% HPS LFP Analysis code
%{
The main goal of this code is perform generalized eigendecomposition on LFP
data collected from the amygdalae of monkey's F and B during the
multisensory processing project. The code also extracts time-frequency (TF)
power from the contacts and from the component signals.

In response to reviewer feedback, this code is being modified to 
directly compare responses between modalities.

This code was created by Mike X Cohen and Jeremiah Morrow.

mikexcohen@gmail.com
jeremiahkmorrow@gmail.com

%}


close all; clc; clear

%% Setup
%Decide if you want to save any of the outputs of this script. Note that several will be very large (up to a couple GBs per session analyzed)
saveData = false; 

%directory where the data files are held
dataDir = 'E:\Arizona\HPS\LFP_Analysis\LFP_data';
%directory where the results files will be saved
resultsDir = 'E:\Arizona\HPS\LFP_Analysis\result_final\eLife_revisionResults';

%Go to the directory with LFP data
cd(dataDir)

% Get a list of folders with the data
recsesslist = dir;

% just get the folders that have the raw data files
rej = true(size(recsesslist));
for i=1:length(recsesslist)
    if ~any(recsesslist(i).name(1) == 'bf') || ~recsesslist(i).isdir
        rej(i) = 0;
    end
end
recsesslist = {recsesslist(rej).name};
recsesslist = recsesslist';


%set how much peri-event data you want to analyze (set relative to stimulus onset)
preStim = 2000; %time in milliseconds
postStim = 2000; %time in milliseconds
baseWin = [-1500 -1000]; %time in milliseconds (relative to fixpot onset)

frex = logspace(log10(1),log10(100),80);  %Set the frequency space for the wavelets
nCyc = logspace(log10(4),log10(25),length(frex)) ./ (2*pi*frex); %Scale the width of the gaussian by the frequency of the wavelet to be used

