% -------------------------------------------------------------------------
% Code for Figure 2 in Lee et al., "Sifting of  visual information in the 
% superior colliculus", submitted to eLife (2020)
% -------------------------------------------------------------------------
% open figure window
f = figure('Position',[334   558   906   410], 'Name', 'Lee et al. 2020 - Figure 2','NumberTitle','off');
% f = figure('Units', 'inch', 'Position',[0   0   7.5   5], 'Name', 'Lee et al. 2020 - Figure 2','NumberTitle','off');

% -------------------------------------------------------------------------
% (A) Single neuron comparison: response to looming vs checkerboard stimulus
% -------------------------------------------------------------------------
% place panel label
panel_label_A = subplot('Position',[0.01 0.96 0.03 0.03]);
text(panel_label_A, 0, 0.5, 'A', 'FontSize', 15, 'FontWeight', 'bold','FontName','Arial');
axis(panel_label_A,'off');

% plot icon for expanding black disk
ax_a1 = subplot('Position',[0.04 0.79 0.075 0.06]);
icon_expanding_black_disk(ax_a1);

% plot icon for checkerboard
ax_a2 = subplot('Position',[0.010+0.03 0.600 0.075 0.06]);
icon_checkerboard(ax_a2);

% plot response to expanding black disk
ax_a3 = subplot('Position',[0.13 0.73 0.15 0.18]);
response_expanding_black_disk(ax_a3, '20190416', 3, 840);
ax_a4 = subplot('Position',[0.3 0.73 0.15 0.18]);
response_expanding_black_disk(ax_a4, '20190416', 3, 820);

% plot response to checkerboard
ax_a5 = subplot('Position',[0.130 0.540 0.065 0.18]);
ax_a56 = subplot('Position',[0.195 0.540 0.020 0.18]);
ax_a6 = subplot('Position',[0.215 0.540 0.065 0.18]);
response_checkerboard(ax_a5, ax_a6, ax_a56, '20190416', 2, 840);

ax_a7 = subplot('Position',[0.3 0.540 0.065 0.18]);
ax_a78 = subplot('Position',[0.365 0.540 0.020 0.18]);
ax_a8 = subplot('Position',[0.385 0.540 0.065 0.18]);
response_checkerboard(ax_a7, ax_a8, ax_a78, '20190416', 2, 820);

% -------------------------------------------------------------------------
% (B) Single neuron comparison: spike triggered average stimulus
% -------------------------------------------------------------------------
% place panel label
panel_label_B = subplot('Position',[0.01 0.47 0.03 0.03]);
text(panel_label_B, 0, 0.5, 'B', 'FontSize', 15, 'FontWeight', 'bold','FontName','Arial');
axis(panel_label_B,'off');

% place label for spatial receptive field
ax_b_spatial_label = subplot('Position',[0.02 0.36 0.075 0.06]);
text(ax_b_spatial_label, 0, 0.5, 'Spatial RF', 'FontSize', 11, ...
    'FontWeight', 'bold','FontName','Arial','HorizontalAlignment','left');
axis(ax_b_spatial_label, 'off');

% place label for temporal receptive field
ax_b_temporal_label = subplot('Position',[0.02 0.16 0.075 0.06]);
text(ax_b_temporal_label, 0, 0.5, 'Temporal RF', 'FontSize', 11, ...
    'FontWeight', 'bold','FontName','Arial','HorizontalAlignment','left');
axis(ax_b_temporal_label, 'off');

% plot spatial and temporal receptive fields
ax_b1 = subplot('Position',[0.1+0.03  0.3 0.07 0.18]); % spatial center, sSC
ax_b2 = subplot('Position',[0.18+0.03 0.3 0.07 0.18]); % spatial surround, sSC
ax_b3 = subplot('Position',[0.1+0.03 0.1 0.15 0.18]); % temporal, sSC
ax_b4 = subplot('Position',[0.27+0.03 0.3 0.07 0.18]);  % spatial center, dSC
ax_b5 = subplot('Position',[0.35+0.03 0.3 0.07 0.18]); % spatial surround, dSC
ax_b6 = subplot('Position',[0.27+0.03 0.1 0.15 0.18]); % temporal, dSC
plot_spatiotempoeral_rf(ax_b1, ax_b2, ax_b3, ax_b4, ax_b5, ax_b6, '20190416', 2, 840, 820);

% -------------------------------------------------------------------------
% (C) Population summary: looming vs. checkerboard
% -------------------------------------------------------------------------
% place panel label
panel_label_C = subplot('Position',[0.51 0.96 0.03 0.03]);

text(panel_label_C, 0, 0.5, 'C', 'FontSize', 15, 'FontWeight', 'bold','FontName','Arial');
axis(panel_label_C,'off');

% plot icon for comparing expanding black vs contracting white disks
ax_c1 = subplot('Position',[0.57-0.01 0.90 0.06 0.05]);
ax_c111 = subplot('Position',[0.57+0.06 0.90 0.02 0.05]);
ax_c2 = subplot('Position',[0.65+0.01 0.90 0.06 0.05]);

icon_expanding_black_disk(ax_c1);
icon_checkerboard(ax_c2)
text(ax_c111, 0.5, 0.5, 'vs.', 'FontSize', 10, 'FontWeight', 'bold','FontName','Arial','HorizontalAlignment','center');
axis(ax_c111,'off');

% plot population summary (selectivity vs depth) and histogram
ax_c2 = subplot('Position',[0.54 0.36 0.20 0.50]);
ax_c3 = subplot('Position',[0.54 0.15 0.20 0.15]);
selectivity_vs_depth_ch(ax_c2, ax_c3);

% -------------------------------------------------------------------------
% (D) Population summary: looming vs white receding disk
% -------------------------------------------------------------------------
% place panel label
panel_label_D = subplot('Position',[0.72+0.03 0.96 0.03 0.03]);
text(panel_label_D, 0, 0.5, 'D', 'FontSize', 15, 'FontWeight', 'bold','FontName','Arial');
axis(panel_label_D,'off');

% plot icon for comparing expanding black vs contracting white disks
% ax_d1 = subplot('Position', [0.84 0.86 0.15 0.1]);
ax_d1 = subplot('Position',[0.78+0.03-0.01 0.90 0.06 0.05]);
ax_d111 = subplot('Position',[0.84+0.03 0.90 0.02 0.05]);
ax_d11 = subplot('Position',[0.86+0.03+0.01 0.90 0.06 0.05]);
icon_expanding_black_disk(ax_d1);
icon_receding_white_disk(ax_d11)
text(ax_d111, 0.5, 0.5, 'vs.', 'FontSize', 10, 'FontWeight', 'bold','FontName','Arial','HorizontalAlignment','center');
axis(ax_d111,'off');

% plot population summary (selectivity vs depth) and histogram
ax_d2 = subplot('Position',[0.78 0.36 0.20 0.50]);
ax_d3 = subplot('Position',[0.78 0.15 0.20 0.15]);
selectivity_vs_depth(ax_d2, ax_d3);

ax_axislabel = subplot('Position',[0.76-0.1 0.03 0.2 0.05]);
text(ax_axislabel, 0.5, 0.5, 'Looming selectivity index', 'FontSize', 10, 'FontName','Arial','HorizontalAlignment','center');
axis(ax_axislabel,'off');


