% Commonality analysis using Spearman correlations
% (which is like Pearson correlation on ranks)

addpath('helper_functions')

%% Load data first

roinames = {'EVC','LO','pFS','lPFC','Parietal'}; % these are our five ROIs

load('data/RDM_fMRI.mat')
load('data/RDM_MEG.mat')

t = linspace(-100,5000,612); %#ok<NASGU> % time points for each sample

%% Set up model variables

model1 = kron(eye(4),ones(8,1));
RDM_model_task = 1 - model1*model1';
model2 = repmat(eye(8),4,1);
RDM_model_object = 1 - model2*model2';

clear model1 model2

%% Select lower triangular matrix

% convert to vectors
RDV_task = mysquareform(RDM_model_task);
RDV_object = mysquareform(RDM_model_object);
RDV_fMRI = mysquareform(RDM_fMRI);
RDV_MEG = mysquareform(RDM_MEG);

%% This would be simple MEG-fMRI fusion

Fusionmat = corr(RDV_fMRI,RDV_MEG,'type','spearman');

%% We want more complicated model-based fusion (based on commonality analysis)

disp('Running commonality analysis...')
for j_roi = 5:-1:1
    
    xMRI = RDV_fMRI(:,j_roi);
    xtask = RDV_task;
    xobj = RDV_object;
    
    for i_time = 612:-1:1
        
        y = RDV_MEG(:,i_time);
        
        % and now again with another calculation
        rMEG_MRItaskobj = correlate([y xMRI xtask xobj],'type','spearman','method','semipartialcorr');
        rMEG_MRIobj = correlate([y xMRI xobj],'type','spearman','method','semipartialcorr');
        rMEG_MRItask = correlate([y xMRI xtask],'type','spearman','method','semipartialcorr');
        
        CMEGMRItask(i_time,j_roi) = rMEG_MRIobj(2,1).^2-rMEG_MRItaskobj(2,1).^2;
        CMEGMRIobj(i_time,j_roi) = rMEG_MRItask(2,1).^2-rMEG_MRItaskobj(2,1).^2;
     
    end
end
disp('done.')

%% Now we run a randomization test, randomizing the rows and columns of the MEG data

% Since all values in MEG are unique, we don't need to use tiedrank
% and since all other values don't change in the 612 steps and the 1000
% permutations, we can just calculate them once, rank transform them, and
% instead of running spearman correlation run pearson correlation.

% if you want to speed this up, run variance partitioning on multiple
% permutations in parallel (treating them as separate), on the ranks of the
% permutations

if exist('data/commonality_perm_reduc.mat','file')
    disp('Loading pre-calculated permutations ...')
    
    try
        % check if we ran the full (non-reduced) commonality permutations
        load('data/commonality_perm.mat')
    catch
        
        edges_task = 13:432;
        edges_obj = 253:432;
        load('data/commonality_perm_reduc.mat')
        
        
        % we couldn't submit all data points because of size limitations, so we
        % are keep only the relevant parts and set all others to zero
        CMEGMRIobj_perm = double(CMEGMRIobj_perm);
        CMEGMRItask_perm = double(CMEGMRItask_perm);
        sz = size(CMEGMRIobj_perm); sz(1) = 612;
        % fill in zeros
        CMEGMRIobj_perm = cat(1,zeros(edges_obj(1)-1,sz(2),sz(3)), double(CMEGMRIobj_perm));
        CMEGMRIobj_perm = cat(1,CMEGMRIobj_perm, zeros(sz(1)-edges_obj(end),sz(2),sz(3)));
        CMEGMRItask_perm = cat(1,zeros(edges_task(1)-1,sz(2),sz(3)), double(CMEGMRItask_perm));
        CMEGMRItask_perm = cat(1,CMEGMRItask_perm, zeros(sz(1)-edges_task(end),sz(2),sz(3)));
        
    end
    
    disp('done.')
    
