function network = NetClust(network)
%% Cluster analysis for NE related Data
% input: network structure which is created by NetEvents.m


netraster = network.netraster;
netnum = network.netnum;

%% cluster sorting

% pearson's correlation, squared euclidian distance in correlation matrix
[netcorr,~] = corr(netraster');    
isnet = ~isnan(diag(netcorr));
netcorr = netcorr(isnet,isnet);
Y = netcorr;
Z = linkage(Y,'weighted','seuclidean');  

%% 

numit = 1000;
[W,EW,thresh,Wk,EWk] = ShuffleClust(network,numit,round(sum(isnet)/2));
gap = log(EW)-log(W);
gapK = log(EWk)-log(Wk);

network.gap = [gap;W;EW];
network.gapK = [gapK;Wk;EWk];
%% r-threshold for cluster number
[aa,bb] = histcounts((netnum),'normalization','probability');

sumaa = 0;
for ii = 1:length(aa)
    sumaa = sumaa + aa(ii);
    if sumaa >0.5
        meannet = bb(ii)+.5;
        break
    end
end 
numcell = sum(isnet);
numclust = round(numcell/meannet);

clustThresh = [];
k = 1;
num = -round(numclust/2):1:round(numcell)/2;
netcorr = corr(netraster(isnet,:)');
for j = num
    C = cluster(Z,'maxclust',numclust+j);
    B = zeros(1,max(C));

    for i = 1:max(C)      
        A = netcorr(C==i,C==i);
        A = tril(A,-1);
        A(A==0) = NaN;
        B(i) =  nanmean(A(:));
    end
    
    if nanmean(B) >thresh &&  k == 1
        newnum = numclust+j;
        clustsig = B;
        clustsig(2,:) = B>thresh;
        k = 0;
    end
   
    clustThresh(j+round(numclust/2)+1,:) = [numclust+j,nanmean(B)];   
end

if k ==1
    newnum = numclust+j;
    clustsig = B;
    clustsig(2,:) = B>thresh;
end
    
numclust = newnum;

%% Silhuette testing for hieraichal and k means clustering
ss = zeros(1,round(sum(isnet)/2));
wss = zeros(1,round(sum(isnet)/2));
CC = zeros(size(Y,1),round(sum(isnet)/2));
parfor i = 1:round(sum(isnet)/2)
    CC(:,i) = cluster(Z,'maxclust',i);%'cutoff',.2,'depth',50);
    s = silhouette(Y,CC(:,i),'cos');%'Euclidean');
    ss(i) = nanmean(s);
    wss(i) = sum((grpstats(CC(:,i), CC(:,i), 'numel')-1) .* sum(grpstats(Y, CC(:,i), 'var'), 2));
end

network.silh.hier_ss = ss;
network.silh.hier_wss = wss;

% k-means
ss = zeros(round(sum(isnet)/2),100);
wss = zeros(round(sum(isnet)/2),100);
parfor i = 1:round(sum(isnet)/2)
    for j = 1:10
        cidx = kmeans(Y,i);
        s = silhouette(Y,cidx,'Euclidean');
        ss(i,j) = mean(s);
        wss(i,j) = sum((grpstats(cidx, cidx, 'numel')-1) .* sum(grpstats(Y, cidx, 'var'), 2));
    end
end

network.silh.kmeans_ss = ss;
network.silh.kmeans_wss = wss;
%%

C = cluster(Z,'maxclust',numclust);
[~,c] = sort(C,'descend');
temp = netraster(isnet,:);
netcluster = corr(temp(c,:)');
cellID(isnet,1) = C;
cellID(isnet,2) = c;
cellID(~isnet,:) = NaN;


%% Cluster analysis output

clustID.clustsig = clustsig;
network.cellID = cellID;
network.netcorr = corr(netraster');
network.netcluster = netcluster;
network.clustThresh = clustThresh;
network.clustthresh = thresh;
network.clustID = clustID;

end

function [W,EW,thresh,Wk,EWk] = ShuffleClust(network,numit,numclustin)
%%
sigThresh = 0.98; % Boarder for the significance threshold

netraster = network.netraster;
randr = zeros(numit,200);
EW = zeros(numit,numclustin);
EWk = zeros(numit,numclustin);

parfor i = 1:numit
    shufraster = shufshuf(netraster);
    % pearson's correlation, squared euclidian distance in correlation matrix
    [netcorr,~] = corr(shufraster');    
    isnet = ~isnan(diag(netcorr));
    netcorr = netcorr(isnet,isnet);
    Y = netcorr;
    Z = linkage(Y,'weighted','seuclidean');
    [w,wk] = clustsort(Y,Z,numclustin,1);
    EW(i,:) = w; 
    EWk(i,:) = wk;
    randrtemp = threshfit(shufraster);
    randr(i,:) = randrtemp;
end

EW = nanmean(EW,1);
EWk = nanmean(EWk,1);
[netcorr,~] = corr(netraster');    
isnet = ~isnan(diag(netcorr));
netcorr = netcorr(isnet,isnet);
Y = netcorr;
Z = linkage(Y,'weighted','seuclidean');

[W,Wk] = clustsort(Y,Z,numclustin,5);

x = -1+0.005:.01:1-0.005;
y = cumsum(sum(randr)/sum(randr(:)));
thresh = x(find(y>sigThresh,1));

end

function shufraster = shufshuf(netraster)
    shufraster = zeros(size(netraster));
    for j = 1:size(netraster,1)
        a = randperm(size(netraster,2));
        shufraster(j,:) = netraster(j,a);
    end
end

function randr = threshfit(netraster)
    [netcorr,~] = corr(netraster');
    a = ~isnan(diag(netcorr));
    netcorr = netcorr(a,a);
    Y = netcorr;
    Y(1:size(Y,1)+1:end) = NaN;
    randr = histcounts(Y,-1:.01:1);
end

function [W,Wk] = clustsort(Y,Z,numclust,numk)
    %% cluster sorting  
    W = zeros(1,round(numclust));
    Wk = zeros(1,round(numclust));
    for j = 1:round(numclust)
        %% Hierachical
        
        C = cluster(Z,'maxclust',j);
        B = zeros(1,max(C));

        for i = 1:max(C)
            A = corr(netraster(C==i,:))';
            A = tril(A,-1);
            A(A==0) = NaN;           
%             A = pdist(Y(C==i,:),'cos');                  
            B(i) =  nanmean(A(:))/2;
        end       
        W(j) = nansum(B);
        %% k means
        wk = zeros(1,numk);
        for k = 1:numk
            C = kmeans(Z,j);
            B = zeros(1,max(C));

            for i = 1:max(C)
                A = corr(netraster(C==i,:))';
                A = tril(A,-1);
                A(A==0) = NaN;           
%                 A = pdist(Y(C==i,:),'cos');                  
                B(i) =  nanmean(A(:))/2;
            end
            wk(k) = nansum(B);
        end        
        Wk(j) = nanmean(wk);
        
    end
    

end