%% Script for the analysis of jRCaMP1a spontaneous activity and ephys
clear all;  close all;  clc;
                         
%Description of variables called throughout the script:
% C_raw: fluorescence trace extracted from ROIs. Each column is a neuron.
% S_raw: time of each detected AP. Each column is a neuron.
% Stamps: the start-of-frame time. Each column is a neuron.

Fs = 11.03;                                     %Sampling frequency
[T,N] = size(C_raw);                            %Number of cells and frames
t = [0:1/Fs:(T-1)/Fs]';                         %Time vector starting at 0 = frame1, with even sampling
t_cell = [zeros(1,N); cumsum(diff(Stamps,1))];  %Exact time of frames (first frame is t=0)
deconv_options = struct('type', 'ar2', ...      % model of the calcium traces. {'ar1', 'ar2'}
    'deconv_method','constrained_foopsi',...    % activity deconvolution method
    'optimize_pars', true, ...                  % optimize AR coefficients
    'temporal_iter',2,...                       % number of block-coordinate descent steps 
    'fudge_factor',0.98);                       % bias correction for AR coefficients

%% Pre-processing of Ca-traces (only for visualization)

C_nrm = normalize(C_raw,'zscore');           %Z-score normalization                     
C_flt = movmedian(C_nrm,round(Fs/2),1);      %Filter traces with moving median for high frequency noise

%Show raw and filtered traces
figure();   set(gcf,'units','normalized','outerposition',[0 0 1 1]);
for i = 1:N
    plot(t,C_nrm(:,i)-(i-1)*6,'.',t,C_flt(:,i)-(i-1)*6,'LineWidth',3); hold on;
    grid on
end
hold off;

%% Denoise with CNMF-e deconvolution and save fitting params

C_scl = normalize(C_flt,'range');                                           %Scale traces between 0 and 1
C_dns = zeros(T,N);                                                         %Initialize Denoised traces to 0
F_params = zeros(N,2);                                                      %Initialize parameter matrix