else
    

    disp('Making 5000 permutations of MEG matrices...')
    RDV_MEG_perm = zeros([size(RDV_MEG) 5000]);
    ind = tril(true(size(RDM_MEG(:,:,1))),-1);
    for i_perm = 1:5000
        rp = randperm(32);
        curr_RDM = RDM_MEG(rp,rp,:);
        % convert to vector (faster than repeatedly calling mysquareform)
        for i_time = 1:612 % important: use same permutation across time!
            tmp = curr_RDM(:,:,i_time);
            RDV_MEG_perm(:,i_time,i_perm) = tmp(ind);
        end
    end
    
    % get ranks by sorting (because all values are unique in MEG) (this is
    % quite memory intense, alternatively replace the later call to
    % allranks with a call to
    % for i_time = 612:-1:1, for i_perm = 5000:-1:1, allranks(:,i_time,i_perm) = tiedrank2(RDV_MEG_perm(:,i_time,i_perm)); end, end
    disp('calculating ranks in parallel, if you get out of memory, change code in line 90')
    [~,sortind] = sort(RDV_MEG_perm);
    clear RDV_MEG_perm
    [n_cells,n_time] = size(RDV_MEG);
    sortind2 = sortind(:) + kron((0:n_cells:(5000*n_time*n_cells-1))',ones(n_cells,1));
    clear sortind
    clear allranks
    allranks(sortind2) = repmat((1:n_cells)',n_time*5000,1);
    clear sortind2
    allranks = reshape(allranks,[n_cells n_time 5000]);
    
    disp('Running 5000 permutations...')
    disp('(come back in a few hours)')
    
    ct = 0;
    ct2 = 0;
    fprintf(repmat('\f',1,9))
    for j_roi = 5:-1:1
        
        xMRI = tiedrank2(RDV_fMRI(:,j_roi));
        xtask = tiedrank2(RDV_task);
        xobj = tiedrank2(RDV_object);
        
        for i_perm = 5000:-1:1
            ct = ct+1;
            if ~mod(ct,5) % five models
                ct2 = ct2+1;
                fprintf(repmat('\b',1,9))
                fprintf('%04i/%04i',ct2,5000)
            end
            
            for i_time = 612:-1:1
                
                y = allranks(:,i_time,i_perm);
                
                % and now again with another calculation
                rMEG_MRItaskobj = correlate([y xMRI(:,1) xtask(:,1) xobj(:,1)],'type','pearson','method','semipartialcorr');
                rMEG_MRIobj = correlate([y xMRI(:,1) xobj(:,1)],'type','pearson','method','semipartialcorr');
                rMEG_MRItask = correlate([y xMRI(:,1) xtask(:,1)],'type','pearson','method','semipartialcorr');
                
                CMEGMRItask_perm(i_time,j_roi,i_perm) = rMEG_MRIobj(2,1).^2-rMEG_MRItaskobj(2,1).^2;
                CMEGMRIobj_perm(i_time,j_roi,i_perm) = rMEG_MRItask(2,1).^2-rMEG_MRItaskobj(2,1).^2;
                
            end
        end
        
    end
    
    fprintf('\ndone.\n')
    
    save data/commonality_perm.mat CMEGMRItask_perm CMEGMRIobj_perm
    
end


%% Cluster statistics

% get cluster correction using 0.05 as cluster-inducing threshold and
% restricting our analysis to time periods where we have a hypothesis (see
% below)

for i_roi = 5:-1:1
    tmp = squeeze(CMEGMRItask_perm(:,i_roi,:));
    tmp_sorted = sort(tmp','descend')'; %#ok<*UDIM>
    cutoff_task(:,i_roi) = tmp_sorted(:,floor(0.05*size(tmp,2)));
    mean_task(:,i_roi) = tmp_sorted(:,floor(0.5*size(tmp,2)));
    tmp = squeeze(CMEGMRIobj_perm(:,i_roi,:));
    tmp_sorted = sort(tmp','descend')';
    cutoff_obj(:,i_roi) = tmp_sorted(:,floor(0.05*size(tmp,2)));
    mean_obj(:,i_roi) = tmp_sorted(:,floor(0.5*size(tmp,2)));
    
    % the cutoff below is a relatively conservative estimate (no permutation is larger)
    % cutoff_task(:,i_roi) = tmp_sorted(:,1);
end

% loop over permutations to get maximum cluster size
% pick edges: for task: 13:432 (101ms to 3600ms which is 0 to 3500)
%             for obj: 253:432 (2101ms to 3600ms)
edges_task = 13:432;
edges_obj = 253:432;
for i_perm = 5000:-1:1
    for j_roi = 5:-1:1
    % find clusters gives us the cluster sizes of the current permutation
    % in the current ROI that are larger than the cutoff value
    c_task(i_perm,j_roi) = max(find_clusters(CMEGMRItask_perm(edges_task,j_roi,i_perm)>cutoff_task(edges_task,j_roi)));
    c_obj(i_perm,j_roi) = max(find_clusters(CMEGMRIobj_perm(edges_obj,j_roi,i_perm)>cutoff_obj(edges_obj,j_roi)));
    end
end

for j_roi = 5:-1:1
    c_sorted = sort(c_task(:,j_roi),'descend');
    clust_cutoff_task(j_roi) = c_sorted(floor(0.05*size(c_task,1)));
    c_sorted = sort(c_obj(:,j_roi),'descend');
    clust_cutoff_obj(j_roi) = c_sorted(floor(0.05*size(c_obj,1)));
end

% now get cutoff as maximum cutoff of all (which we need for a cutoff corrected for multiple comparisons)
clust_cutoff_taskall = max(clust_cutoff_task);
clust_cutoff_objall = max(clust_cutoff_obj);

% with these cutoffs check out real cluster sizes
% TODO: we may potentially get percentiles for cluster sizes found empirically for p-values?

all_clustind_task = zeros(612,5);
all_clustind_obj = zeros(612,5);
for j_roi = 5:-1:1
    [c,~,~,clustind] = find_clusters(CMEGMRItask(edges_task,j_roi)>cutoff_task(edges_task,j_roi));
    for i_c = 1:length(c) % loop over cluster sizes
        if c(i_c)<=clust_cutoff_taskall % was <clust_cutoff_task(j_roi)
            clustind(clustind==i_c) = 0;
        end
    end
    all_clustind_task(edges_task,j_roi) = clustind;
    
    [c,~,~,clustind] = find_clusters(CMEGMRIobj(edges_obj,j_roi)>cutoff_obj(edges_obj,j_roi));
    for i_c = 1:length(c)
        if c(i_c)<=clust_cutoff_objall % was <clust_cutoff_obj(j_roi)
            clustind(clustind==i_c) = 0;
        end
    end
    all_clustind_obj(edges_obj,j_roi) = clustind;
end


%% Plotting

% Let's plot results (task and obj within one, i.e. 5 plots)
% using a quadratic scale y-axis and plotting significance above
% (quadratic scale y-axis to reflect an axis similar to correlation values
% of previous RSA and fusion results)

tlim = [-100 3500];
tind = 1:(tlim(2)-tlim(1))*0.12;
t = linspace(tlim(1),tlim(2),(tlim(2)-tlim(1))*0.12);

smooth_kern = 1; % smoothed results may look nicer (no smoothing, i.e. kernel = 1 used for figure in paper)
% y = normpdf(linspace(-2.355,2.355,smooth_kern));
y = normpdf(linspace(-2.355,0,smooth_kern));
% normalize to 1
y = y/sum(y);
if length(y)==1
    y = 1;
end

% below is not commented well, but it's only plotting
for i_roi = 1:5

    if i_roi == 1 || i_roi == 2 || i_roi == 3
        tytick = [-0.01:0.01:0.04 0.06 0.08 0.10]; % was -0.01:0.01:0.08
    elseif i_roi == 4
        tytick = [-0.01 0 0.01 0.04:0.02:0.1 0.15 0.20]; % was 0.04:0.02:0.20
    elseif i_roi == 5
        tytick = [-0.01 0 0.01 0.05:0.05:0.2 0.3 0.4]; % was 0.05:0.05:0.35
    end
    
    tyticklabel = num2cell(tytick);
    tylim = tytick([1 end]);

    taskplot = CMEGMRItask(tind,i_roi);
    objplot = CMEGMRIobj(tind,i_roi);
    ceilplot = Fusionmat(i_roi,tind).^2;
    
    taskplot = [taskplot((smooth_kern-1)/2:-1:1); taskplot; taskplot(end:-1:end-(smooth_kern-1)/2+1)]; %#ok<AGROW>
    taskplot = conv(taskplot,y,'valid');
    
    objplot = [objplot((smooth_kern-1)/2:-1:1); objplot; objplot(end:-1:end-(smooth_kern-1)/2+1)]; %#ok<AGROW>
    objplot = conv(objplot,y,'valid');
    
    ceilplot = [ceilplot((smooth_kern-1)/2:-1:1); ceilplot; ceilplot(end:-1:end-(smooth_kern-1)/2+1)]; %#ok<AGROW>
    ceilplot = conv(ceilplot,y,'valid');
    
    hf = figure;
    h3 = area(t,sqrt(ceilplot)); % plot fusion as baseline

    hold on
    % we convert values to quadratic scale and change the y-axis
    s = sign(taskplot);
    h1 = plot(t,s.*sqrt(abs(taskplot)));

    s = sign(objplot);
    h2 = plot(t,s.*sqrt(abs(objplot)));
    xlim(tlim)
    legend({'fusion','task','object'})
    title(strrep(roinames{i_roi},'_','\_'))
    h1.LineWidth = 1.5;
    h2.LineWidth = 1.5;
    h3.LineWidth = 0.5;
    h3.FaceColor = [0.9 0.9 0.9];
    h3.EdgeColor = [0.9 0.9 0.9];
    ha = gca;
    ha.YTickMode = 'manual';
    ha.YTick = sign(tytick).*sqrt(abs(tytick));
    ha.YTickLabel = tyticklabel;
    ha.YLim = sign(tylim).*sqrt(abs(tylim));
    
    % add significance bars
    sig_task = double(all_clustind_task(tind,i_roi)>0);
    sig_task(sig_task==0) = NaN; % NaN will allow plotting as a line with breaks
    sig_obj = double(all_clustind_obj(tind,i_roi)>0);
    sig_obj(sig_obj==0) = NaN;    
    
    % fact will provide where to place the significance bars
    fact_task = sqrt([0.095 0.095 0.095 0.19 0.375]); % was 0.075 0.075 0.075 0.19 0.335
    fact_obj  = sqrt([0.09  0.09  0.09  0.18 0.36  ]); % was 0.07 0.07 0.07 0.18 0.32
    
    ht1 = plot(t,sig_task*fact_task(i_roi),'b','linewidth',1.5);
    ht2 = plot(t,sig_obj*fact_obj(i_roi),'r','linewidth',1.5);
    
    % plot time 0
    plot([0 0],ha.YLim,'k--','linewidth',1.5)
    plot(xlim,[0 0],'k--','linewidth',1.5)
    
% activate below if you want to save the results    
%     print(['commonality_' roinames{i_roi} '.eps'],'-painters','-depsc')
% saveas(gcf,['commonality_' roinames{i_roi} '.png'],'png')
    
end
