clc
clear all
close all
N = 512;
alpha = 0.19;
mbar = 152.5;
k0 = 5;
k = k0 / N * 512; %global inhibition factor
J0 = 1;
J0= J0 / N * 512;
%%%%%%%%%%%%%%%%Parameters Settings
a = 0.4;
tau = 3;
tau_v =48*tau;
dt = tau/10;
m = mbar* tau/tau_v;

%%%%%%%%%%%%%%%%%%Matrix Construction

J = zeros(N, 1);
Iext = zeros(N, 1);
U = zeros(N, 1);
V = zeros(N, 1);
r = zeros(N, 1);

% Map all neurons to [-pi, pi)

x = linspace(-pi, pi, N+1);

pos = x(1: N);
for i = 1: N
		dx = min(pos(i)-pos(1), pi-pos(i));
		% choose the weight matrix to be a 2D Gaussian Distribution
		J(i) = J0/(sqrt(2*pi)*a) * exp(-(dx^2)/(2*a^2));
end
Jfft = fft(J);
%%%%%%%%%%%%%%
pos=pos';
vbar = 0.55;
v=a/tau_v*vbar;
% v = 4*1e-3;
T=8*pi/v;
loc=-pi*5/8;
centerx_U = zeros(1,length(T/dt));
centerx_I = zeros(1,length(T/dt));
t = 0;
r_t = zeros(N,length(T/dt));

Tuning_matrix = zeros(N,N);
for i = 1: N
    for j = 1:N
        dis = pos(i)-pos(j);
		dx = min(abs(dis), 2*pi-abs(dis));
		
        Tuning_matrix(i,j) = 2.1*1e-3*exp(-(dx^2)/(2*a^2))./3.3.*15;
    end