%number of GED repetitions used to generate eigenvalue distributions
nperm = 500;
%%% Loop through all sessions 
for sessno = 1:length(recsesslist)   
    clearvars -except saveData dataDir resultsDir recsesslist preStim postStim baseWin sessno frex nCyc nperm  
    
    cd(dataDir)
    fprintf(['Working on session ' recsesslist{sessno} '...\r'])
    
    % Create a name for a file to save everything into (if wanted)
    outputfile = [ resultsDir '/' recsesslist{sessno} '_GED_results1.mat' ];
    
    % load in LFP data
    load([ recsesslist{sessno} '/' recsesslist{sessno} '_HPSData.mat' ])
    
    % setup EEG structure using the standard structure from the matlab toolbox EEGLAB 
    EEG = eeg_emptyset; % from eeglab, creates structure for data
    EEG.srate = HPSRawData.LFPData(1).data.ADFreq; %get sampling rate from lfp data file
    
    % Data were recorded with one 16 channel vprobe
    channels2use = 1:16;
    
    % remove 60 cycle noise
    for cFilt = channels2use
        d = designfilt('bandstopiir','FilterOrder',2, ...
        'HalfPowerFrequency1',59,'HalfPowerFrequency2',61, ...
        'DesignMethod','butter','SampleRate',1000);
    
        HPSRawData.LFPData(cFilt).data.Values = filtfilt(d,HPSRawData.LFPData(cFilt).data.Values);
    end
    
    % Remove some bad contacts from analysis
    if strcmpi(HPSRawData.sessDate,'120117')
        channels2use = [1:6 8 10:16];
    elseif strcmpi(HPSRawData.sessDate,'022317')
        channels2use = 1:15;
    elseif strcmpi(recsesslist{sessno},'foz_030917')
        channels2use = [1:9 11:16];
    end

    % reject trials with artifacts
    trials2remove = [];
    if strcmpi(recsesslist{sessno},'blu_062917')
        trials2remove = [362 363];
    elseif strcmpi(recsesslist{sessno},'blu_070517')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_071117')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_071317')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_071817')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_072017')
        trials2remove = 0; % ch14 a bit weird towards end
    elseif strcmpi(recsesslist{sessno},'blu_112917')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_120517')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_120717')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_120817')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_121117')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_121317')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_121517')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_121817')
        trials2remove = 568;
    elseif strcmpi(recsesslist{sessno},'blu_122017')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'blu_122217')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_012517')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_012617')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_020917')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_021017')
        trials2remove = 30;
    elseif strcmpi(recsesslist{sessno},'foz_021317')
        trials2remove = [142 203];
    elseif strcmpi(recsesslist{sessno},'foz_021517')
        trials2remove = [225 386];
    elseif strcmpi(recsesslist{sessno},'foz_021717')
        trials2remove = 36;
    elseif strcmpi(recsesslist{sessno},'foz_022117')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_022317')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_022417')
        trials2remove = 321;
    elseif strcmpi(recsesslist{sessno},'foz_022817')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_030217')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_030317')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_051617')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_051817')
        trials2remove = 449;
    elseif strcmpi(recsesslist{sessno},'foz_052317')
        trials2remove = 152;
    elseif strcmpi(recsesslist{sessno},'foz_052517')
        trials2remove = [138 139 142 287 300 600 602 605];
    elseif strcmpi(recsesslist{sessno},'foz_053017')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_060117')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_060617')
        trials2remove = [7 112 138 198 202 221 299 442];
    elseif strcmpi(recsesslist{sessno},'foz_060817')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_061317')
        trials2remove = [38 55 616];
    elseif strcmpi(recsesslist{sessno},'foz_061517')
        trials2remove = [10 81 288 330 436 485];
    elseif strcmpi(recsesslist{sessno},'foz_101217')
        trials2remove = 0;
    elseif strcmpi(recsesslist{sessno},'foz_101917')
        trials2remove = [11 199 367];
    end
      
    % preallocate data structure as contacts, time in ms, number of trials
    EEG.data        = zeros(length(channels2use), preStim+postStim+1, length(HPSRawData.goodtrial));
    EEG.baseData    = zeros(length(channels2use), baseWin(2)-baseWin(1)+1, length(HPSRawData.goodtrial));
    
    %Loop through all trials
    for triali = 1:length(HPSRawData.goodtrial)
        %Get EEG signal across all chanells for each trial
        for chani = channels2use
            EEG.allData(chani,:,triali) = HPSRawData.LFPData(chani).data.Values(HPSRawData.goodtrial(triali).stimOn - preStim : HPSRawData.goodtrial(triali).stimOn + postStim);
            EEG.allBaseData(chani,:,triali) = HPSRawData.LFPData(chani).data.Values(HPSRawData.goodtrial(triali).fixspotOn+baseWin(1) : HPSRawData.goodtrial(triali).fixspotOn + baseWin(2));
        end
        
        % divide trials by condition type condition       
        c = HPSRawData.goodtrial(triali).cndNumber;
        if c<9
            EEG.trialtype(triali) = 1;  % images
        elseif c>8 && c<17
            EEG.trialtype(triali) = 2;  % air puff
        elseif c>16 && c<25
            EEG.trialtype(triali) = 3;  % audio
        else
            EEG.trialtype(triali) = 99; % remove any trials that didn't contain an image, puff, or sound (i.e., sham trials that were present in some session)
        end 
    end
    
    %%% get the stimulus ID number for each trial
    EEG.trial_StimID = [];
    EEG.trial_StimID = [HPSRawData.goodtrial(:).cndNumber];
    
    %%% trial rejection
    if sum(trials2remove==0)~=1
        EEG.allData(:,:,trials2remove) = [];
        EEG.allBaseData(:,:,trials2remove) = [];
        EEG.trialtype(trials2remove) = [];
        EEG.trial_StimID(trials2remove) = [];
        EEG.trials = size(EEG.data,3);
    end
    
    %trial condition labels
    EEG.condlabels = {'images';'airpuff';'audio'};
    
    for modi = 1%:length(EEG.condlabels)
        clearvars -except saveData dataDir resultsDir recsesslist preStim postStim baseWin sessno EEG modi trials2remove channels2use ...
            outputfile frex nCyc GED_results nperm  
        
        %%% if you want to do modalities individually use this code
        %EEG.data = EEG.allData(:,:,EEG.trialtype == modi);
        %EEG.baseData = EEG.allBaseData(:,:,EEG.trialtype == modi);     
        %%% if you want to do modalities individually use this code
        
        
        %%% if you want to group all trials regardless of modality use this code (also, switch the 'modi' loop to only equal 1)
        EEG.data = EEG.allData;
        EEG.baseData = EEG.allBaseData;
        %%% if you want to group all trials regardless of modality use this
        
        
        EEG.xmin    = -preStim/1000;            %time in seconds before stimulus delivery
        EEG.xmax    = postStim/1000;            %time in seconds after stimulus delivery
        EEG.trials  = size(EEG.data,3);         %total number of trials in the session
        EEG.times   = -preStim:postStim;        %peri-event time points in milliseconds (ms)
        EEG.pnts    = size(EEG.data,2);         %total number of peri-event time points
        EEG.pnts_base= size(EEG.baseData,2);    %total number of baseline time points
        EEG.nbchan  = 16;                       %a 16-channel v-probe was used in all experiments
        EEG.setname = recsesslist{sessno};      %Name of the recording session
        
        % common average rereference
        EEG.data = bsxfun(@minus,EEG.data,mean(EEG.data,1)); %subtracting off mean across all channels
        EEG.baseData = bsxfun(@minus,EEG.baseData,mean(EEG.baseData,1)); %same for baseline

        % baseline subtraction
        EEG.data = bsxfun(@minus, EEG.data, mean( mean( EEG.baseData, 2), 3) ); %get baseline means across the timepoints and trial dimensions and subtract from trial window
        
        % Event related activity on each channel (as seen in figure 1 and 2b) can be generated from EEG.data at this point
        erp = mean(EEG.data,3);
        
        
        %%% GED for pre-cue vs. post-stim
        % list times in ms for 'reference' and 'signal' periods to compare
        tidx = dsearchn(EEG.times',[baseWin(1) baseWin(2) 0 1000]');
        
        % preallocate covariance matrices
        [covB,covT] = deal( zeros(length(channels2use)) );
        
        %Generate the covariance matrix for each trial
        for triali=1:EEG.trials
            
            % baseline covariance
            tdat = squeeze(EEG.baseData(:,:,triali));
            tdat = bsxfun(@minus,tdat,mean(tdat,2));
            covB = covB + tdat*tdat'/diff([baseWin(1),baseWin(2)]);
            
            % task covariance
            tdat = squeeze(EEG.data(:,tidx(3):tidx(4),triali));
            tdat = bsxfun(@minus,tdat,mean(tdat,2));
            covT = covT + tdat*tdat'/diff(tidx(3:4));
        end
        
        % these matrices can be plotted using imagesc to create panels like those in figure 2a
        covB = covB/triali; %normalize by the total number of trials
        covT = covT/triali; %normalize by the total number of trials
        
        % generalized eigendecomposition to get the components
        regu_gam = .01; % regularization
        covBr = (1-regu_gam)*covB + regu_gam*mean(eig(covB))*eye(length(channels2use));
        [evecs,evals] = eig(covT,covBr); %Use the built-in 'eig' function to perform generalized eigendecomposition
        [evals,sidx] = sort(diag(evals),'descend'); %sort the eigenvalues in descending order
        evecs = evecs(:,sidx); %arrange the eigenvectors according to their eigenvalues
        
        % norm the evecs
        evecs = bsxfun(@rdivide,evecs,sqrt(sum(evecs.^2,1))); %these eigenvectors (and the eigenvalues in evals) can be visualized using imagesc to create images like those in figure 2a
        
        % generate component maps and adjust sign so that largest contact weight is positive
        maps = covT * evecs;
        for mi=1:size(maps,2)
            tmp = maps(:,mi);
            tmp = tmp./norm(tmp);
            [~,smax] = max(abs(tmp));
            maps(:,mi) = tmp * sign(maps(smax,mi)); %adjust sign if needed. the columns of these matrices make the maps seen in figure 2 d and 2f and in figure 3
        end
        
        %%% select components based on eigenvalues
        Lpct = [100*(evals(1:end-1)-evals(2:end))./evals(2:end);0];
        
        %%% find significance threshold based on permutation testing
        permevals = zeros(length(channels2use),nperm);
        for permi = 1:nperm
            
             % preallocate covariance matrices
            [covB_R,covT_R] = deal( zeros(length(channels2use)) );
            
            for triali = 1:EEG.trials
                
                rns = randsample(-preStim:1000,2); %take 2 random start times
                tidx = dsearchn(EEG.times',[rns(1) rns(1)+500 rns(2) rns(2)+500]'); %create two covariance matrices from these randomly selected time windows
                
                % baseline covariance
                tdat = squeeze(EEG.data(:,tidx(1):tidx(2),triali));
                tdat = bsxfun(@minus,tdat,mean(tdat,2));
                covB_R = covB_R + tdat*tdat'/nperm;
                
                % task covariance
                tdat = squeeze(EEG.data(:,tidx(3):tidx(4),triali));
                tdat = bsxfun(@minus,tdat,mean(tdat,2));
                covT_R = covT_R + tdat*tdat'/nperm;
            end
            
            %Regularization 
            covBr = (1-regu_gam)*covB_R + regu_gam*mean(eig(covB_R))*eye(length(channels2use));
            % normalize by number of trials and perform generalized eigendecomposition to get the components
            evalsT = eig(covT_R/triali,covBr/triali);
            % sort the eigenvalues
            permevals(:,permi) = sort(evalsT,'descend');
            
        end
        % Rarely, a decomposition produces infinite values. This prevents further analysis so these are removed
        permevals(~isfinite(permevals)) = NaN;
        
        % determine the significant components
        evalsZvals = (evals-nanmean(permevals(1,:))) / nanstd(permevals(1,:));
        sigComps   = evalsZvals>abs(norminv(.05)) & Lpct>1;
        ncomps = sum(sigComps); % the number of components per sessionwere plotted as a histogram in figure 2h 

        % skip this dataset if no components
        if sum(sigComps)==0
            tf_raw_all = [];
            tf_baseDiff_avg = [];
            tf_baseDiff_Tri = [];
            basepower = [];
        else
            % get time series of components and add as extra channels
            EEG.data(EEG.nbchan+1:EEG.nbchan+ncomps,:,:) = reshape( evecs(:,1:ncomps)' * reshape(EEG.data,length(channels2use),[]) ,[ncomps EEG.pnts EEG.trials]);
            EEG.baseData(EEG.nbchan+1:EEG.nbchan+ncomps,:,:) = reshape( evecs(:,1:ncomps)' * reshape(EEG.baseData,length(channels2use),[]) ,[ncomps EEG.pnts_base size(EEG.baseData,3)]);
            EEG.nbchan = size(EEG.data,1);
            
            %%% remove non-component time series if you want to only focus on component related analyses. averages of component time series across trials (3rd dimension) were plotted in figure 2 e and 2g. 
            %EEG.data(1:16,:,:) = [];
            %EEG.baseData(1:16,:,:) = [];
            %EEG.nbchan = size(EEG.data,1);
            %%% remove non-component time series if you want to only focus on component related analyses. averages of component time series across trials (3rd dimension) were plotted in figure 2 e and 2g. 

            %%% TF decomposition using Morlet wavelet convolution
            fprintf('Performing time-freq decomposition ...\r')
            
            % prepare for wavelets and convolution
            wavtime     = -2:1/EEG.srate:2; %number of time points in the wavelet
            halfwav     = floor(length(wavtime)/2)+1; %half the time points in the wavelet (will be subtracted from the end product of convolution)
            nData       = EEG.pnts * EEG.trials; %total number of data points across all trials in the peri-event LFP data
            nData_base   = EEG.pnts_base * size(EEG.baseData,3); %total number of data points across all trials in the peri-event LFP data
            nWave       = length(wavtime); %total number of points in the wavelet
            nConv       = nData+nWave-1; %number of points that will result from the convolution of the wavelet and the LFP data
            nConv_base   = nData_base+nWave-1; %number of points that will result from the convolution of the wavelet and the baseline LFP data
            
            times2save = -1500:20:1500; %Create an index of time points to extract values at
            times2saveidx = dsearchn(EEG.times',times2save'); %Find and index these time points in the ms precision EEG data points
           
            % pre-allocate variables
            tf_raw_all = zeros(EEG.nbchan,length(times2save),EEG.trials,length(frex));
            tf_baseDiff_avg = zeros(EEG.nbchan,length(times2save),length(frex));
            tf_baseDiff_Tri = zeros(EEG.nbchan,length(times2save),EEG.trials,length(frex));
            tfMod = zeros(EEG.nbchan,length(EEG.condlabels),length(frex),length(times2save));
            
            basepower = zeros(EEG.nbchan,length(frex));
            
            % Fast fourier transform of data
            dataX = fft(reshape(EEG.data,EEG.nbchan,[]),nConv,2);
            dataX_base = fft(reshape(EEG.baseData,EEG.nbchan,[]),nConv_base,2);
            
            
            % Loop through all frequencies
            for fi = 1:length(frex)
                fprintf(['   Freq ' num2str(frex(fi)) '...\r'])
                
                % create normalized complex morlet wavelet
                cmwX = fft( exp(1i*2*pi*frex(fi)*wavtime) .* exp( (-wavtime.^2)/(2*nCyc(fi)^2) ), nConv); %create the complex morlet wavelet
                cmwX = cmwX./max(cmwX); %amplitude scaling of power spectrum (normalize the wavelet)
                
                cmwX_base = fft( exp(1i*2*pi*frex(fi)*wavtime) .* exp( (-wavtime.^2)/(2*nCyc(fi)^2) ), nConv_base); %create the complex morlet wavelet
                cmwX_base = cmwX_base./max(cmwX_base); %amplitude scaling of power spectrum (normalize the wavelet)
                
                % convolve the complex morlet wavelet with the transformed LFP data (as = analytic signal)
                as = ifft( bsxfun(@times,dataX,cmwX) ,[],2); % really this is multiplication of the transformed data and the wavelet in the frequency domain but this is mathematically identical to time domain convolution (but is much faster)
                as = reshape( as(:,halfwav:end-halfwav+1) ,size(EEG.data) ); %reshape into channels-by-timepoints-by-trials
                
                % convolve the complex morlet wavelet with the transformed LFP data (as = analytic signal) for the baseline signal
                as_b = ifft( bsxfun(@times,dataX_base,cmwX_base) ,[],2); % really this is multiplication of the transformed data and the wavelet in the frequency domain but this is mathematically identical to time domain convolution (but is much faster)
                as_b = reshape( as_b(:,halfwav:end-halfwav+1) ,size(EEG.baseData) ); %reshape into channels-by-timepoints-by-trials
                
                % baseline power (used later to normalize peri-event power to baseline)
                basepower(:,fi) = mean(mean(abs( as_b ).^2,3),2);
                
                %power at varying frequencies over peri-event time
                tf_raw_all(:,:,:,fi) = abs(as(:,times2saveidx,:).^2); %difference in power between baseline and stimulation at various frequencies for each trial
                tf_baseDiff_avg(:,:,fi) = 10*log10( bsxfun(@rdivide, mean(abs(as(:,times2saveidx,:).^2),3) , basepower(:,fi)) ); %difference in power between baseline and stimulation at various frequencies for each trial
                tf_baseDiff_Tri(:,:,:,fi) = 10*log10( bsxfun(@rdivide, abs(as(:,times2saveidx,:).^2) , basepower(:,fi)) ); %difference in power between baseline and stimulation at various frequencies for each trial
                
                % to create images from figure 4...
                for condi = 1:length(EEG.condlabels) % ... loop through the different trial types...
                    % ... and get TF power for each modality averaged across trials
                    tfMod(:,condi,fi,:) = 10*log10( bsxfun(@rdivide, mean(abs(as(:,times2saveidx,EEG.trialtype==condi).^2),3) , basepower(:,fi)) ); %difference in power between baseline and stimulation at various frequencies for each trial. Data from this variable was used to create images in figure 4a-f and figure 4-s1 (note that statistical thresholding is done in source code 2)
                end
            end
        end
        
        % create a variable to group results by modality 
        if saveData == true
            % choose what results you want to save. subsequent analysis mainly make use of tf_baseDiff_Tri
            GED_results(modi).tf_baseDiff_Tri = tf_baseDiff_Tri;
            
            % tfMod is used for the analyses in figure 4 (just note that saving all of the 'tf_' variables will produce a fairly large file)
            %GED_results(modi).tfMod = tfMod;
            
            GED_results(modi).ncomps = ncomps;
            GED_results(modi).evalsZvals = evalsZvals;
            GED_results(modi).Lpct = Lpct;
            GED_results(modi).maps = maps;
            GED_results(modi).evecs = evecs;
            GED_results(modi).evals = evals;

        end
    end
    
    if saveData == true
        save(outputfile,'GED_results','frex','nCyc','preStim','postStim','baseWin','trials2remove','channels2use')
    end
    % data from the excel file 'allChanLocations_final' can be used to
    % re-run this code include or exclude recording contacts based on their
    % estimated location (inside or outside the amygdala)
end