figure();   set(gcf,'units','normalized','outerposition',[0 0 1 1]);
for i = 1:N
[c, s, option] = deconvolveCa(C_scl(~isnan(C_scl(:,i)),i)', deconv_options, 'ar2','optimize_pars',1,'optimize_b',1);
C_dns(1:nnz(~isnan(C_scl(:,i))),i) = c;
F_params(i,:) = option.pars'; 
C_sn(i) = option.sn;
subplot(N,1,i);  plot(t,C_scl(:,i),t,C_dns(:,i));   
end

%% Interpolate 'spike times' referred to 'imaging times'

S_alg = S_raw;
for i = 1:N
S_alg(:,i) = interp1(Stamps(~isnan(Stamps(:,i)),i),t_cell(~isnan(t_cell(:,i)),i),S_raw(:,i),'linear');
end

figure();   set(gcf,'units','normalized','outerposition',[0 0 1 1]);
for i = 1:N
ax(i)=subplot(N,1,i); stem(S_alg(:,i),ones(size(S_alg(:,i))),'Marker','none'); hold on; plot(t_cell(:,i),C_dns(:,i),'LineWidth',2); hold off;
end
linkaxes(ax,'x');
ylim([0 1]);


%% Quantify DF/F caused by a burst of APs
C_filt = movmedian(C_raw,round(Fs/2),1);

burst_l = 0.350;      %burst length (s)
burst_int = 0.8;      %Inter-burst distance (s)
AP_waveform = struct(); AP_waveform.id = []; AP_waveform.count = []; 
AP_waveform.fluo =[]; AP_waveform.freq =[];  

figure();   set(gcf,'units','normalized','outerposition',[0 0 1 1]);
for i = 1:N
    spk_t = S_alg(~isnan(S_alg(:,i)),i);            %Get valid spike times
    isi = [diff(spk_t); t(end,:)-spk_t(end,:)];     %Calculate inter-spike intervals
    b_count = [];   b_times = [];   prev_isi = []; m = 1; b_freq = [];
    for j = find(isi > burst_int)'                          %Detect burst terminators (burst_int far from the next burst)
        prev_isi = cumsum(flip(isi(1:j)))-isi(j);           %For every terminator, count spikes in a burst_l window
        b_count(m) = nnz(prev_isi<burst_l);                 %Number of spikes in the burst        
        %b_times(m) = spk_t(j);                             %Time of the last spike
        b_times(m) = spk_t(j-b_count(m)+1);                 %Time of the first spike
        if b_count(m) == 1
            b_freq(m) = nan;
        else
            b_freq(m) = 1/mean(isi(j-b_count(m)+1:j-1),1);
        end
        if j-b_count(m)>0 && (isi(j-b_count(m)))>burst_l    %Accept only if first spike happened burst_l s after the previous one
            m=m+1;
        end    
    end
    ax(i)=subplot(N,1,i);  stem(S_alg(:,i),ones(size(S_alg(:,i))),'Marker','none');     hold on;    
    plot(t_cell(:,i),C_scl(:,i),'LineWidth',2);
    text(b_times,-0.1*ones(size(b_times)),string(b_count)); 
    hold off;
    ylim([-0.2 1]);
    
    %Create structure with fluorescence waveform of each burst
    %id = cell id, count = APs in the burst, fluo = fluo form last spike to burst_l   
    AP_waveform.id = [AP_waveform.id, i*ones(size(b_count))];
    AP_waveform.count = [AP_waveform.count, b_count];
    AP_waveform.freq = [AP_waveform.freq, b_freq];
    for l = 1:length(b_count)   %get DF/F0 for each burst 
        [~,closest_frame] = min(abs(t_cell(:,i)-b_times(l)));
        F0 = C_filt(closest_frame,i); 
        %F0 = mean(C_filt(closest_frame-5:closest_frame,i),1);
        AP_waveform.fluo = [AP_waveform.fluo, (C_filt(closest_frame:closest_frame+round(burst_int*Fs),i)-F0)./F0];
    end
end
linkaxes(ax,'x');

n_AP_max = min([8, max(AP_waveform.count)]);

%% Look at DF/F waveforms  for 1,2,3..APs (single traces) 
figure();   set(gcf,'units','normalized','outerposition',[0 0 0.5 0.5]);
ave_fluo = []; ave_n = []; sem_fluo = [];
time = [0:1/Fs:(size(AP_waveform.fluo,1)-1)/Fs]';
for n = 1:n_AP_max
    ave_fluo(:,n) = mean(AP_waveform.fluo(:,find(AP_waveform.count == n)),2,'omitnan');
    ave_n(:,n) = nnz(AP_waveform.count == n);
    sem_fluo(:,n) = std(AP_waveform.fluo(:,find(AP_waveform.count == n)),0,2,'omitnan')./sqrt(ave_n(:,n)); 

    subplot(1,n_AP_max,n);
    plot(time,AP_waveform.fluo(:,find(AP_waveform.count == n)),'Color',[0.9 0.9 0.9]);   hold on;
    plot(time,ave_fluo(:,n),'k','LineWidth',2);  
    ciplot(ave_fluo(:,n)-sem_fluo(:,n),ave_fluo(:,n)+sem_fluo(:,n),time,[0.6 0.6 0.6]);    alpha(0.3);
    hold off;
    ylim([-0.2 0.3]);
end

%% Look at DF/F waveforms  for 1,2,3..APs (single cells)  
figure();   set(gcf,'units','normalized','outerposition',[0 0 0.5 0.5]);
ave_fluo = []; ave_n = []; sem_fluo = [];
time = [0:1/Fs:(size(AP_waveform.fluo,1)-1)/Fs]';
for n = 1:n_AP_max
    for i = 1:N
    ave_fluo(:,n,i) = mean(AP_waveform.fluo(:,find(AP_waveform.id == i & AP_waveform.count == n)),2,'omitnan'); 
    end
    
    subplot(1,n_AP_max,n);
    plot(time,squeeze(ave_fluo(:,n, :)),'Color',[0.8 0.8 0.8]);   hold on;  
 
    n_F = sum(squeeze(~isnan(ave_fluo(1,n,:))));
    mean_F = mean(ave_fluo(:,n,:),3,'omitnan');
    sem_F = std(ave_fluo(:,n,:),0,3,'omitnan')./sqrt(n_F);
    
    plot(time,mean_F,'k','LineWidth',2);  
    ciplot(mean_F-sem_F,mean_F+sem_F,time,[0.6 0.6 0.6]);    alpha(0.3);
    hold off;
    ylim([-0.2 0.3]);
end

%% Get other measurements
figure();   set(gcf,'units','normalized','outerposition',[0.3 0 0.3 1]);
plot(ave_fluo(:,1:4),'DisplayName','ave_fluo')
[~,loc] = max(mean(ave_fluo,3,'omitnan'));        %get frame corresponding to the maximum, useful later
legend();

%for each cell, calculate the DF/F associated with bursts (at the frame
%correspondig to the maximum in the average burst fluorescence)
for i = 1:N
    for j = 1:n_AP_max
        idx = find(AP_waveform.id == i & AP_waveform.count == j);
        DF(i,j,1) = length(idx);
        DF(i,j,2) = mean(AP_waveform.fluo(loc(j),idx),2);
    end
end

for j = 1:n_AP_max
    idx = find(AP_waveform.count == j);
    Peak_values{j} = AP_waveform.fluo(loc(j),idx)';
    Freq_values{j} = AP_waveform.freq(idx)';
end

DF_cell_F = squeeze(DF(:,:,2));
DF_cell_n = squeeze(DF(:,:,1));

%% Fit the AR model for each cell

%AR(2) model definition (see Pnevmatikakis et al., 2016)
s = [0 0 1 zeros(1,round(60*Fs))];                 %train of 1 min with an initial spike 
y = zeros(N,length(s));                            %fluorescence
x = linspace(0, (length(y))/Fs, length(y));        %time
g = F_params;                                      %parameters from denoising 
tau = zeros(N,1);                                  %initialize decay times 
for i = 1:N
    for n = 3:length(y)
        y(i,n) = g(i,1)*y(i,n-1)+g(i,2)*y(i,n-2)+s(n);       %recurrent definition of the AR(2) model
    end
end

y = normalize(y','range');                                     %normalize between 0 and 1 

%Fit with a double exponential and plot
fo = fitoptions('Method','NonlinearLeastSquares','StartPoint', [3, 3, 0.5, 0.05],'Lower', [0.1, 0.5, 0.1, 0.001],'Upper', [10, 10, 2, 2]);
ft = fittype('A*(exp(-(x-x0)/td)-exp(-(x-x0)/tr))','options',fo);
figure();   set(gcf,'units','normalized','outerposition',[0 0.3 1 0.5]);
for i = 1:N
    f = fit(x',y(:,i),ft);
    coeffvals= coeffvalues(f);
    subplot(1,N,i); plot(f,x,y(:,i));
    hold on;    stem(x,s,'LineWidth',4, 'Marker','none','color','r');    
    xlim([-2 Inf]);  ylim([-0.2 Inf]);
    text(8,0.5,{['Rise: ' num2str(coeffvals(3)) ' s'],['Decay: ' num2str(coeffvals(2)) ' s']});
    tau(i) = coeffvals(2);
    xlabel('Time(s)');  ylabel('Fluorescence'); hold off;
end