end
Prob = zeros(N,502);
tic
timestamp = 1;
j = 1;
while t < T
        loc = loc + v * dt;
    if loc >= pi
        loc = loc - 2 * pi;
    end
    dis = min(abs(pos - loc), 2 * pi - abs(pos - loc));
    Iext = alpha  * exp(-(dis.^2/(4*a^2)));
    Irec=ifft(Jfft.*fft(r));%;
    dU = dt * (-U - V + Iext+Irec)/tau;
	U = U + dU;
	dV = dt * (-V + m.*U) / tau_v;
	V = V + dV;
	U = max(U, 0);
	r = U.^2./(1+k.*sum(U(:).^2));%.*(ratio)
    %if t>10
    if t>T/4*3
        centerx_I(1,j)=loc;
        maxp = angle(exp(-1i.*pos)'*r);
        if loc > pi - 4*a && maxp < -pi+4*a
            maxp = maxp + 2 * pi;
        end

        if maxp > pi - 4*a && loc < -pi+4*a
            maxp = maxp - 2 * pi;
        end
        centerx_U(1,j)=maxp;
        
        
        
        r_t(:,j)=r(:);
        Ar(j) = max(r);
%         if rem(j,10) == 0
%             plot(r.*50),hold on
%             plot(Iext),hold off
%             axis([1 N 0 alpha])
%             drawnow
%         end
%         if rem(j,30) == 0
%             centerx_U(1,timestamp)=maxp;
%             centerx_I(1,timestamp)=loc;
%             r_mean(:) = mean(r_t(:,end-30:end),2);
%             Tau = 30*dt;
% %             spike_num = poissrnd(r_mean.*Tau./3.3.*15);
%             Prob(:,timestamp)=1;
%             for pos_j = 1:N
%                 for r_i = 1:N
%                     dis_re = floor(pos_j - (loc+pi)/2/pi*N);
%                     if dis_re >= 1
%                         Pos_relative = dis_re;%+N/2;
%                     else
%                         Pos_relative = N + dis_re;%-N/2;
%                     end
% %                     Prob(Pos_relative,timestamp) = Prob(Pos_relative,timestamp)*(Tuning_matrix(r_i,pos_j))^(spike_num(r_i)).*exp(-Tau*Tuning_matrix(r_i,pos_j));
%                     Prob(Pos_relative,timestamp) = Prob(Pos_relative,timestamp)*(Tuning_matrix(r_i,pos_j))^(r_mean(r_i).*Tau)*exp(-Tau*Tuning_matrix(r_i,pos_j));
% %                     Prob(Pos_relative,timestamp) = Prob(Pos_relative,timestamp)*(Tuning_matrix(r_i,pos_j)*Tau)^(spike_num(r_i).*Tau)./(factorial(spike_num(r_i)))*exp(-Tau*Tuning_matrix(r_i,pos_j));
%                 end
%             end
% %             Prob(:,timestamp) = (Prob(:,timestamp)-min(Prob(:,timestamp)))./(max(Prob(:,timestamp))-min(Prob(:,timestamp)));
%             Prob(:,timestamp) = Prob(:,timestamp)./sum(Prob(:,timestamp));
%             timestamp = timestamp + 1;
%         end

        j=j+1;
    end
	t = t + dt;
%     disp(t/T)
end
toc
% toc

time = dt:dt:(j-1)*dt;
[Time, Pos] = meshgrid(time,pos);
[CI, P] = meshgrid(centerx_I(1:length(time)),pos);
relative_Pos = Pos-CI;


L_diff = centerx_U-centerx_I;
max(L_diff)
[~,Locs] = findpeaks(-L_diff);
Period = mean(diff(time(Locs)))./1e3;


Phase = (time./1e3-0.55)./Period.*360;
Phase = Phase+50 ;
%%
%fig 4a
figure
neuron_id = N/2;
plot(Phase./720-3.5,r_t(neuron_id,:).*5000,'k','linewidth',1)
xlabel('theta cycles')
ylabel('firing rate(Hz)')
% set(gca,'Xtick',0:180:720)
axis([-6 6 0 15])
set(gca,'linewidth',1,'fontsize',10,'fontname','Arial');
set(gcf,'unit','centimeters','position',[25,17,20,5])
set(gca,'xtick',-10:1:20);
a=get(gca,'xticklabel');
% a = a(5:end-5);
b=cell(size(a));
b(mod(1:size(a,1),5)==1,:)=a(mod(1:length(a),5)==1,:);
box off

%%
%fig 4d
plot(Phase,Ar.*5000,'r','linewidth',2),hold on
plot([Phase(Locs(8)) Phase(Locs(8))],[0 15],'k--')
plot([Phase(Locs(7)) Phase(Locs(7))],[0 15],'k--')
[~,Locs_pos] = findpeaks(L_diff);
plot([Phase(Locs_pos(8)) Phase(Locs_pos(8))],[0 15],'k--')
plot([Phase(Locs_pos(7)) Phase(Locs_pos(7))],[0 15],'k--')
% plot(Phase,L_diff+pi)
axis([0 720 0 15])
set(gca,'linewidth',1,'fontsize',10,'fontname','Arial');
set(gcf,'unit','centimeters','position',[25,17,10,7])
xlabel('theta phase')
ylabel('maximum firing rate(hz)')
set(gca,'Xtick',0:180:720)
%%
%fig 4e
[Theta_phase, ~] = meshgrid(Phase,pos);
step_1 = 1;
step_2 = 2;
time_end = 3500;
time_s = 1800;
pos_end = 290;
pos_s = 110;
figure
idx_1 = pos_s:step_1:pos_end;
idx_2 = time_s:step_2:time_end;

surf(Theta_phase(idx_1,idx_2)-360,...
    relative_Pos(idx_1,idx_2),...
    r_t(idx_1,idx_2).*5000),hold on

% plot3(Phase-360,L_diff,1,'w','linewidth',2.5)
shading flat
axis([-270 450 -0.72 1.3])
view([0 0 1])
xlabel('Theta phase')
ylabel('Relative Loc.')
set(gca,'Xtick',0:180:720)
set(gca,'Ytick',-3:0.5:3)
colorbar
colormap hot
caxis([0 15])
box on

set(gca,'linewidth',3,'fontsize',15,'fontname','Arial');
set(gcf,'unit','centimeters','position',[25,17,23,15])
% subplot(2,1,2)