% -------------------------------------------------------------------------
% Functions
% -------------------------------------------------------------------------

function plot_spatiotempoeral_rf(ax1, ax2, ax3, ax4, ax5, ax6, recordingdate, dataind, cell1, cell2)
% -------------------------------------------------------------------------
% plots the spatial and temporal receptive fields of the two neurons
% specified
% -------------------------------------------------------------------------
% inputs
%   ax1 : axis handle for spatial RF (center) of the cell 1
%   ax2 : axis handle for spatial RF (surround) of the cell 1
%   ax3 : axis handle for temporal RF of the cell 1
%   ax4 :axis handle for spatial RF (center) of the cell 2
%   ax5 : axis handle for spatial RF (surround) of the cell 2
%   ax6 : axis handle for temporal RF of the cell 2
%   recordingdate : specify which recording to use
%   cell1 : ID of first cell whose RF to plot
%   cell2 : ID of second cell whos RF to plot
% -------------------------------------------------------------------------

% load recording info and the STA
path = fullfile('M:\mouse_sc\',recordingdate,'proc');
load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,'cellid_masks.mat'),'cellid','masks');
load(fullfile(path,'ST.mat'),'ST');
load(fullfile(path,d{1,dataind},'stimparam.mat'),'param');

% call another function to get the first and the second SVD components for
% temporal and spatial RF
[~, spatial1, temporal1] = denoiseRFbySVD(recordingdate, 0, cell1);
[~, spatial2, temporal2] = denoiseRFbySVD(recordingdate, 0, cell2);

% set colormap
colsize=125;
col = ones(colsize,3); 
col(1:ceil(colsize/2),2)=0:1/floor(colsize/2):1; 
col(1:ceil(colsize/2),3)=0:1/floor(colsize/2):1;
col(ceil(colsize/2):colsize,1) = 1:-1/floor(colsize/2):0; 
col(ceil(colsize/2):colsize,2) = 1:-1/floor(colsize/2):0;
col = flipud(col);

% temporal rf
svd_temporal_11 = temporal1{1, cellid == cell1}; % adjust sign based on the color to plot
svd_temporal_12 = -temporal1{2, cellid == cell1};

tmax = max([abs(svd_temporal_11); abs(svd_temporal_12)]);

svd_temporal_21 = temporal2{1, cellid == cell2};
svd_temporal_22 = -temporal2{2, cellid == cell2};

hold(ax3, 'on');
hold(ax6, 'on');

% time axis and zero time point
plot(ax3, zeros(1,length(svd_temporal_12)), 'LineWidth', 1.2, 'Color', [0 0 0]);
plot(ax3, [length(svd_temporal_12), length(svd_temporal_12)], [-0.2, 0.2], 'LineWidth', 1.2, 'Color', [0 0 0]);

plot(ax3, svd_temporal_11/tmax, 'LineWidth', 1.5, 'Color', col(1,:));
plot(ax3, svd_temporal_12/tmax, 'LineWidth', 1.5, 'Color', col(end,:));

% time axis and zero time point
plot(ax6, zeros(1,length(svd_temporal_22)), 'LineWidth', 1.3, 'Color', [0 0 0]);
plot(ax6, [length(svd_temporal_22), length(svd_temporal_22)], [-0.2, 0.2], 'LineWidth', 1.4, 'Color', [0 0 0]);

plot(ax6, svd_temporal_21/tmax, 'LineWidth', 1.5, 'Color', col(1,:));
plot(ax6, svd_temporal_22/tmax, 'LineWidth', 1.5, 'Color', col(end,:));

% scale bar and label for temporal rf
plot(ax6, length(svd_temporal_22)-1:-1:length(svd_temporal_22)-6,-ones(1,6)*0.7,...
    'LineWidth', 1.5, 'Color', [0,0,0]);
text(ax6, (length(svd_temporal_22)-1)*0.5+(length(svd_temporal_22)-6)*0.5, -0.9, '100 ms',...
    'HorizontalAlignment', 'center', 'FontSize', 10, 'FontName', 'Arial', 'Color', [0 0 0]);

hold(ax3, 'off');
hold(ax6, 'off');

ylim(ax3, [-1.1 1.1]);
ylim(ax6, [-1.1 1.1]);

axis(ax3, 'off');
axis(ax6, 'off');

% spatial rf
svd_spatial_11 = -spatial1{1, cellid == cell1}; % adjust sign based on the color to plot
svd_spatial_12 = -spatial1{2, cellid == cell1};

svd_spatial_21 = -spatial2{1, cellid == cell2};
svd_spatial_22 = -spatial2{2, cellid == cell2};

maxelem_1 = max(max(abs(svd_spatial_11)));

