% Component map change point detection
%{
This code is designed to find the points within the component maps. Each column of the
"maps" variable contains the component topography. The idea here is to
compare how the component maps, which are generated only from the
statistics of the signal on the contacts, overlap with what was seen in the
anatomical estimates of contact locations.

%}

clearvars; close all; clc

%Data files that contain the component maps
datFolder = 'E:\Arizona\HPS\LFP_Analysis\result_final\eLife_revisionResults';
cd( datFolder )
filDir = dir('*s.mat');

% Save figures?
saveFig = 0;
figFolder = 'E:\Arizona\HPS\LFP_Analysis\result_final\eLife_revisionResults\compTops\allMeans_5perc';

% save matfiles with results?
saveData = 0;
saveDatFolder = 'E:\Arizona\HPS\LFP_Analysis\result_final\eLife_revisionResults\compTops\matFiles';

[~,~,allLocs] = xlsread('E:\HPS_LFP_eLife\1_final_files\allChanLocations_final.xlsx');
allLocs(1,:) = [];

numShuff = 1000; %1000 used in paper
set(0,'DefaultFigureVisible','on')

%%% Loop through all sessions in the dataset
for sessi =  1:length(filDir)
    fprintf(['Session #' num2str(sessi) '\r'])
    % load the data from a session
    fnam = filDir(sessi).name;
    load( fnam )
    
    %give each contact location a numerical code
    sessLocs = allLocs(sessi,3:18);
    for sI = 1:length(sessLocs)
        if strcmp(sessLocs{sI}, 'ce') % contacts in the central nucleus
            chanCode(sI,1) = 1;
        elseif strcmp(sessLocs{sI}, 'ab') % contacts in the accessory basal nucleus
            chanCode(sI,1) = 2;
        elseif strcmp(sessLocs{sI}, 'ba') % contacts in the basal nucleus
            chanCode(sI,1) = 3;
        elseif strcmp(sessLocs{sI}, 'la') % contacts in the lateral nucleus
            chanCode(sI,1) = 4;
        elseif strcmp(sessLocs{sI}, 'me') % contacts in the medial nucleus
            chanCode(sI,1) = 5;
        else
            chanCode(sI,1) = 0;
        end
    end
    
    for modi = 1:3 %if assessing GED results that are not modality specific (ie are across all trials) use modi = 1
        %Loop through all the components
        for compi = 1:GED_results(modi).ncomps
            % shuffle the observed anatomical and component maps
            fprintf(['Comp #' num2str(compi) '\r'])
            
            % Get the component map for this modality and component
            compTop1 = GED_results(modi).maps(:,compi);
            
            %%% Use the 'findchangepts' function (mean parameter)
            
            %Manually set a minimum threshold for the proportionality constant of the findchangepts function
            transVal = range(compTop1)*0.05; % 0.05 was used for the paper, means that the contact group differences have to be at least 5% of total range
            
            %Select statistic for findchangepts function
            statType = 'mean'; 
            
            figure(1);clf
            
            %uses a min threshold so everything has the same rules
            [ ipt_real, res_real ] = findchangepts(compTop1,'Statistic',statType,'MinThreshold',transVal);
            findchangepts(compTop1,'Statistic',statType,'MinThreshold',transVal);
                     
            yMax = max(abs(ylim));
            
            %save findchangepts figure
            if saveFig == 1
                figure(2);clf
                axes('YAxisLocation','right','linewidth',5);
                hold on
                plot(compTop1,'k-o','linew',5,'markersize',25)
                for li = 1:length(ipt_real)
                    line([ipt_real(li)-.5 ipt_real(li)-.5 ],[-yMax yMax],...
                        'color',[.7 .7 .7],'linew',5,'linestyle','--')
                end
                line([1 16],[0 0],'color','k','linestyle','--','linew',5)
               
                xlim([1 16])
                xticks([])
                
                ylim([-yMax yMax])
                yticks(-yMax:yMax/2:yMax)
                camroll(270)
                
                
                set(gca,'fontsize',24)
                set(gcf,'position',[40 40 400 600])
                %png
                %print(gcf,[figFolder '\' fnam(1:end-4) '_m' num2str(modi) '_c' num2str(compi) '_' statType], '-dpng')
                %pdf
                print(gcf,[figFolder '\' fnam(1:end-4) '_m' num2str(modi) '_c' num2str(compi) '_' statType], '-dpdf','-bestfit')
                
            end

            mGroup(1:length(channels2use),1) = 0; %set up mean grouping
            
            %loop through the number of change points
            for tI = 1:length(ipt_real)
                if tI == 1
                    mGroup(1:ipt_real(tI)-1,1) = tI; %fill all contacts up to the first change point as group 1
                end
                if tI == length(ipt_real) && length(ipt_real) > 1 %if you are on the last change point
                    mGroup(ipt_real(tI-1):ipt_real(tI)-1,1) = tI; %fill contacts up this point as whatever this last group number is
                    mGroup(ipt_real(tI):16) = tI + 1;
                end
                if tI ~= 1 && tI ~= length(ipt_real)  %otherwise fill in the middle groups as appropriate
                    mGroup(ipt_real(tI-1):ipt_real(tI)-1,1) = tI;
                end
            end
            
            
            clear gC1 gC2 gC3
            
            groupCount = 0;
            for cI = 1:length(nonzeros(unique(chanCode))) %go through the anatomy groups
                
                cLocs = nonzeros(unique(chanCode)); %set them into groups based on the anatomy
                c2 = find(chanCode == cLocs(cI)); %find how many belong to the current group index
                
                %grab the stat-based group indices that matched the positions of the anatomy based group (for each anatomy group)
                gC1{cI} = mGroup(c2);
                for iv = 1:length(gC1)
                    gC3{iv} = ones(1,length(gC1{iv})); %set all to matching for now. will eliminate non-matching contacts in following code
                end
                
                
                uGs = unique(mGroup(c2))'; % get the number of different stats-based groups contained within the anatomy group
                if length(uGs) > 1 %if it > 1 then there is a mismatch
                    a = 0;
                    clear gComp
                    for i = uGs %go through the mismatched channels
                        a = a + 1;
                        gComp(a) = length(find(mGroup(c2) == i)); % figure out the sizes of the subgroups
                    end
                    gC2{cI} = gC1{cI} == min(uGs(find(gComp == max(gComp))))'; %set the values of the non-largest subgroups to 0 (i.e., identity them as mismatches)
                else
                    gC2{cI} = ones(1,length(mGroup(c2)))'; %if the numbers in the two groups completely overlap, then every contact in this stats group is considered a match
                end
            end
            
            % if there is more than one anatomy group
            if length(gC1) > 1
                for i = 1:length(gC1) %then go through each anatomy group
                    for ii = 1:length(gC1) %and check to see if any stat-group indices were split between the two groups (these split ones will be identitified as mismatches)
                        if i == ii % ignore it is it is the same anatomy group (don't need to compare them to themselves)
                        else
                            if isempty(intersect(gC1{i},gC1{ii})) %if there is no splitting, every contact can be classified as matching
                                gC3{i} = ones(1,length(gC1{i}));
                            else %if there is splitting, identify which group captured the most non-split contacts and set these as matching, all others will be set as mismatches
                                if length(find(gC1{i} == intersect(gC1{i},gC1{ii}))) > length(find(gC1{ii} == intersect(gC1{i},gC1{ii})))
                                    gC3{ii}(find(gC1{ii} == intersect(gC1{i},gC1{ii}))) = 0;
                                elseif length(find(gC1{i} == intersect(gC1{i},gC1{ii}))) < length(find(gC1{ii} == intersect(gC1{i},gC1{ii})))
                                    gC3{i}(find(gC1{i} == intersect(gC1{i},gC1{ii}))) = 0;
                                elseif length(find(gC1{i} == intersect(gC1{i},gC1{ii}))) == length(find(gC1{ii} == intersect(gC1{i},gC1{ii})))
                                    if length(gC1{i}) > length(gC1{ii})
                                        gC3{ii}(find(gC1{ii} == intersect(gC1{i},gC1{ii}))) = 1;
                                    else
                                        gC3{i}(find(gC1{i} == intersect(gC1{i},gC1{ii}))) = 1;
                                    end
                                end
                            end
                        end
                    end
                end
                
                % check the final counts
                for iii = 1:length(gC3)
                    if isempty(gC3{iii}) %if we found splitting, update the final counts to reflect the splitting issues
                        gC3{iii} = gC2{iii};
                    end
                end
                
                %
                for iv = 1:length(gC1)
                    for v = 1:length(gC1{iv})
                        % sum up the matches for each group
                        if gC2{iv}(v) == 1 && gC3{iv}(v) == 1 && sum(gC2{iv}) > 1 && sum(gC3{iv}) > 1 %Only count the ones that cleared all the checks as matches
                            groupCount = groupCount + 1;
                        end
                        
                    end
                end
            else
                %if there is mroe than one stat group to match, count the matches in gc2
                if sum(gC2{1}) > 1
                    groupCount = sum(gC2{1});
                end
            end
            
            gCount_real{sessi,modi,compi} = groupCount;
             
            % do same analysis but with shuffled component maps now
            for shuffi = 1:numShuff
                
                if mod(shuffi,5) == 0
                    fprintf([num2str(shuffi) '\r'])
                end
                compTopIdx = randi([2 15],1); %pick a place to split the component map for this component (split needs to be somewhere between contact 2 and 15)
                
                compTop2 = compTop1(compTopIdx+1:end); %first half of the cut-and-shifted map
                compTop3 = compTop1(1:compTopIdx); %second half of the cut-and-shifted map
                compTop = [compTop2; compTop3]; %put the map together
                
                % Numbering of contacts
                contactNum = 1:length(channels2use);
                
                %Finds the change points for a shuffled map
                [ipt, res] = findchangepts(compTop,'Statistic',statType,'MinThreshold',transVal); %index of the change point and residual. this function can be run to create the plots seen in figure 3
                
                mGroup(1:length(channels2use),1) = 0; %set up mean grouping
                
                %loop through the number of change points
                for tI = 1:length(ipt)
                    if tI == 1
                        mGroup(1:ipt(tI)-1,1) = tI; %fill all contacts up to the first change point as group 1
                    end
                    if tI == length(ipt) && length(ipt) > 1 %if you are on the last change point
                        mGroup(ipt(tI-1):ipt(tI)-1,1) = tI; %fill contacts up this point as whatever this last group number is
                        mGroup(ipt(tI):16) = tI + 1;
                    end
                    if tI ~= 1 && tI ~= length(ipt)  %otherwise fill in the middle groups as appropriate
                        mGroup(ipt(tI-1):ipt(tI)-1,1) = tI;
                    end
                end
                
                clear gC1 gC2 gC3
                
                groupCount = 0;
                for cI = 1:length(nonzeros(unique(chanCode))) %go through the anatomy groups
                    
                    cLocs = nonzeros(unique(chanCode)); %set them into groups based on the anatomy
                    c2 = find(chanCode == cLocs(cI)); %find how many belong to the current group index
                    
                    %grab the stat-based group indices that matched the positions of the anatomy based group (for each anatomy group)
                    gC1{cI} = mGroup(c2);
                    for iv = 1:length(gC1)
                        gC3{iv} = ones(1,length(gC1{iv})); %set all to matching for now. will eliminate non-matching contacts in following code
                    end
                    
                    
                    uGs = unique(mGroup(c2))'; % get the number of different stats-based groups contained within the anatomy group
                    if length(uGs) > 1 %if it > 1 then there is a mismatch
                        a = 0;
                        clear gComp
                        for i = uGs %go through the mismatched channels
                            a = a + 1;
                            gComp(a) = length(find(mGroup(c2) == i)); % figure out the sizes of the subgroups
                        end
                        gC2{cI} = gC1{cI} == min(uGs(find(gComp == max(gComp))))'; %set the values of the non-largest subgroups to 0 (i.e., identity them as mismatches)
                    else
                        gC2{cI} = ones(1,length(mGroup(c2)))'; %if the numbers in the two groups completely overlap, then every contact in this stats group is considered a match
                    end
                end
                
                % if there is more than one anatomy group
                if length(gC1) > 1
                    for i = 1:length(gC1) %then go through each anatomy group
                        for ii = 1:length(gC1) %and check to see if any stat-group indices were split between the two groups (these split ones will be identitified as mismatches)
                            if i == ii % ignore it is it is the same anatomy group (don't need to compare them to themselves)
                            else
                                if isempty(intersect(gC1{i},gC1{ii})) %if there is no splitting, every contact can be classified as matching
                                    gC3{i} = ones(1,length(gC1{i}));
                                else %if there is splitting, identify which group captured the most non-split contacts and set these as matching, all others will be set as mismatches
                                    if length(find(gC1{i} == intersect(gC1{i},gC1{ii}))) > length(find(gC1{ii} == intersect(gC1{i},gC1{ii})))
                                        gC3{ii}(find(gC1{ii} == intersect(gC1{i},gC1{ii}))) = 0;
                                    elseif length(find(gC1{i} == intersect(gC1{i},gC1{ii}))) < length(find(gC1{ii} == intersect(gC1{i},gC1{ii})))
                                        gC3{i}(find(gC1{i} == intersect(gC1{i},gC1{ii}))) = 0;
                                    elseif length(find(gC1{i} == intersect(gC1{i},gC1{ii}))) == length(find(gC1{ii} == intersect(gC1{i},gC1{ii})))
                                        if length(gC1{i}) > length(gC1{ii})
                                            gC3{ii}(find(gC1{ii} == intersect(gC1{i},gC1{ii}))) = 1;
                                        else
                                            gC3{i}(find(gC1{i} == intersect(gC1{i},gC1{ii}))) = 1;
                                        end
                                    end
                                end
                            end
                        end
                    end
                    
                    % check the final counts
                    for iii = 1:length(gC3)
                        if isempty(gC3{iii}) %if we found splitting, update the final counts to reflect the splitting issues
                            gC3{iii} = gC2{iii};
                        end
                    end
                    
                    for iv = 1:length(gC1)
                        for v = 1:length(gC1{iv})
                            % sum up the matches for each group
                            if gC2{iv}(v) == 1 && gC3{iv}(v) == 1 && sum(gC2{iv}) > 1 && sum(gC3{iv}) > 1 %Only count the ones that cleared all the checks as matches
                                groupCount = groupCount + 1;
                            end
                            
                        end
                    end
                else
                    %if there is mroe than one stat group to match, count the matches in gc2
                    if sum(gC2{1}) > 1
                        groupCount = sum(gC2{1});
                    end
                end
                
                %store the shuffled group matches for this session, component, and shuffle number
                gCount{sessi,modi,compi}(shuffi) = groupCount;
                
                %Save data
                if saveData == 1
                    GEDTrans.sessNum(sessi).modi(modi).netNum(compi).meanMeth.transPts{shuffi} = ipt;
                    GEDTrans.sessNum(sessi).modi(modi).netNum(compi).meanMeth.residual{shuffi} = res;
                end
                
                close all
                gCount2(sessi,modi,compi) = mean(gCount{sessi,modi,compi}); %get the average match from the shuffles. used to assess if the observed values were better than the shuffled values
            end

            fprintf('\r')
            
            if saveData == 1
                GEDTrans.sessNum(sessi).modi(modi).netNum(compi).meanMeth.type = statType;
                GEDTrans.sessNum(sessi).modi(modi).netNum(compi).meanMeth.transPts_real = ipt_real;
                GEDTrans.sessNum(sessi).modi(modi).netNum(compi).meanMeth.residual_real = res_real;
            end
        end
    end
end

if saveData == 1
    save([saveDatFolder '\GED_trans'],'GEDTrans','gCount','gCount2','gCount_real')
end


%% count number of comps per session for each modality

for sessi = 1:length(GEDTrans.sessNum)
    for modi = 1:length(GEDTrans.sessNum(sessi).modi)
        
        if modi == 1
            vc(sessi) = length(GEDTrans.sessNum(sessi).modi(modi).netNum);
        elseif modi == 2
            tc(sessi) = length(GEDTrans.sessNum(sessi).modi(modi).netNum);
        elseif modi == 3
            ac(sessi) = length(GEDTrans.sessNum(sessi).modi(modi).netNum);
        end
    end
end

%figure 5 bar plots
figure;
subplot(311)
histogram(vc)
ylim([0 20])
ylabel('count')
xlim([-0.5 5.5])
xlabel('# comps')
title(['Vis   n = ' num2str(sum(vc))])

subplot(312)
histogram(tc)
ylim([0 20])
ylabel('count')
xlim([-0.5 5.5])
xlabel('# comps')
title(['Tac   n = ' num2str(sum(tc))])

subplot(313)
histogram(ac)
ylim([0 20])
ylabel('count')
xlim([-0.5 5.5])
xlabel('# comps')
title(['Aud  n = ' num2str(sum(ac))])
