%% Script for the analysis of juxta recordings from jRCaMP1a_stCoChR positive neurons
clear all;  close all;

%Description of variables called throughout the script:
% C_raw: fluorescence trace extracted from ROIs. Each column is a neuron.
% S_raw: time of each detected AP. Each column is a neuron.
% Stamps: the start-of-frame time. Each column is a neuron.
% Stim_times: time of stimulus onset. Each column is a neuron.

Fs = 11.03;                                     %Sampling frequency
[T,N] = size(C_raw);                            %Number of cells and frames
N_stim = length(Stim_times);                    %Number of stimulations
t = [0:1/Fs:(T-1)/Fs]';                         %Time vector starting at 0 = frame1
t_cell = [zeros(1,N); cumsum(diff(Stamps,1))];  %Exact time of frames (first frame is t=0)

%% Interpolate 'spike times' and 'stim times' referred to 'imaging times'
S_alg = S_raw;
St_alg = zeros(length(Stim_times),N);

for i = 1:N
    S_alg(:,i) = interp1(Stamps(~isnan(Stamps(:,i)),i),t_cell(~isnan(t_cell(:,i)),i),S_raw(:,i),'linear');
    St_alg(:,i) = interp1(Stamps(~isnan(Stamps(:,i)),i),t_cell(~isnan(t_cell(:,i)),i),Stim_times(:,1));
end

%% Generate stimulation trace (200 ms illumination)
St_trace = zeros(T,N);
Stim_frm = zeros(N_stim,N);
for i = 1:N
    for j = 1:N_stim
        [~,Stim_frm(j,i)] = min(abs(t-St_alg(j,i)));
        St_trace(Stim_frm(j,i):Stim_frm(j,i)+round(Fs*0.2),i)=1;
    end
end

%% Filter and plot traces
C_nrm = normalize(C_raw,'zscore');              %Z-score normalization (for visualization)
C_flt = movmedian(C_nrm,round(Fs/2),1);         %Filter traces with moving median for high frequency noise
B_nrm = normalize(B_raw,'zscore');

figure();   set(gcf,'units','normalized','outerposition',[0 0 1 1]);
for i = 1:N
    ax(i) = subplot(N,1,i);
    plot(C_flt(:,i));
end
linkaxes(ax,'x');

%% Plot imaging and stimulation
figure();   set(gcf,'units','normalized','outerposition',[0 0 1 1]);
for i = 1:N
    ax(i)=subplot(N,1,i);
    plot(t_cell(:,i),C_flt(:,i)); hold on;
    stem(S_alg(:,i),ones(size(S_alg(:,i))),'Marker','none','Color','k');
    area(t_cell(:,i),4*St_trace(:,i),'FaceColor','r','EdgeColor','none','FaceAlpha',0.3);
    hold off;
end
linkaxes(ax,'x');

%% Plot the backgroud artifact
figure();   set(gcf,'units','normalized','outerposition',[0 0 1 1]);
for i = 1:N
    ax(i)=subplot(N,1,i);
    plot(t_cell(:,i),B_nrm(:,i)); hold on;
    stem(S_alg(:,i),ones(size(S_alg(:,i))),'Marker','none');
    area(t_cell(:,i),4*St_trace(:,i),'FaceColor','r','EdgeColor','none','FaceAlpha',0.3);    hold off;
end
linkaxes(ax,'x');

%% Calculate the number of spikes during the stimulation window
for i = 1:N
    for j = 1:N_stim
    N_spks(i,j) = nnz(S_raw(:,i)>Stim_times(j) & S_raw(:,i)<Stim_times(j)+0.200);
    end
end

figure();
histogram(N_spks(:));
title(['Mean = ' num2str(mean(N_spks(:))), ' spikes/stim']);

%% Effectively apply median filtering 
C_raw = movmedian(C_raw,round(Fs/2),1);         

%% Calculate the average DF/F and tau for each recording

%F0 calculated in ten frames at the beginning of each recording
%Extract the stimulation periods from -2s to +3s 
figure();
for i = 1:N
    for j = 1:N_stim
        F0 = mean(C_raw(1:10,i),1);
        C_stm(:,i,j) = (C_raw(Stim_frm(j,i)-round(Fs*2):Stim_frm(j,i)+round(Fs*3),i)-F0)/F0;
        C_stm(:,i,j) = C_stm(:,i,j)- mean(C_stm(1:20,i,j),1); %Subtract average in the first 20 frames, to correct for drifting baseline   
        plot(C_stm(:,i,j)); hold on;
    end
end
hold off;

%Plot and fit the average transient at the offset of the stimulation
C_stm(:,3,2) = nan; %Exclude stimulation trials with significant spontaneous activity close to stim
A = mean(C_stm,3,'omitnan');
time = [0: 1/Fs: (size(A,1)-1)*1/Fs];
figure();   set(gcf,'units','normalized','outerposition',[0.3 0 0.4 1]);
for i = 1:N
    ax(i) = subplot(N,1,i); plot(A(:,i));
end
linkaxes(ax,'x');

figure();
for i = 1:N
    subplot(2,1,1); plot(time, A(:,i));
    subplot(2,1,2); [am(i), ta(i)] = exp_fit(A(round(2.5*Fs):end,i),1/Fs);
    w = waitforbuttonpress;
end

%Generate recap tables
Summary = [am' ta' mean(N_spks,2)];

%% For figure (version I)
Stim_ave = zeros(length(time));
Stim_ave(round(Fs*2)+1:round(Fs*2.2)+1) = 0.6;
figure(); set(gcf,'units','normalized','outerposition',[0.3 0.3 0.15 0.4]);
for i = 1:N
    for j = 1:N_stim
        plot(time,C_stm(:,i,j),'Color',[0.7 0.7 0.7]); hold on;
    end
end
plot(time,mean(A,2),'r','LineWidth',3);
area(time,Stim_ave,'FaceColor','r','EdgeColor','none','FaceAlpha',0.3);
hold off;
xlabel('Time (s)');
ylabel('DF/F (%)');

%% For figure (version II)
Stim_ave = zeros(length(time));
Stim_ave(round(Fs*2)+1:round(Fs*2.2)+1) = 0.6;
figure(); set(gcf,'units','normalized','outerposition',[0.3 0.3 0.15 0.4]);

plot(time,A,'Color',[0.7 0.7 0.7]); hold on;
plot(time,mean(A,2),'k','LineWidth',3);
ciplot(mean(A,2)-std(A,0,2)/sqrt(N),mean(A,2)+std(A,0,2)/sqrt(N),time,[0.6 0.6 0.6]);    alpha(0.3);

area(time,Stim_ave,'FaceColor','r','EdgeColor','none','FaceAlpha',0.3);
hold off;
xlabel('Time (s)');
ylabel('DF/F (%)');

%% Definition of the exponential fit
function [ampl, tau] = exp_fit(y,delta)
    x = [0:1:size(y)-1]';
    f = fit(x,y,'exp1','StartPoint',[0.03,-delta/0.6], 'Upper',[+Inf -delta/15],'Lower',[0 -delta/0.01]);
    plot(f,x,y,'-');
    ampl = f.a;
    tau  = -delta/f.b;
end