function [Simil,SimShuf,Act] = PCAsimilarity(X,S_bin,E,running,whoshuf,norm)
%% Similarity measures between principle components
% This script calculates PCAs for different behavoral states in the data
% and uses the principle of projected variances between this basis to
% measure the similarity.
% The script offers different ways of shuffling the data to test the result
% against chance level and different ways of normlizing the input data.
% 
% Inputs:
%      X         Df/f traces
%      S         binary traces
%      E         logical with network events
%      running   logical whether animal was running
%      whoshuf   0 shift of traces against each other
%                1 random re-asighnment of NE timepoints
%                2 randomize NE composition (netraster in y-direction)
%                3 randomize cell attendance (netraster in x-direction)
%                4 randomize NE order
%      norm      0 no additional normalization
%                1 normalize each trace by its mean
%                2 z score each trace

%% Tuning variance threshold
[~,~,~,~,explained,~] = pca(X);
explained = cumsum(explained)/sum(explained);
thresh = .5; % threshold for percentage of explained variance
xDim = find(explained>thresh,1); % number of components considered;


%% running related activity
runstart = find(diff(running)==1);
runstop = find(diff(running)==-1);

if runstop(1)<runstart(1);runstop = runstop(2:end);end
if runstart(end)>runstop(end);runstart = runstart(1:end-1);end

clear dat

j = 1;
for i = 1:length(runstart)
    int = runstart(i):runstop(i);
    if length(int)>0
        dat(j).trialId = j;
        dat(j).spikes = X(:,int);
        dat(j).events = S_bin(:,int);
%         dat(j).space = scn.distance(int);
%         dat(j).space(dat(j).space<1) = 1;
        j = j+1;
    end
end

datrun = dat;
Xrun = [dat.spikes];
Zrun = [dat.events];
Zrun = sum(Zrun,2);

if sum(sum(Xrun,2)>1)<xDim
    Simil = [];
    SimShuf = [];
    ind_var = [];
    Act = [];
    return
end
%%
meanXrun        = mean(Xrun,2);
varXrun         = var(Xrun,[],2);
Xrun_centered   = Xrun-repmat(meanXrun,1,size(Xrun,2));
stdXrun         = std(Xrun,[],2);
Xrun_normed     = Xrun_centered./repmat(stdXrun+.001,1,size(Xrun,2));

if norm == 1
    Xrun_use = Xrun_centered;
elseif norm == 2
    Xrun_use = Xrun_normed;
else
    Xrun_use = Xrun;
end

