% Analysis for the grid cell firing pattern
% - Divyansh Mittal
% - V1.0
% - Run this after completeing the rate based model main simulation

%%%%	Requirements	%%%%
% - Change the address for loading of the all.mat file (output of main simulations)
% - Mention the address for loading the rodent's virtual trajectory


%%%%	Output	%%%%
% ------------------------------------------------------------------- %
% - Average rate: Overall average firing rate in the entire simulation run
% - Peak rate: Highest firing rate during the entire duration of the
%              simnulation run
% - Number of fields: Total number of the grid fields in the whole arena
% - Mean field size: Average size of the grid field in the arena
% - Avergae spacing between grid field centres: 
% - Gridness Score: min(Acorr60, Acorr120)- max(Acorr30-Acorr90-Across150);
%                   soomthening of the spatial map
% - Information rate:
% - Sparsity: 


% -Spatial rate maps 
% ------------------------------------------------------------------- %

clc; clear; %close all;

% change directory to load the temporal acitivity of all cells in the main simulations (all.mat)
cd /Volumes/DIVYANSH2/Computational/Network_model/Grid_hetero_res_full_run/Res_freq_network/HPF/Tau10_hpf06_R_12 
load all.mat

%the address for loading the rodent's virtual trajectory
load /Volumes/DIVYANSH2/Computational/Network_model/Grid_hetero_res_full_run/Res_freq_network/LPF/data/pos_2m.mat

% Parameters for analysis of grid cells
guassfilt_n = 4;
thresh_for_peak= 0.6;
thresh_in_grid = 0.2;

side_of_arena = 2;

rcdcells = zeros(size(rcd,1),1);%zeros(1,1);

nSpatialBins = 100;
minx = 0; maxx = 2; % m
miny = 0; maxy = 2; % m

%spikeCoords = [];
spikeThresh = 0.1;
occupancy = zeros(nSpatialBins,nSpatialBins,length(rcdcells));
spikes = zeros(nSpatialBins,nSpatialBins,length(rcdcells));

for cell_no= 1:length(rcdcells)
    cell_no
    for tind = 1:100e3-1
    
        xindex = round((pos(1,tind)-minx)/(maxx-minx)*nSpatialBins);
        yindex = round((pos(2,tind)-miny)/(maxy-miny)*nSpatialBins);
        if (xindex>100)
            xindex = 100;
        end
        if (yindex>100)
            yindex = 100;
        end
        if (xindex<2)
            xindex = 2;
        end
        if (yindex<2)
            yindex = 2;
        end
        occupancy(yindex,xindex, cell_no) = occupancy(yindex,xindex, cell_no) + 1;
        spikes(yindex,xindex, cell_no) = spikes(yindex,xindex, cell_no) + rcd(cell_no,tind);
    end
end
rate_map_all = zeros(nSpatialBins,nSpatialBins,length(rcdcells));
rate_map_all = spikes./occupancy;

%%% Analysis%%%
%Keeping original image
for j = 1:length(rcdcells)
    
rate_map_org = rate_map_all (:,:,j);
rate_map_org(isnan(rate_map_org))=0;

% 2D Guassian filtering of the image
rate_map = imgaussfilt(rate_map_org, guassfilt_n);

size_map = size(rate_map);
len_vector_rate_map = size_map(1)*size_map(2);
rate_map_vec = reshape(rate_map,[len_vector_rate_map,1]);

%%%% Average rate %%%%
Average_firing_rate(j) = mean(rate_map_vec);

%%%% Peak Rate of firing %%%%
Peak_firing_rate(j) = max(rate_map_vec);

%%%% Number of fields in the arena %%%%

% finding the location of first peak
[Max_row,Max_col] = find(rate_map==max(max(rate_map)));

%everything which is lower than the threshold for finding grid field peak should
%be considered as zero
red_rate_map = rate_map;
red_rate_map(red_rate_map<thresh_for_peak*Peak_firing_rate(j)) = 0;

%find loacl peaks
Local_peaks = imregionalmax(red_rate_map);
Num_Grid_fields(j) = sum(sum(Local_peaks));


%%%% Mean field size %%%%
red_rate_map2 = rate_map;
red_rate_map2(red_rate_map2<thresh_in_grid*Peak_firing_rate(j)) = 0;
E_D = bwdist(red_rate_map2);

Mean_field_size_pix(j) =nnz(~E_D)/Num_Grid_fields(j);
Mean_field_size_cm(j) = side_of_arena*100/size_map(1)* Mean_field_size_pix(j);

%%%% Grid score %%%%
AC = xcorr2(red_rate_map2);
% rotation on centre
AC_30 = imrotate(AC, -30, 'crop');
AC_90 = imrotate(AC, -90, 'crop');
AC_150 = imrotate(AC, -150, 'crop');

AC_60 = imrotate(AC, 60, 'crop');
AC_120 = imrotate(AC, 120, 'crop');

%correlation between original AC and shifted AC
C0_30 = corr2(AC,AC_30);
C0_90 = corr2(AC,AC_90);
C0_150 = corr2(AC,AC_150);

C0_60 = corr2(AC,AC_60);
C0_120 = corr2(AC,AC_120);

grid_score(j) = min(C0_60, C0_120) - max([C0_30, C0_90, C0_150]);

%%%% information rate %%%%
occ = occupancy(:,:,j);
prob_occ= occ./sum(sum(occ));
prob_occ_vec = reshape(prob_occ,[len_vector_rate_map,1]);
prob_occ_vec(isnan(prob_occ_vec))=0;

temp = log2(rate_map_vec./Average_firing_rate(j));
temp(isinf(temp)) = 0;
info_rate(j)= sum(prob_occ_vec.*(rate_map_vec.*temp));


%%%% sparsity %%%%%
sparsity(j) = (Average_firing_rate(j)*Average_firing_rate(j))/sum(prob_occ_vec.*(rate_map_vec.*rate_map_vec));

%%%% spacing %%%%%
[i1 j1] = find(Local_peaks);
out = tril(ones(numel(i1)),-1);
out(out~=0) = hypot(pdist(i1),pdist(j1));
Average_spacing(j) = sum(sum(out))/nnz(out);
i1 = [];
j1 = [];
out= [];
end

info_rate(isnan(info_rate))=0;

mean_all(1) = mean(Average_firing_rate);
mean_all(2) = mean(Peak_firing_rate);
mean_all(3) = mean(Num_Grid_fields);
mean_all(4) = mean(Mean_field_size_pix);
mean_all(5) = mean(Mean_field_size_cm);
mean_all(6) = mean(grid_score);
mean_all(7) = mean(info_rate);
mean_all(8) = mean(sparsity);
mean_all(9) = mean(Average_spacing);

std_all(1) = std(Average_firing_rate);
std_all(2) = std(Peak_firing_rate);
std_all(3) = std(Num_Grid_fields);
std_all(4) = std(Mean_field_size_pix);
std_all(5) = std(Mean_field_size_cm);
std_all(6) = std(grid_score);
std_all(7) = std(info_rate);
std_all(8) = std(sparsity);
std_all(9) = std(Average_spacing);

mean_all = mean_all';
std_all = std_all';

save("all_mesaure_100.mat", 'mean_all', 'std_all', 'Average_firing_rate', 'Peak_firing_rate', 'Num_Grid_fields', 'Mean_field_size_pix', 'Mean_field_size_cm', 'grid_score', 'info_rate', 'sparsity', 'Average_spacing')
save('rate_map_org_100.mat', 'rate_map_all')