hold(ax1, 'on');
imagesc(ax1, 'CData', svd_spatial_11, [-1 1]*maxelem_1);
plot(ax1, [0.5   size(svd_spatial_11,2)+0.5], [0.5  0.5], 'k', 'LineWidth', 0.7);
plot(ax1, [0.5   size(svd_spatial_11,2)+0.5], [size(svd_spatial_11,1)+0.5 size(svd_spatial_11,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax1, [0.5   0.5], [0.5 size(svd_spatial_11,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax1, [size(svd_spatial_11,2)+0.5   size(svd_spatial_11,2)+0.5], [0.5 size(svd_spatial_11,1)+0.5], 'k', 'LineWidth', 0.7);
hold(ax1, 'off');

hold(ax2, 'on');
imagesc(ax2, 'CData', svd_spatial_12, [-1 1]*maxelem_1);
plot(ax2, [0.5   size(svd_spatial_12,2)+0.5], [0.5  0.5], 'k', 'LineWidth', 0.7);
plot(ax2, [0.5   size(svd_spatial_12,2)+0.5], [size(svd_spatial_12,1)+0.5 size(svd_spatial_12,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax2, [0.5   0.5], [0.5 size(svd_spatial_12,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax2, [size(svd_spatial_12,2)+0.5   size(svd_spatial_12,2)+0.5], [0.5 size(svd_spatial_12,1)+0.5], 'k', 'LineWidth', 0.7);
hold(ax2, 'on');

hold(ax4, 'on');
imagesc(ax4, 'CData', svd_spatial_21, [-1 1]*maxelem_1);
plot(ax4, [0.5   size(svd_spatial_21,2)+0.5], [0.5  0.5], 'k', 'LineWidth', 0.7);
plot(ax4, [0.5   size(svd_spatial_21,2)+0.5], [size(svd_spatial_21,1)+0.5 size(svd_spatial_21,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax4, [0.5   0.5], [0.5 size(svd_spatial_21,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax4, [size(svd_spatial_21,2)+0.5   size(svd_spatial_21,2)+0.5], [0.5 size(svd_spatial_21,1)+0.5], 'k', 'LineWidth', 0.7);
hold(ax4, 'on');

hold(ax5, 'on');
imagesc(ax5, 'CData', svd_spatial_22, [-1 1]*maxelem_1);
plot(ax5, [0.5   size(svd_spatial_22,2)+0.5], [0.5  0.5], 'k', 'LineWidth', 0.7);
plot(ax5, [0.5   size(svd_spatial_22,2)+0.5], [size(svd_spatial_22,1)+0.5 size(svd_spatial_22,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax5, [0.5   0.5], [0.5 size(svd_spatial_22,1)+0.5], 'k', 'LineWidth', 0.7);
plot(ax5, [size(svd_spatial_22,2)+0.5   size(svd_spatial_22,2)+0.5], [0.5 size(svd_spatial_22,1)+0.5], 'k', 'LineWidth', 0.7);
hold(ax5, 'on');

% scale bar for spatial rf
[~,d2m] = probetype_d2monitor(recordingdate);
scale_bar_angles = 10;
ddx = tan(scale_bar_angles/2*pi/180)/0.53*20*2*d2m; % in pixels
ddx = ddx/param.dx; % in checkers
plot(ax5, [size(svd_spatial_22,2)-2   size(svd_spatial_22,2)-2-ddx], [4 4], 'k', 'LineWidth', 1.5);
text(ax5, size(svd_spatial_22,2)-2-0.5*ddx, 2, [num2str(scale_bar_angles), ' deg'], ...
    'HorizontalAlignment', 'center', 'FontSize', 10, 'FontName', 'Arial', 'Color', [0 0 0]);

colormap(ax1, col);
colormap(ax2, col);

colormap(ax4, col);
colormap(ax5, col);

axis(ax1, 'equal');
axis(ax2, 'equal');

axis(ax4, 'equal');
axis(ax5, 'equal');

axis(ax1, 'off');
axis(ax2, 'off');

axis(ax4, 'off');
axis(ax5, 'off');

end

function response_checkerboard(ax1, ax2, ax12, recordingdate, dataind, cell)
% -------------------------------------------------------------------------
% plots response of specified neuron to flickering checkerboard
% -------------------------------------------------------------------------
% input
%   ax1 : axis handle to plot beginning of checkerboard
%   ax2 : axis handle to plot end of checkerboard
%   ax12 : axis handle to plot break in time axis
%   recordingdate : string, specify which recording to use
%   dataind : specify which data file in the recording to use
%   cell : ID of the cell to plot
% -------------------------------------------------------------------------

path = fullfile('M:\mouse_sc\',recordingdate,'proc');

load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,'cellid_masks.mat'),'cellid','masks');

binsize = 0.1;
       
% load spiketrain and stim timing
load(fullfile(path,d{1,dataind},'spiketrain.mat'),'cellid','spiketrain');
load(fullfile(path,d{1,dataind},'stim.mat'),'stim_trigger');

% set bin edges
bins1 = max(0,stim_trigger(1)-20*binsize):binsize:stim_trigger(1)+100*binsize;
bins2 = stim_trigger(end)-100*binsize:binsize:stim_trigger(end)+20*binsize;

hold(ax1, 'on');
hold(ax2, 'on');

% indicate stimulus
fill(ax1, [stim_trigger(1) stim_trigger(1)+100*binsize stim_trigger(1)+100*binsize stim_trigger(1)],...
    [0 0 300 300], 'r', 'EdgeColor', 'none', 'FaceAlpha', 0.1);
fill(ax2, [stim_trigger(end)-100*binsize   stim_trigger(end)   stim_trigger(end)   stim_trigger(end)-100*binsize],...
    [0 0 300 300], 'r', 'EdgeColor', 'none', 'FaceAlpha', 0.1);

time_between = stim_trigger(end)-100*binsize - (stim_trigger(1)+100*binsize);

% plot psth
h1 = histogram(ax1, spiketrain{1,cellid==cell}, bins1, 'Normalization','countdensity',...
    'DisplayStyle', 'bar', 'EdgeAlpha',0,'FaceColor',[0,0,0], 'FaceAlpha',1);
h2 = histogram(ax2, spiketrain{1,cellid==cell}, bins2, 'Normalization','countdensity',...
    'DisplayStyle', 'bar', 'EdgeAlpha',0,'FaceColor',[0,0,0], 'FaceAlpha',1);

% save max firing and bin ends for later
max_val = max([h1.Values h2.Values 1]);
if max_val < 50
    max_val = 70;
end

% highlight bottom border
line(ax1, [bins1(1) bins1(end)], [0,0], 'color', [0,0,0], 'LineWidth', 1.0);
line(ax2, [bins2(1) bins2(end)], [0,0], 'color', [0,0,0], 'LineWidth', 1.0);

% scale bar
horizontal_scale_bar_position = 0.999; % fraction of max x range
horizontal_scale_bar_length = 1; % in s
vertical_scale_bar_position = 0.99; % fraction of max y range
vertical_scale_bar_length = 50; % in spikes/s

% horizontal bar
line(ax2, [bins2(end)*horizontal_scale_bar_position,  bins2(end)*horizontal_scale_bar_position-horizontal_scale_bar_length],...
    [max_val*1.2*vertical_scale_bar_position   max_val*1.2*vertical_scale_bar_position],...
    'Color', [0 0 0], 'LineWidth', 1.4);

% vertical bar
line(ax2, [bins2(end)*horizontal_scale_bar_position,  bins2(end)*horizontal_scale_bar_position],...
    [max_val*1.2*vertical_scale_bar_position   max_val*1.2*vertical_scale_bar_position-vertical_scale_bar_length],...
    'Color', [0 0 0], 'LineWidth', 1.4);

% remove axis
axis(ax1, 'off');
axis(ax2, 'off');

% plot break in the time axis
hold(ax12, 'on');
text(ax12, 0.5, max_val*0.95, ['~ ', num2str(round(time_between)), ' s'], 'HorizontalAlignment', 'center', 'FontSize', 8, 'FontName', 'Arial');
plot(ax12, [0 1],[0 0], 'k', 'LineWidth', 1.0);
plot(ax12, [0.3 0.4],[-max_val*0.1,  max_val*0.1], 'k', 'LineWidth', 1.0);
plot(ax12, [0.6 0.7],[-max_val*0.1,  max_val*0.1], 'k', 'LineWidth', 1.0);
xlim(ax12, [0 1]);
ylim(ax12, [-max_val*0.1 max_val*1.2]);
axis(ax12, 'off');
hold(ax12, 'off');

% adjust xlim and ylim so that every subplot is the same
ylim(ax1, [-max_val*0.1,   max_val*1.2]);
ylim(ax2, [-max_val*0.1,   max_val*1.2]);
xlim(ax1, [bins1(1)   bins1(end)]);
xlim(ax2, [bins2(1)   bins2(end)]);

end

function response_expanding_black_disk(ax, recordingdate, dataind, cell)
% plots response of specified neuron to the expanding black disk
% input
%   ax : specify the axis handle
%   recordingdate : string, specify which recording to use
%   dataind : specify which data file in the recording to use
%   cell : ID of the cell to plot

path = fullfile('M:\mouse_sc\',recordingdate,'proc');

load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,'cellid_masks.mat'),'cellid','masks');

binsize = 0.1;
       
% load spiketrain and stim timing
load(fullfile(path,d{1,dataind},'spiketrain.mat'),'cellid','spiketrain');
load(fullfile(path,d{1,dataind},'stim.mat'),'stim_trigger');

% set bin edges
bins = stim_trigger(1)-5*binsize:binsize:stim_trigger(end)+5*binsize;

% color and text label
relative_depth = classify_layer2(recordingdate);
if relative_depth(cellid==cell) >= -100
    txt = 'sSC';
    col = [222,45,38]/255;
else
    txt = 'dSC';
    col = [49,130,189]/255;
end

hold(ax,'on');

% plot psth
h = histogram(ax, spiketrain{1,cellid==cell}, bins, 'Normalization','countdensity',...
    'DisplayStyle', 'bar', 'EdgeAlpha',0,'FaceColor',[0,0,0],'FaceAlpha',1);

% save max firing and bin ends for later
max_val = max([h.Values,1]);

% indicate stimulus period
for k=1:2:length(stim_trigger)-1
    fill(ax, [stim_trigger(k) stim_trigger(k+1) stim_trigger(k+1) stim_trigger(k)],...
        [0 0 max_val*1.1 max_val*1.1], 'r', 'EdgeColor', 'none', 'FaceAlpha', 0.1);
end
clear k

% highlight bottom border
line(ax, [bins(1) bins(end)], [0,0], 'color', [0,0,0], 'LineWidth', 1.0);

% label
text(ax, (bins(end)-bins(1))*0.5+bins(1), max_val*1.25, txt,...
                'Horizontalalignment','center', 'FontSize', 11, 'Color', col, 'FontWeight', 'bold','FontName','Arial');

% scale bar
horizontal_scale_bar_position = 0.99; % fraction of max x range
horizontal_scale_bar_length = 1; % in s
vertical_scale_bar_position = 0.99; % fraction of max y range
vertical_scale_bar_length = 50; % in spikes/s

% horizontal
line(ax, [bins(end)*horizontal_scale_bar_position bins(end)*horizontal_scale_bar_position-horizontal_scale_bar_length],...
    [max_val*1.1*vertical_scale_bar_position max_val*1.1*vertical_scale_bar_position],...
    'Color', [0 0 0], 'LineWidth', 1.4);
if relative_depth(cellid==cell) < -100 
text(ax, bins(end)*horizontal_scale_bar_position-0.5*horizontal_scale_bar_length,...
    max_val*1.25*vertical_scale_bar_position,'1 s', 'Color', [0 0 0], 'Horizontalalignment','center', 'FontSize', 9, 'FontName','Arial');
end

% vertical
line(ax, [bins(end)*horizontal_scale_bar_position   bins(end)*horizontal_scale_bar_position],...
    [max_val*1.1*vertical_scale_bar_position    max_val*1.1*vertical_scale_bar_position-vertical_scale_bar_length],...
    'Color', [0 0 0], 'LineWidth', 1.4);
if relative_depth(cellid==cell) < -100 
text(ax, bins(end)*horizontal_scale_bar_position+0.6, max_val*1.2*vertical_scale_bar_position-0.5*vertical_scale_bar_length,...
    {'50','sp/s'}, 'Color', [0 0 0], 'Horizontalalignment','left', 'FontSize', 9, 'FontName','Arial');
end

% remove axis
axis(ax, 'off');

% adjust xlim and ylim so that every subplot is the same
ylim(ax, [0 max_val*1.2]);
xlim(ax, [bins(1) bins(end)]);

end

function icon_checkerboard(ax)
% -------------------------------------------------------------------------
% draws an icon of checkerboard
% -------------------------------------------------------------------------
% input
%   ax: axis handle to draw the icon
% -------------------------------------------------------------------------

% set parameters
rng(1002,'twister');
a = rand(10,10)>0.5;
b = rand(10,10)>0.5;
c = rand(10,10)>0.5;

im =[a ones(10,2) b ones(10,2) c];

hold(ax, 'on');

imagesc(ax, im);

colormap(ax, gray);

plot(ax, [0.5,10.5],[10.5,10.5],'k','linewidth',0.75);
plot(ax, [0.5,10.5],[0.5,0.5],'k','linewidth',0.75);
plot(ax, [0.5,0.5],[0.5,10.5],'k','linewidth',0.75);
plot(ax, [10.5,10.5],[0.5,10.5],'k','linewidth',0.75);

plot(ax, [12.5,22.5],[10.5,10.5],'k','linewidth',0.75);
plot(ax, [12.5,22.5],[0.5,0.5],'k','linewidth',0.75);
plot(ax, [12.5,12.5],[0.5,10.5],'k','linewidth',0.75);
plot(ax, [22.5,22.5],[0.5,10.5],'k','linewidth',0.75);

plot(ax, [24.5,34.5],[10.5,10.5],'k','linewidth',0.75);
plot(ax, [24.5,34.5],[0.5,0.5],'k','linewidth',0.75);
plot(ax, [24.5,24.5],[0.5,10.5],'k','linewidth',0.75);
plot(ax, [34.5,34.5],[0.5,10.5],'k','linewidth',0.75);

axis(ax, 'equal');
axis(ax, 'off');

hold(ax, 'off');

end

function icon_expanding_black_disk(ax)
% -------------------------------------------------------------------------
% draws an icon of expanding black disk
% -------------------------------------------------------------------------
% input
%   ax: axis handle to draw the icon
% -------------------------------------------------------------------------

% set parameters
gray_level = [1 1 1]*0.7;
length_square = 1;
d_square = 0.2;

figcolor = [0 0 0];
radii = [0.2 0.4 0.8];
hold(ax,'on');

% draw gray background then black circle
rectangle(ax, 'Position',[0 0 length_square length_square],'Curvature',[0 0],'FaceColor',gray_level,'EdgeColor','none');
rectangle(ax, 'Position',[length_square/2-radii(1)/2 length_square/2-radii(1)/2 radii(1) radii(1)],'Curvature',[1 1],'FaceColor',figcolor,'EdgeColor','none');

rectangle(ax, 'Position',[length_square+d_square 0 length_square length_square],'Curvature',[0 0],'FaceColor',gray_level,'EdgeColor','none');
rectangle(ax, 'Position',[length_square+d_square+length_square/2-radii(2)/2 length_square/2-radii(2)/2 radii(2) radii(2)],'Curvature',[1 1],'FaceColor',figcolor,'EdgeColor','none');

rectangle(ax, 'Position',[2*(length_square+d_square) 0 length_square length_square],'Curvature',[0 0],'FaceColor',gray_level,'EdgeColor','none');
rectangle(ax, 'Position',[2*(length_square+d_square)+length_square/2-radii(3)/2 length_square/2-radii(3)/2 radii(3) radii(3)],'Curvature',[1 1],'FaceColor',figcolor,'EdgeColor','none');

hold(ax,'off');

% turn off axis info
axis(ax,'equal'); 
axis(ax,'off');

% set limits
xlim(ax, [0,length_square*3+d_square*2]);
ylim(ax, [0,length_square]);
end

function icon_receding_white_disk(ax)
% -------------------------------------------------------------------------
% draws an icon of expanding black disk
% -------------------------------------------------------------------------
% input
%   ax: axis handle to draw the icon
% -------------------------------------------------------------------------

% set parameters
gray_level = [1 1 1]*0.7;
length_square = 1;
d_square = 0.2;

radii = [0.2 0.4 0.8];
hold(ax,'on');

% draw gray background then black circle
rectangle(ax, 'Position',[0 0 length_square length_square],'Curvature',[0 0],'FaceColor',gray_level,'EdgeColor','none');
rectangle(ax, 'Position',[length_square/2-radii(3)/2 length_square/2-radii(3)/2 radii(3) radii(3)],'Curvature',[1 1],'FaceColor',[1,1,1],'EdgeColor','none');

rectangle(ax, 'Position',[length_square+d_square 0 length_square length_square],'Curvature',[0 0],'FaceColor',gray_level,'EdgeColor','none');
rectangle(ax, 'Position',[length_square+d_square+length_square/2-radii(2)/2 length_square/2-radii(2)/2 radii(2) radii(2)],'Curvature',[1 1],'FaceColor',[1,1,1],'EdgeColor','none');

rectangle(ax, 'Position',[2*(length_square+d_square) 0 length_square length_square],'Curvature',[0 0],'FaceColor',gray_level,'EdgeColor','none');
rectangle(ax, 'Position',[2*(length_square+d_square)+length_square/2-radii(1)/2 length_square/2-radii(1)/2 radii(1) radii(1)],'Curvature',[1 1],'FaceColor',[1,1,1],'EdgeColor','none');

hold(ax,'off');

% turn off axis info
axis(ax,'equal'); 
axis(ax,'off');

% set limits
xlim(ax, [0,length_square*3+d_square*2]);
ylim(ax, [0,length_square]);
end

function selectivity_vs_depth(ax1,ax2)
% -------------------------------------------------------------------------
% plot population summary of looming selectivity index vs. depth in SC
% -------------------------------------------------------------------------
% inputs
%   ax1 : axis handle for plotting the population summary
%   ax2 : axis handle for plotting the histogram
% -------------------------------------------------------------------------

% data sets used
recordingdates{1,1} = '20190209';
% recordingdates{1,2} = '20180905';
recordingdates{1,2} = '20170517';
recordingdates{1,3} = '20170119';
recordingdates{1,4} = '20180708';
recordingdates{1,5} = '20190320';
recordingdates{1,6} = '20190416';

% index within the recording for expanding black disk
loom_j{1,1} = 3;
% loom_j{1,2} = 3;
loom_j{1,2} = 2;
loom_j{1,3} = 2;
loom_j{1,4} = 2;
loom_j{1,5} = 3;
loom_j{1,6} = 3;%15;

% index within the recording for contracting white disk
wr_j{1,1} = 13;
% wr_j{1,2} = 8;
wr_j{1,2} = 29;
wr_j{1,3} = 3;
wr_j{1,4} = 17;
wr_j{1,5} = 5;
wr_j{1,6} = 9;%21;

% range of depth corresponding to SC (350 microns above and 750 microns
% below the superficial / deep boundary)
uprange = [0,350];
lowrange = [-750,0];
totalrange = [min(lowrange),max(uprange)];

upcolor = [222,45,38]/255;
lowcolor = [49,130,189]/255;

ncells = zeros(1,size(recordingdates,2));
sl = [];
d = [];

hold(ax1, 'on');
%
rng(102,'twister');

for i=1:size(recordingdates,2)
    
    % get depth
    depth = classify_layer2(recordingdates{1,i});
    
    % get selectivity index
    [s,sig] = compute_selectivity(recordingdates{1,i}, loom_j{1,i}, wr_j{1,i});
    
    % add jitter to so that points are not on top of each other
    jitter = rand(size(s));
    jitter = (jitter - 0.5)/0.5*0.02;
    s = s - abs(jitter);
    
    jitter2 = rand(size(depth));
    jitter2 = (jitter2-0.5)/0.5*5;
    depth = depth + jitter2;
    
    % add 100 microns to account for the thickness of SO
    depth = depth+100;
    
    % plot
    plot(ax1, s(sig & depth>lowrange(1) & depth<=uprange(2)), depth(sig & depth>lowrange(1) & depth<=uprange(2)), 'ko', 'MarkerSize',4);
    
    % save the values for histogram
    sl = [sl, s(sig & depth>lowrange(1) & depth<=uprange(2))];
    d = [d, depth(sig & depth>lowrange(1) & depth<=uprange(2))];
    ncells(i) = length(s(sig & depth>lowrange(1) & depth<=uprange(2)));
    
end
clear i depth

% ylabel(ax1, 'Depth from bottom of SO (\mum)', 'FontSize', 10, 'FontName', 'Arial');

xrange = [-1,1];

% dotted line
plot(ax1, xrange, [uprange(1) uprange(1)], 'k--', 'LineWidth', 1.3);
plot(ax1, [0.75 0.75], totalrange, 'k--', 'LineWidth', 1.3);

% x cells from y recordings
text(ax1, xrange(1)*0.9, totalrange(1)*0.90,...
    {[num2str(sum(ncells)),' cells from '],[num2str(size(ncells,2)),' recordings']}, 'FontSize', 8,'FontName','Arial');

xlim(ax1, xrange);
ylim(ax1, totalrange);
ax1.TickDir = 'out';
ax1.XTick = [];
ax1.YTick = totalrange(1):150:totalrange(2);
ax1.Box = 'off';
ax1.XColor = 'none';
ax1.FontSize = 9;
ax1.FontName = 'Arial';
set(ax1,'ytick',[])
set(ax1,'yticklabel',[])
hold(ax1,'off')

% plot histogram
hold(ax2, 'on');

binsize = 0.2;
binedges= xrange(1):binsize:xrange(2);

histogram(ax2, sl(d>uprange(1) & d<=uprange(2)), binedges,'EdgeColor','none',...
    'EdgeAlpha',0,'Normalization','probability','FaceColor',upcolor,'LineStyle','none');
histogram(ax2, sl(d>lowrange(1) & d<=lowrange(2)), binedges,'EdgeColor','none',...
    'EdgeAlpha',0,'Normalization','probability','Facecolor',lowcolor,'LineStyle','none');

legend(ax2, 'boxoff');

% significance tests: two sample KS test
[~,p_val_ks] = kstest2(sl(d>uprange(1) & d<=uprange(2)),sl(d>lowrange(1) & d<=lowrange(2)),'Alpha',0.005);

% text(ax2, 0, 0.3, ['\it p', '\rm < ',num2str(round(p_val_ks,9))], 'FontSize', 8,'FontName','Arial');
text(ax2, 0, 0.3, ['\it p', '\rm = ',num2str(p_val_ks,1)], 'FontSize', 8,'FontName','Arial');

ylim(ax2, [0 0.4]);

ax2.TickDir = 'out';
ax2.XTick = xrange(1):0.5:xrange(2);
ax2.FontSize = 9;
ax2.FontName = 'Arial';
ax2.YTick = 0:0.2:0.4;

hold(ax2,'off');

end

function selectivity_vs_depth_ch(ax1,ax2)
% -------------------------------------------------------------------------
% plot population summary of looming selectivity index (response to looming
% vs. checkerboard) against depth in SC
% -------------------------------------------------------------------------
% inputs
%   ax1 : axis handle for plotting the population summary
%   ax2 : axis handle for plotting the histogram
% -------------------------------------------------------------------------

recordingdates{1,1} = '20190209';
recordingdates{1,2} = '20170928';
recordingdates{1,3} = '20190320';
recordingdates{1,4} = '20190416';
recordingdates{1,5} = '20180824';

loom_j{1,1} = 15;
loom_j{1,2} = 2;
loom_j{1,3} = 6;
loom_j{1,4} = 3;
loom_j{1,5} = 3;

% loom_j{1,1} = 3; % 3 for grid
% loom_j{1,2} = 4; % 4 for grid
% loom_j{1,3} = 3; % 3 for grid
% loom_j{1,4} = 26; % 26 for grid
% loom_j{1,5} = 27; % 27 for grid

ch_j{1,1} = 2;
ch_j{1,2} = 3;
ch_j{1,3} = 2;
ch_j{1,4} = 2;
ch_j{1,5} = 2;

% range of depth corresponding to SC (350 microns above and 750 microns
% below the superficial / deep boundary)
uprange = [0,350];
lowrange = [-750,0];
totalrange = [min(lowrange),max(uprange)];

upcolor = [222,45,38]/255;
lowcolor = [49,130,189]/255;

ncells = zeros(1,size(recordingdates,2));
sl = [];
d = [];

hold(ax1, 'on');

rng(102,'twister');

for i=1:size(recordingdates,2)
    % get depth
    depth = classify_layer2(recordingdates{1,i});
    
    % get selectivity
    [s,sig] = compute_selectivity_ch(recordingdates{1,i}, loom_j{1,i}, ch_j{1,i});
    
    % add jitter to the points so that they don't lie on top of one another
    jitter = rand(size(s));
    jitter = (jitter - 0.5)/0.5*0.01;
    s = s + jitter;
    
    jitter2 = rand(size(depth));
    jitter2 = (jitter2-0.5)/0.5*5;
    depth = depth + jitter2;
    
    % add 100 microns to account for the thickness of SO
    depth = depth+100;
    
    % plot data points
    plot(ax1, s(sig & depth>lowrange(1) & depth<=uprange(2)),...
        depth(sig & depth>lowrange(1) & depth<=uprange(2)), 'ko', 'MarkerSize',4);
    
    % keep data for histogram later
    sl = [sl, s(sig & depth>lowrange(1) & depth<=uprange(2))];
    d = [d, depth(sig & depth>lowrange(1) & depth<=uprange(2))];
    ncells(i) = length(s(sig & depth>lowrange(1) & depth<=uprange(2)));
end
clear i depth

ylabel(ax1, 'Depth from bottom of SO (\mum)', 'FontSize', 10, 'FontName', 'Arial');

xrange = [-1,1];

% dashed line
plot(ax1, xrange, [uprange(1) uprange(1)], 'k--', 'LineWidth', 1.3);
plot(ax1, [0.75 0.75], totalrange, 'k--', 'LineWidth', 1.3);

% x cells from y recordings
text(ax1, xrange(1)*0.9, totalrange(1)*0.90,...
    {[num2str(sum(ncells)),' cells from '],[num2str(size(ncells,2)),' recordings']}, 'FontSize', 8,'FontName','Arial');

xlim(ax1, xrange);
ylim(ax1, totalrange);
ax1.TickDir = 'out';
ax1.XTick = [];
ax1.YTick = totalrange(1):150:totalrange(2);
ax1.Box = 'off';
ax1.XColor = 'none';
ax1.FontSize = 9;
ax1.FontName = 'Arial';

hold(ax1,'off')

% plot histogram
hold(ax2, 'on');

binsize = 0.2;
binedges= xrange(1):binsize:xrange(2);

h1 = histogram(ax2, sl(d>uprange(1) & d<=uprange(2)), binedges,'EdgeColor','none',...
    'EdgeAlpha',0,'Normalization','probability','FaceColor',upcolor,'LineStyle','none');
h2 = histogram(ax2, sl(d>lowrange(1) & d<=lowrange(2)), binedges,'EdgeColor','none',...
    'EdgeAlpha',0,'Normalization','probability','Facecolor',lowcolor,'LineStyle','none');

legend(ax2, {'sSC', 'dSC'}, 'Location','northwest','FontSize',8);
legend(ax2, 'boxoff');

ylabel(ax2, 'Probability','FontSize',8,'FontName','Arial');

% KS test
[~,p_val_ks] = kstest2(sl(d>uprange(1) & d<=uprange(2)),sl(d>lowrange(1) & d<=lowrange(2)),'Alpha',0.005);

% text(ax2, -0.1, 0.3, ['\it p', '\rm < ',num2str(round(p_val_ks,15))], 'FontSize', 8,'FontName','Arial');
text(ax2, -0.1, 0.3, ['\it p', '\rm = ',num2str(p_val_ks,1)], 'FontSize', 8,'FontName','Arial');

ylim_top = max([h1.Values, h2.Values]);
ylim_top = ceil(ylim_top*10)/10;

ylim(ax2, [0 ylim_top]);

ax2.TickDir = 'out';
ax2.XTick = xrange(1):0.5:xrange(2);
ax2.FontSize = 9;
ax2.FontName = 'Arial';
ax2.YTick = 0:ylim_top/2:ylim_top;

hold(ax2,'off');

end

function [s,sig] = compute_selectivity(recordingdate,loom_j,wr_j)
% -------------------------------------------------------------------------
% computes the looming selectivity index
% -------------------------------------------------------------------------
% input
%   recordingdate : specify which recording to use
%   loom_j : index of data file containing response to expanding black disk
%   wr_j : index of data file containing response to contracting white disk
% output
%   s : looming selectivity index, defined as (RL-RO)/(RL+RO) where RL =
%       spike count during expanding black disk and RO = spike count during
%       contracting white disk
%   depth : depth from the border between SO and SGS; subtract 100 to
%           turn into border between sSC and dSC
%   sig :  logical, significant responses
% -------------------------------------------------------------------------

p_cutoff = 5e-3;

% use 
if strcmp(recordingdate,'20190209') || strcmp(recordingdate,'20190320')
    [loom_spikecount,wr_spikecount] = getloomresponse_randomloom(recordingdate,loom_j,wr_j);
    
    loom_spikecount_bg = get_background(recordingdate,loom_j,loom_j);
%     loom_spikecount_bg = get_background2(recordingdate, 1, loom_j);
    wr_spikecount_bg = get_background(recordingdate,wr_j,loom_j);
%     wr_spikecount_bg = loom_spikecount_bg;
    
    sig = poisson_sigtester(p_cutoff, loom_spikecount, wr_spikecount, loom_spikecount_bg, wr_spikecount_bg);
    
else 
    loom_spikecount = get_loom_response(recordingdate,1,loom_j,loom_j);
    loom_spikecount_bg = get_background(recordingdate,loom_j,loom_j);
%     loom_spikecount_bg = get_background2(recordingdate, 1, loom_j);
    
    wr_spikecount = get_loom_response(recordingdate,1,wr_j,loom_j);
    wr_spikecount_bg = get_background(recordingdate,wr_j,loom_j);
%     wr_spikecount_bg = loom_spikecount_bg;
    
    sig = poisson_sigtester(p_cutoff, loom_spikecount, wr_spikecount, loom_spikecount_bg, wr_spikecount_bg);
    
end

loom_spikecount = loom_spikecount - loom_spikecount_bg;
loom_spikecount(loom_spikecount<0)=0;


wr_spikecount = wr_spikecount - wr_spikecount_bg;
wr_spikecount(wr_spikecount<0)=0;

s = (loom_spikecount - wr_spikecount) ./ (loom_spikecount + wr_spikecount);
end

function [s,sig] = compute_selectivity_ch(recordingdate,loom_j,ch_j)
% -------------------------------------------------------------------------
% computes the looming selectivity index
% -------------------------------------------------------------------------
% input
%   recordingdate : specify which recording to use
%   loom_j : index of data file containing response to expanding black disk
%     must be the type where the stim is repeated x number of times in same
%     location
%   ch_j : index of data file containing response to checkerboard
% output
%   s : looming selectivity index, defined as (RL-RO)/(RL+RO) where RL =
%       spike count during expanding black disk and RO = spike count during
%       checkerboard stim based on average firing rate
%   depth : depth from the border between SO and SGS; subtract 100 to
%           turn into border between sSC and dSC
%   sig :  logical, significant responses
% -------------------------------------------------------------------------

p_cutoff = 5e-3;

% compute number of spikes
loom_spikecount = get_loom_response(recordingdate, 1, loom_j, loom_j);
% loom_spikecount = getloomresponse_randomloom2(recordingdate, loom_j);

% response to loom expected from background activity
loom_spikecount_bg = get_background(recordingdate, loom_j, loom_j);
% loom_spikecount_bg = get_background2(recordingdate, 1, loom_j);

ch_spikecount = get_ch_response(recordingdate, ch_j, loom_j);
ch_spikecount_bg = get_background(recordingdate, ch_j, loom_j);
% ch_spikecount_bg = loom_spikecount_bg;

% identify significant responses
sig = poisson_sigtester(p_cutoff, loom_spikecount, ch_spikecount, loom_spikecount_bg, ch_spikecount_bg);

% subtract baseline
loom_spikecount = loom_spikecount - loom_spikecount_bg;
loom_spikecount(loom_spikecount<0) = 0;

ch_spikecount = ch_spikecount - ch_spikecount_bg;
ch_spikecount(ch_spikecount<0) = 0;

% compute selectivity index
s = (loom_spikecount - ch_spikecount) ./ (loom_spikecount + ch_spikecount);

end

function sig = poisson_sigtester(p_cutoff, loom_spikecount, wr_spikecount, loom_spikecount_bg, wr_spikecount_bg)
% -------------------------------------------------------------------------
% identifies visually responsive neurons, defined as those whose p-value of
% stimulus-driven response is less than the 'p_cutoff' variable based on a
% poisson noise model; response is defined as the number of spikes a cell
% fired during that time
% -------------------------------------------------------------------------
% input
%   p_cutoff
% output
%   sig: 1 x N logical
%     1 if significant, 0 otherwise
% -------------------------------------------------------------------------
loom_pvals = zeros(1,length(loom_spikecount));
wr_pvals = zeros(1,length(loom_spikecount));
for i=1:length(loom_pvals)
    % compute p values based on the baseline activity
    loom_pvals(i) = 1-cdf('Poisson', loom_spikecount(i), loom_spikecount_bg(i));
    wr_pvals(i) = 1-cdf('Poisson', wr_spikecount(i), wr_spikecount_bg(i));
end
clear i

% significant if response to either stimulus is sufficiently unlikely given
% the baseline firing rate
sig = (loom_pvals<p_cutoff) | (wr_pvals<p_cutoff);
end

function [loom_response,wr_response] = getloomresponse_randomloom(recordingdate,loom_j,wr_j)

path = fullfile('M:\mouse_sc\',recordingdate,'proc');
load(fullfile(path,'ordered_rhds.mat'),'d');

load(fullfile(path,d{1,loom_j},'spiketrain.mat'),'spiketrain','cellid','masks');
load(fullfile(path,d{1,loom_j},'stim.mat'),'stim_trigger');
load(fullfile(path,d{1,loom_j},'stimparam.mat'),'param');

loom_response = zeros(1,length(cellid));
loom_maxtrial = zeros(1,length(cellid));
pos = zeros(2,length(cellid));
for k = 1:length(cellid)
    spikecount = CountSpikesPerTrial(spiketrain{1,k},stim_trigger);
    [sc,maxtrial] = max(spikecount);
    loom_response(k) = sc;
    loom_maxtrial(k) = maxtrial;
    pos(1,k) = param.position_x(param.sequence_x(maxtrial));
    pos(2,k) = param.position_y(param.sequence_y(maxtrial));
end
clear k spikecount sc maxtrial spiketrain stim_trigger param

load(fullfile(path,d{1,wr_j},'spiketrain.mat'),'spiketrain','cellid','masks');
load(fullfile(path,d{1,wr_j},'stim.mat'),'stim_trigger');
load(fullfile(path,d{1,wr_j},'stimparam.mat'),'param');
wr_response = zeros(1,length(cellid));
for k = 1:size(spiketrain,2)
    spikecount = CountSpikesPerTrial(spiketrain{1,k},stim_trigger);
    wr_response(k) = spikecount(find(param.position_x(param.sequence_x)==pos(1,k) & param.position_y(param.sequence_y)==pos(2,k),1));
end
clear k spikecount spiketrain stim_trigger

end

function loom_response = getloomresponse_randomloom2(recordingdate,loom_j)
% -------------------------------------------------------------------------
% finds the response to looming stimulus at the location that elicited the
% max response (only consider the first trial at each location)
% -------------------------------------------------------------------------
path = fullfile('M:\mouse_sc\',recordingdate,'proc');
load(fullfile(path,'ordered_rhds.mat'),'d');

load(fullfile(path,d{1,loom_j},'spiketrain.mat'),'spiketrain','cellid','masks');
load(fullfile(path,d{1,loom_j},'stim.mat'),'stim_trigger');
load(fullfile(path,d{1,loom_j},'stimparam.mat'),'param');

loom_response = zeros(1,length(cellid));
for k = 1:length(cellid)
    ind = zeros(length(param.position_x),length(param.position_y));
    spikecount = CountSpikesPerTrial(spiketrain{1,k},stim_trigger);
    for t = 1:length(param.position_x)
        for v = 1:length(param.position_y)
            if sum(param.sequence_x==t & param.sequence_y==v)>0
                ind(t,v) = find(param.sequence_x==t & param.sequence_y==v,1);
            end
        end
        clear v
    end
    clear t
    ind = reshape(ind,1,[]);
    ind = ind(ind>0);
    loom_response(k) = max(spikecount(ind));
end
clear k spikecount sc maxtrial spiketrain stim_trigger param
end

function bgresponse = get_background(recordingdate,dataind,loom_dataind)
% -------------------------------------------------------------------------
% computes expected number of spikes in the background, based only on the
% short period preceding spike rather than taking average over all period
% preceding spike
% -------------------------------------------------------------------------
% input
%   recordingdate
% output
%   bgresponse : spikes
% -------------------------------------------------------------------------

path = fullfile('M:\mouse_sc\',recordingdate,'proc');

load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,d{1,loom_dataind(1)},'stim.mat'),'stim_trigger');
loomduration = mean(stim_trigger(2:2:end)-stim_trigger(1:2:end-1));
clear stim_trigger

for j=1:length(dataind)
    load(fullfile(path,d{1,dataind(j)},'spiketrain.mat'),'spiketrain','cellid','masks');
    load(fullfile(path,d{1,dataind(j)},'stim.mat'),'stim_trigger');
    
    bgresponse = zeros(length(dataind),size(spiketrain,2));
    for k = 1:size(spiketrain,2)
        bgresponse(j,k) = max([sum(spiketrain{1,k}<stim_trigger(1)) / stim_trigger(1) * loomduration,1]);
    end
    clear k
end
clear j
end

function bgresponse = get_background2(recordingdate,dataind,loom_dataind)
% -------------------------------------------------------------------------
% computes expected number of spikes during looming stimulus period as a 
% result of baseline firing, based on the beginning of the recording
% -------------------------------------------------------------------------
% input
%   recordingdate
% output
%   bgresponse : spikes
% -------------------------------------------------------------------------

path = fullfile('M:\mouse_sc\',recordingdate,'proc');

load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,d{1,loom_dataind(1)},'stim.mat'),'stim_trigger');
loomduration = mean(stim_trigger(2:2:end)-stim_trigger(1:2:end-1));
clear stim_trigger

for j=1:length(dataind)
    load(fullfile(path,d{1,dataind(j)},'spiketrain.mat'),'spiketrain','cellid','masks','recordingDuration');
    load(fullfile(path,d{1,dataind(j)},'stim.mat'),'stim_trigger');
    
    bgresponse = zeros(length(dataind),size(spiketrain,2));
    for k = 1:size(spiketrain,2)
        bgresponse(j,k) = length(spiketrain{1,k})/recordingDuration*loomduration;
    end
    clear k
end
clear j
end

function [response,i] = get_loom_response(recordingdate, trial_index, dataind, loom_dataind)
% -------------------------------------------------------------------------
% response to looming stimulus, defined as the number of spikes a neuron
% fires in a trial specified by 'trial_index' variable
% -------------------------------------------------------------------------
% input
%   recordingdate : char
%   trial_index : 1 x 1 int
%   dataind : 1 x K int
%     K: number of data files to look;
%   loom_dataind : 1 x P int, P >= 1
% output
%   response : 1 x N int
%     N: number of cells
% -------------------------------------------------------------------------
path = fullfile('M:\mouse_sc\',recordingdate,'proc');
% path = fullfile('J:\proc\',recordingdate);
% if ~exist(path,'dir')
%     path = fullfile('D:\kyu\SiliconProbe\Final\',recordingdate);
% end

load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,d{1,loom_dataind(1)},'spiketrain.mat'),'cellid');
load(fullfile(path,d{1,loom_dataind(1)},'stim.mat'),'stim_trigger');

% if trial_index is -1, take the last trial
if trial_index == -1
    trial_index = length(stim_trigger)/2;
end

loomduration = mean(stim_trigger(2:2:end)-stim_trigger(1:2:end-1));
clear stim_trigger

response = zeros(length(dataind),length(cellid));
for j=1:length(dataind)
    load(fullfile(path,d{1,dataind(j)},'spiketrain.mat'),'spiketrain','cellid','masks');
    load(fullfile(path,d{1,dataind(j)},'stim.mat'),'stim_trigger');
    stimduration = mean(stim_trigger(2:2:end)-stim_trigger(1:2:end-1));
    
    for k = 1:size(spiketrain,2)

        spikecount = CountSpikesPerTrial(spiketrain{1,k},stim_trigger);
        if trial_index==0
            response(j,k) = max(spikecount)/stimduration*loomduration;
        else
            response(j,k) = spikecount(trial_index)/stimduration*loomduration;
        end
    end
    clear k spikecount
end
clear j
[response,i] = max(response,[],1);
end

function response = get_ch_response(recordingdate,dataind,loom_dataind)
% -------------------------------------------------------------------------
% returns response to checkerboard stimulus, defined as the number of
% spikes it would have fired in a duration equal to the presentation of
% looming stimulus based on the average firing rate 
% -------------------------------------------------------------------------
% input
%   recordingdate : char
%   dataind : 1 x K int
%     K = number of data inds
%   loom_dataind : 1 x P int, P >= 1
% output
%   response : 1 x N int
%     N: number of cells
% -------------------------------------------------------------------------

% set path
path = fullfile('M:\mouse_sc\',recordingdate,'proc');

% load stuff
load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,d{1,loom_dataind(1)},'spiketrain.mat'),'cellid');
load(fullfile(path,d{1,loom_dataind(1)},'stim.mat'),'stim_trigger');

% figure out how long the looming stimulus lasted
stim_duration = mean(stim_trigger(2:2:end)-stim_trigger(1:2:end-1));
clear stim_trigger

% get repsonse to checkerboard
response = zeros(length(dataind),length(cellid));
for j=1:length(dataind)
    
    load(fullfile(path,d{1,dataind(j)},'spiketrain.mat'),'spiketrain');
    load(fullfile(path,d{1,dataind(j)},'stim.mat'),'stim_trigger');
    
    for k = 1:size(spiketrain,2)
        response(j,k) = sum(spiketrain{1,k}>stim_trigger(1) & spiketrain{1,k}<stim_trigger(end))/(stim_trigger(end)-stim_trigger(1))*stim_duration;
    end
    clear k
end
clear j
response = max(response,[],1);

end

function response = get_ch_response2(recordingdate,dataind,loom_dataind)
% -------------------------------------------------------------------------
% returns response to checkerboard stimulus, defined as the number of
% spikes a cell fires in the first n seconds, where n is the duration of
% the looming stimulus being compared
% -------------------------------------------------------------------------
% input
%   recordingdate : char
%   dataind : 1 x K int
%     K = number of data inds
%   loom_dataind : 1 x P int, P >= 1
% output
%   response : 1 x N int
%     N: number of cells
% -------------------------------------------------------------------------

% set path
path = fullfile('M:\mouse_sc\',recordingdate,'proc');

% load stuff
load(fullfile(path,'ordered_rhds.mat'),'d');
load(fullfile(path,d{1,loom_dataind(1)},'spiketrain.mat'),'cellid');
load(fullfile(path,d{1,loom_dataind(1)},'stim.mat'),'stim_trigger');

% duration of looming stimulus
stim_duration = mean(stim_trigger(2:2:end)-stim_trigger(1:2:end-1));
clear stim_trigger

% get repsonse to checkerboard
response = zeros(length(dataind),length(cellid));
for j=1:length(dataind)
    
    load(fullfile(path,d{1,dataind(j)},'spiketrain.mat'),'spiketrain');
    load(fullfile(path,d{1,dataind(j)},'stim.mat'),'stim_trigger');
    
    for k = 1:size(spiketrain,2)
        response(j,k) = sum(spiketrain{1,k}>stim_trigger(1) & spiketrain{1,k}<(stim_trigger(1)+stim_duration));
    end
    clear k
    
end
clear j
response = max(response,[],1);

end