Covrun = cov(Xrun_use');

%% Calculate running PCA using SVD rutine (Macke)    
[Urun,sRun,Vrun]=svd(Xrun_use);
sRun = diag(sRun(1:xDim,1:xDim));
Yrun = Urun(:,1:xDim);


%% network related activity
win = -7:7;
netpos = find(E);
netpos = netpos(netpos>abs(win(1)) & netpos<size(X,2)-win(end));

clear dat
for i = 1:length(netpos)
    dat(i).trialId = i;    
    dat(i).spikes = X(:,netpos(i)+win);
    dat(i).space = win;
    dat(i).events = S_bin(:,netpos(i)+win);
end

datnet = dat;
Xnet = [dat.spikes];
Znet = [dat.events];
Znet = sum(Znet,2);

if sum(sum(Xnet,2)>1)<xDim
    Simil = [];
    SimShuf = [];
    ind_var = [];
    Act = [];
    return
end

%%
meanXnet      = mean(Xnet,2);
varXnet       = var(Xnet,[],2);
Xnet_centered = Xnet-repmat(meanXnet,1,size(Xnet,2));
stdXnet       = std(Xnet,[],2);
Xnet_normed   = Xnet_centered./repmat(stdXnet+.001,1,size(Xnet,2));

if norm == 1
    Xnet_use = Xnet_centered;
elseif norm == 2
    Xnet_use = Xnet_normed;
else
    Xnet_use = Xnet;
end


Covnet = cov(Xnet_use');

%% Calculate network PCA using SVD rutine (Macke)
[Unet,Snet,Vnet]=svd(Xnet_use);
sNet = diag(Snet(1:xDim,1:xDim));
Ynet = Unet(:,1:xDim);


%% resting related activity
runstart = find(diff(running)==1);
runstop = find(diff(running)==-1);

if runstart(1)<runstop(1);runstart = runstart(2:end);end
if runstop(end)>runstart(end);runstop = runstop(1:end-1);end

clear dat

j = 1;
for i = 1:length(runstart)
    int = runstop(i):runstart(i);
    if length(int)>0
        dat(j).trialId = j;
        dat(j).spikes = X(:,int);
        dat(j).events = S_bin(:,int);
%         dat(j).space = scn.distance(int);
%         dat(j).space(dat(j).space<1) = 1;
        j = j+1;
    end
end

datrest = dat;
Xrest = [dat.spikes];
Zrest = [dat.events];
Zrest = sum(Zrest,2);

if sum(sum(Xrest,2)>1)<xDim
    Simil = [];
    SimShuf = [];
    ind_var = [];
    Act = [];
    return
end
%%
meanXrest        = mean(Xrest,2);
varXrest         = var(Xrest,[],2);
Xrest_centered   = Xrest-repmat(meanXrest,1,size(Xrest,2));
stdXrest         = std(Xrest,[],2);
Xrest_normed     = Xrest_centered./repmat(stdXrest+.001,1,size(Xrest,2));

if norm == 1
    Xrest_use = Xrest_centered;
elseif norm == 2
    Xrest_use = Xrest_normed;
else
    Xrest_use = Xrest;
end


Covrest = cov(Xrest_use');

%% Calculate restning PCA using SVD rutine (Macke)    
[Urest,srest,Vrest]=svd(Xrest_use);
srest = diag(srest(1:xDim,1:xDim));
Yrest = Urest(:,1:xDim);


%% resting no NE related activity
rest_E = running;
netT = find(E);
for i = 1:sum(E)
   rest_E(netT(i)+(win)) = 1;
end
runstart = find(diff(rest_E)==1);
runstop = find(diff(rest_E)==-1);

if runstart(1)<runstop(1);runstart = runstart(2:end);end
if runstop(end)>runstart(end);runstop = runstop(1:end-1);end

clear dat

j = 1;
for i = 1:length(runstart)
    int = runstop(i):runstart(i);
    if length(int)>0
        dat(j).trialId = j;
        dat(j).spikes = X(:,int);
        dat(j).events = S_bin(:,int);
%         dat(j).space = scn.distance(int);
%         dat(j).space(dat(j).space<1) = 1;
        j = j+1;
    end
end

datrnne = dat;
Xrnne = [dat.spikes];
Zrnne = [dat.events];
Zrnne = sum(Zrnne,2);

if sum(sum(Xrnne,2)>1)<xDim
    Simil = [];
    SimShuf = [];
    ind_var = [];
    Act = [];
    return
end
%%
meanXrnne        = mean(Xrnne,2);
varXrnne         = var(Xrnne,[],2);
Xrnne_centered   = Xrnne-repmat(meanXrnne,1,size(Xrnne,2));
stdXrnne         = std(Xrnne,[],2);
Xrnne_normed     = Xrnne_centered./repmat(stdXrnne+.001,1,size(Xrnne,2));

if norm == 1
    Xrnne_use = Xrnne_centered;
elseif norm == 2
    Xrnne_use = Xrnne_normed;
else
    Xrnne_use = Xrnne;
end


Covrnne = cov(Xrnne_use');

%% Calculate restning PCA using SVD rutine (Macke)    
[Urnne,srnne,Vrnne]=svd(Xrnne_use);
srnne = diag(srnne(1:xDim,1:xDim));
Yrnne = Urnne(:,1:xDim);


%% Similarity factors (Krzanowsky 1979)

L = Ynet(:,:);
M = Yrun(:,:);

S = L'*M*M'*L;
S = S/size(M,2);
    
%% Eros Approach (Yang 2004)
w = zeros(size(sNet));    
for i = 1:length(w)
    w(i) = mean([sNet(i) sRun(i)]);        
end
w = w./sum(w);
Er = zeros(size(sNet));

for i = 1:length(w)
    Er(i) = w(i)*abs(L(:,i)'*M(:,i));
end

%% Angle between the to subspace (as multiples of pi)
SubAng = subspace(L,M)/pi;

%% Projected variances
var_runnet  = trace(L'*Covrun*L);
var_netnet  = trace(L'*Covnet*L);
var_runrun  = trace(M'*Covrun*M);
var_runrest = trace(Yrest'*Covrun*Yrest);
var_runrnne = trace(Yrnne'*Covrun*Yrnne);
%%
Simil =  [sum(diag(S)) sum(Er) SubAng var_runnet var_runrun var_netnet var_runrest var_runrnne];   


%%
numit = 10;
SimShuf = zeros(numit,4);
varXrand = zeros(size(Xnet,1),numit);
Zrand = zeros(size(Xnet,1),numit);
Covrand = zeros(numit,1);
ExpRand = zeros(min(size(Xnet)),numit);

parfor ii = 1:numit
    [datrand,Xrand,XrandComp,Zrand(:,ii)] = randomX(X,E,running,xDim,whoshuf,win,S_bin);
    %%
    meanXrand=mean(Xrand,2);
    varXrand(:,ii) = var(Xrand,[],2);
    Xrand_centered= Xrand-repmat(meanXrand,1,size(Xrand,2));
    stdXrand = std(Xrand,[],2);
    Xrand_normed= Xrand_centered./repmat(stdXrand+.001,1,size(Xrand,2));

    if norm == 1
        Xrand_use = Xrand_centered;
    elseif norm == 2
        Xrand_use = Xrand_normed;
    else
        Xrand_use = Xrand;
    end

    
    Covrand1 = cov(Xrand_use');
    Covrand1 = tril(Covrand1,-1);
    Covrand(ii) = sum(Covrand1(:));

    [Urand,sRand,Vrand]=svd(Xrand_use);
    sRand = diag(sRand(1:xDim,1:xDim));
    Yrand = Urand(:,1:xDim);

    %% Similarity factors (Krzanowsky 1979)
    L = Yrand(:,:);
    M = Yrun(:,:);

    S = L'*M*M'*L;
    S = S/size(M,2);
    %% Eros Approach (Yang 2004)
    w = zeros(size(sRun));    
    for i = 1:length(w)
        w(i) = mean([sRand(i) sRun(i)]);        
    end
    w = w./sum(w);
    Er = zeros(size(sNet));

    for i = 1:length(w)
        Er(i) = w(i)*abs(L(:,i)'*M(:,i));
    end
    %% Angle between the to subspace (as multiples of pi)
    SubAng = subspace(L,M)/pi;
    %% Surviving variance
    var_runrand = trace(L'*Covrun*L);
    %%
    SimShuf(ii,:) =  [sum(diag(S)) sum(Er) SubAng var_runrand];   
    
end


%% variances in differen stages
varX =  [varXrun varXnet varXrest varXrnne mean(varXrand,2) std(varXrand,[],2)];
%% Events in different stages
events = [Zrun Znet Zrest Zrnne mean(Zrand,2) std(Zrand,[],2)];
%% Covariance in different stages
n = 0.5*size(X,1)*(size(X,1)-1);
Covrun = tril(Covrun,-1);
Covrun = sum(Covrun(:));
Covrest = tril(Covrest,-1);
Covrest = sum(Covrest(:));
Covnet = tril(Covnet,-1);
Covnet = sum(Covnet(:));
Covrnne = tril(Covrnne,-1);
Covrnne = sum(Covrnne(:));
covX = [Covrun Covnet Covrest Covrnne mean(Covrand) std(Covrand)]./n;
%%
% ExpNet(end:length(ExpRun)) = 0;
% ExpRand(end:length(ExpRun),:) = 0;
% explained = [ExpRun ExpNet ExpRest ExpRnne mean(ExpRand,2) std(ExpRand,[],2)];
%%
Act.varX = varX;
Act.events = events;
Act.covX = covX;
Act.explained = explained;
end

function [datrand,Xrand,XrandComp,Zrand] = randomX(X,E,running,xDim,whoshuf,win,S_bin)
    %%
    wf = 2;
    E(1:abs(win(1))*wf)  = 0;
    E(end-win(end)*wf:end) = 0;
    Erand = E;
    rest = find(running==0);
    netT = find(E);
    rest_E = running;
    for i = 1:sum(E)
       rest_E(netT(i)+(win(1)*wf:win(end)*wf)) = 1;
    end
    Drun = find(diff(running));
    Drun(Drun<abs(win(1)) | length(running)-win(end))  = [];
    for i = 1:length(Drun)
       rest_E(Drun(i)+win(1):win(end)) = 1;
    end
    rest_E = find(rest_E==0);
    xNew = 0;
    while xNew<xDim
        %% random network points during rest
        if whoshuf == 0
            %% random reshuffle of data
            Xrest = X(:,rest);
            for i = 1:size(X,1)
                tpt = randperm(size(Xrest,2),1);
                Xrest(i,:) = Xrest(i,[tpt:end 1:tpt-1]); 
            end
            X(:,rest) = Xrest;
        elseif whoshuf == 1
            Erand = false(size(E));
            Erand(randsample(rest_E,sum(E))) = true;
        end
        %% create shuffled dataset
        netpos = find(Erand);
        netpos = netpos(netpos>abs(win(1)) & netpos<size(X,2)-win(end));
        tempNet = randsample(netpos,length(netpos));

        for i = 1:length(netpos)
            if whoshuf == 2
                temp = randperm(size(X,1),size(X,1));
                dat(i).trialId = i;
                dat(i).events = S_bin(temp,netpos(i)+win);
                dat(i).spikes = X(temp,netpos(i)+win);
%                 dat(i).space = win;
            elseif whoshuf == 4
                dat(i).trialId = i;
                dat(i).events = S_bin(:,tempNet(i)+win);
                dat(i).spikes = X(:,tempNet(i)+win);
%                 dat(i).space = win;
            else
                dat(i).trialId = i;
                dat(i).events = S_bin(:,netpos(i)+win);
                dat(i).spikes = X(:,netpos(i)+win);
%                 dat(i).space = win;
            end
        end

        Xrand = [dat.spikes];
        XrandComp = X;
        xNew = sum(sum(Xrand,2)>0);
        Zrand = [dat.events];
        if whoshuf == 3
            %%            
            for j = 1:size(X,1)
                temp = randperm(length(netpos),length(netpos));
                dattemp = dat(temp);
                Xtemp = [dattemp.spikes];                  
                Xrand(j,:) = Xtemp(j,:); 
                Ztemp = [dattemp.events];
                Zrand(j,:) = Ztemp(j,:);
            end
            %%
            win = length(win);
            for i = 1:length(netpos)
%                 dat(i).trialId = i;
                dat(i).events = Zrand(:,(i-1)*win+1:i*win);
                dat(i).spikes = Xrand(:,(i-1)*win+1:i*win);
%                 dat(i).space = win;
            end
        end
        datrand = dat;
        Zrand = sum(Zrand,2);
    end
end
