ccc
tic
N = 256;
J = zeros(N, N);
% Map all neurons to [-pi, pi)
x = linspace(-pi, pi, N+1);
pos = x(1: N);
%construct weight matrix, use FFT and convolution to speed up the computation of recurrent connection
a = 0.3;
J0 = 1/(N^2)*128^2;

for i = 1: N
	for j = 1: N
        dx = min(pos(i)-pos(1), pi-pos(i));
        dy = min(pos(j)-pos(1), pi-pos(j));
% 		dx = pos(i)-pos(1);
% 		dy = pos(j)-pos(1);
        dis = sqrt(dx.^2 + dy.^2);
%         dis = min(dis, pi-dis);
		% choose the weight matrix to be a 2D Gaussian Distribution
		J(i, j) = J0/(2*pi*a^2) * exp(-(dis^2)/(2*a^2));
	end
end
% imagesc(J)

Jfft = fft2(J);


tau = 1;
tau_v = 50;
dt = tau/10;
density = N*N/(4*pi^2);
alpha = 0.2;
Iext = zeros(N, N);
Iext_0 = zeros(N, N);
U = zeros(N, N);
V = zeros(N, N);
r = zeros(N, N);
t = 0;
i=1;

mbar= 190 ;%20
m = mbar* tau/tau_v;
kfactor=(1+tau/tau_v)*(1+2*m-tau/tau_v)*density*J0^2/(32*pi*a^2*(1+m)^4);
k = 5/(N^2)*128^2; % global inhibition factor
tic
% center_x=zeros(1,(T-1000)/dt);
% centerx_I=zeros(1,(T-1000)/dt);
% center_y=zeros(1,(T-1000)/dt);
v_ext = 6.2*1e-3/5;
[X,Y] = meshgrid(pos,pos);
loc = 0;
loc_y = -0.02;
a_noise = 0.01;



% %%
% bound_x = -pi/2; 
% 
% 
% %%

%%Video initiation
% myVideo = VideoWriter('T-maze.mp4','MPEG-4'); 
% myVideo.FrameRate = 40; 
% open(myVideo); 

band = floor(a*0.7/pi*N);
band_y = floor(a*0.7/pi*N);
I_sum = [];
r_probe = [];
while  loc < pi/2-a/2
%     loc = loc + v_ext/2*sqrt(2) * dt;
    loc = loc + v_ext * dt;
    if loc >= pi
        loc = loc - 2 * pi;
    end
   
    for x = 1: N
	    for y = 1: N
            dis = min(abs(pos(x) - loc), 2 * pi - abs(pos(x)-loc));
            dis_y = min(abs(pos(y) - loc_y), 2 * pi - abs(pos(y)-loc_y));
			Iext(y, x) = alpha/(2 * pi * a^2) * (exp(-((dis).^2+(dis_y).^2)/8/a^2)+randn(1).*a_noise);
            %为了让预测距离能够大于预测物体的半径
        end
    end
    
    Iext(:,1:N/4) = nan;
    Iext(1:N/8,:) = nan;
    Iext(7*N/8:end,:) = nan;
    
    Iext(N/2+band_y:end,N/4+1:N*3/4-band) = nan;
    Iext(1:N/2-band_y,N/4+1:N*3/4-band) = nan;
    width = 10*N/256;
    
    Iext(N/2+band_y+width:end,N/4+1:N*3/4-band+width) = nan;
    Iext(1:N/2-band_y-width,N/4+1:N*3/4-band+width) = nan;
    Iext(:,N*3/4+band+2*width:end) = nan;
    for bi = 1:5
        Iext(N/2+band_y+bi:end,N*3/4-band+bi) = nan;
        Iext(1:N/2-band_y-bi,N*3/4-band+bi) = nan;
    end
    
    
    Iext = Iext./sum(sum(Iext(Iext>=0))).*(190*N^2/128^2);
    I_sum = [I_sum,sum(sum(Iext(Iext>=0)))];
    
    dU = dt * (-U + ifft2(Jfft.*fft2(r)) - V + Iext)/tau;
	U = U + dU;
    
    U(:,1:N/4) = nan;
    U(1:N/8,:) = nan;
    U(7*N/8:end,:) = nan;
    U(N/2+band_y:end,N/4+1:N*3/4-band) = nan;
    U(1:N/2-band_y,N/4+1:N*3/4-band) = nan;
    U(N/2+band_y+width:end,N/4+1:N*3/4-band+width) = nan;
    U(1:N/2-band_y-width,N/4+1:N*3/4-band+width) = nan;
    U(:,N*3/4+band+2*width:end) = nan;
    for bi = 1:width
        U(N/2+band_y+bi:end,N*3/4-band+bi) = nan;
        U(1:N/2-band_y-bi,N*3/4-band+bi) = nan;
    end
	U = max(U, 0);
	dV = dt * (-V + m*U) / tau_v;%
	V = V + dV;
	U = max(U, 0);
	r = U.^2./(1+k*sum(U(:).^2));
    
    
    t_start = 300;
    if t> t_start
        id = 25*N/128;
        r_probe_l(i) = r(N/2+id,N/4*3);
        r_probe_l2(i) = r(N/2+id+4,N/4*3);
        r_probe_r(i) = r(N/2-id,N/4*3);
        center=expectation2(r);
        center_x(i)=center(1);
        center_y(i)=center(2);
        centerx_I(i) = loc;
        centery_I(i) = loc_y;
        time(i) = t-t_start;
        i=i+1;
    if i/5 == floor( i/5) 
        r_p = r;
        r_p(:,1:N/4) = nan;
        r_p(1:N/8,:) = nan;
        r_p(7*N/8:end,:) = nan;
        r_p(N/2+band_y:end,N/4+1:N*3/4-band) = nan;
        r_p(1:N/2-band_y,N/4+1:N*3/4-band) = nan;
        r_p(N/2+band_y+width:end,N/4+1:N*3/4-band+width) = nan;
        r_p(1:N/2-band_y-width,N/4+1:N*3/4-band+width) = nan;
        r_p(:,N*3/4+band+2*width:end) = nan;
        for bi = 1:width
            r_p(N/2+band_y+bi:end,N*3/4-band+bi) = nan;
            r_p(1:N/2-band_y-bi,N*3/4-band+bi) = nan;
        end
%         plot(center(1),center(2),'k.','markersize',15),hold on
%         plot(loc,loc_y,'b.','markersize',15),
%         plot(center_x,center_y,'k','linewidth',2)
%         hold off
        
            
        surf(Y',X',r_p'),hold on
        plot3(-3.1*a,loc,1,'r>','linewidth',100,'markersize',45)
        plot3(-2.9*a,loc,1,'r>','linewidth',100,'markersize',45)
        plot3(-3*a,loc,1,'r>','linewidth',100,'markersize',40)
        plot3(-3*a,loc,1,'r>','linewidth',100,'markersize',35)
        plot3(-3*a,loc,1,'r>','linewidth',100,'markersize',30)
        plot3(-3*a,loc,1,'r>','linewidth',100,'markersize',45),hold off
        axis equal
        axis off
%         colormap t
        shading flat
        view([0 0 1])
        drawnow
        if loc >0.5 && i/30 == floor(i/30)
            name = num2str(floor(i/30));
            saveas(gcf,name,'bmp')
        end
%         frame = getframe(gcf);
%         im = frame2im(frame); 
%         writeVideo(myVideo,im); 
    end
    end
    
	t = t + dt;
    disp(pi/2 - a/2 -loc);
end
toc%%
% close(myVideo)
%%
L_diff = center_x-centerx_I;
plot(L_diff),hold on
[pks,locs] = findpeaks(-L_diff);
plot(locs,-pks,'r.'),
% time = time./1e3.*3;
%%

figure
subplot(1,2,1)
plot(time./1e3.*3-0.5,r_probe_l*1e4*2,'r','linewidth',1),hold on
plot(time./1e3.*3-0.5,r_probe_r*1e4*2,'b','linewidth',1)
for i = 1:length(locs)
    plot([time(locs(i)+100)./1e3.*3-0.5 time(locs(i)+100)./1e3.*3-0.5],[0 8],'k--','linewidth',1)
end
legend('probe neuron A','probe neuron B')
legend box off
xlabel('Time (s)')
ylabel('Firing rate (Hz)')
% set(gcf,'unit','centimeters','position',[20,17,15,14])
box off
set(gca,'linewidth',1,'fontsize',15,'fontname','Arial');
axis([0 1.9 0 13])

subplot(1,2,2)
plot(time./1e3.*3-0.5,r_probe_l*1e4*2,'r','linewidth',1),hold on
for i = 1:length(locs)
    plot([time(locs(i)+100)./1e3.*3-0.5 time(locs(i)+100)./1e3.*3-0.5],[0 8],'k--','linewidth',1)
end
xlabel('Time (s)')
ylabel('Firing rate (Hz)')
set(gcf,'unit','centimeters','position',[20,17,25,10])
box off
set(gca,'linewidth',1,'fontsize',15,'fontname','Arial');
axis([0 1.9 0 13])

%%
c = xcorr(r_probe_l*1e4*2,r_probe_r*1e4*2);

figure
subplot(2,1,2)
Time = ((1:length(c))-length(c)/2).*dt./1e3.*3;

plot(Time,c./max(c),'k','linewidth',1)
box on
xlabel('Time (s)')
ylabel('Normalized cross correlation')
set(gca,'linewidth',1,'fontsize',15,'fontname','Arial');
% set(gcf,'unit','centimeters','position',[25,17,15,10])
axis([-0.4 0.4 0 1.1])

c = xcorr(r_probe_l*1e4*2,r_probe_l*1e4*2);

subplot(2,1,1)
Time = ((1:length(c))-length(c)/2).*dt./1e3.*3;

plot(Time,c./max(c),'k','linewidth',1)
box on
xlabel('Time (s)')
ylabel('Normalized auto-correlation')
set(gca,'linewidth',1,'fontsize',15,'fontname','Arial');
set(gcf,'unit','centimeters','position',[25,17,15,20])
axis([-0.4 0.4 0 1.1])
%%

data = r_probe_l;
Fs = 1/dt*1e3/3;            % Sampling frequency 
L = length(data);
% data_0 = zeros(1,2*L);
% data = [data_0,Data,data_0];
% L = length(data);
Y = fft(data);
P2 = abs(Y/L);
P1 = P2(1:L/2+1);
P1(2:end-1) = 2*P1(2:end-1);
f = Fs*(0:(L/2))/L;
% plot(f,P1*100,'linewidth',3) 
% axis([5 17 0 0.4])
% L = L/2;
f_data = f(floor(5/Fs*L):floor(15/Fs*L));
% f_data = f_data./(sum(f_data));
P_data = P1(floor(5/Fs*L):floor(15/Fs*L));
plot(f,P1./max(P1),'k','markersize',20)
axis([0 100 0 1])
%%
data = r_probe_l;
Fs = 1/dt*1e3/3;            % Sampling frequency 
L = length(data);
% data_0 = zeros(1,2*L);
% data = [data_0,Data,data_0];
% L = length(data);
Y = fft(data);
P2 = abs(Y/L);
P1 = P2(1:L/2+1);
P1(2:end-1) = 2*P1(2:end-1);
f = Fs*(0:(L/2))/L;
% plot(f,P1*100,'linewidth',3) 
% axis([5 17 0 0.4])
% L = L/2;
f_data = f(floor(5/Fs*L):floor(15/Fs*L));
% f_data = f_data./(sum(f_data));
P_data = P1(floor(5/Fs*L):floor(15/Fs*L));
plot(f,P1./max(P1),'k','markersize',20)
axis([0 100 0 1])



%%
figure
pb = 0.3;
center_x(center_x>pi/2+0.25) = pi/2+0.25;
plot(time-0.5,centerx_I+0.5,'g','linewidth',3)
hold on
plot(time(abs(center_y) < pb)-0.5,center_x(abs(center_y) < pb)+0.5,'k.','markersize',1),
c_left = center_y(center_y>=pb);
c_right = abs(center_y(center_y<=-pb));
plot(time(center_y>=pb)-0.5,c_left+pi/2+0.5,'r.','markersize',1)
plot(time(center_y<=-pb)-0.5,c_right+pi/8*7+0.5,'b.','markersize',1)
xlabel('time')
ylabel('decoded position')
set(gcf,'unit','centimeters','position',[20,17,25,7.5])
box off
set(gca,'linewidth',1,'fontsize',15,'fontname','Arial'); 
% axis([0 2.8 0 4.5])