% This script calculates wavelet results from previously saved .mat files of data. 
% requires the "morletTransform" and "totalCoh" codes to run
% Author: Dr. Jean-Sebastien Blouin, September 2011
% Modified: Dr. Romain Tisserand, Januray 2017

clear all; close all;
tic
  
  files=['Coherence between Fy and SVS - Gait Initiation']; % Experimental condition
  subs=['03';'04';'05';'06';'07';'08';'09';'10';'12';'13']; % Subjects
  musc=['Fx';'Fy';'SVS']; % Data used for analysis
  maxmusc=size(musc,1);

  muscleList = [1 8];%[9 16]; %9:8+maxmusc; % which muscles you want to process
  
  % load data
  fs=1000; % sampling frequency % TO ADJUST !! to the analyzed signal
  load('C:\Users\user\Documents\Projects\Locomotor_transitions\Data\CoherenceAnalysis\Matrixes\Initiation_HF_Normalized.mat'); 

  doPlot = 2; %  1 plot results while running, 2 also saves 
  pLevel = .01; % remove anything with more p-value than this
  
  % set the scale of the wavelet display (0.1 for 2-30 Hz, decrease for 10-25Hz)
  maxCohPerSubj = 0.25;
  
  
  doRemovePadding = 0; % 1: take off 25% each step, 2: zero outer 25%,
                       % otherwise do nothing

  numsteps=100; % Number of trials available to one subject

  %Parameters for wavelets
  fmin = 1;  % Frequency range - lower limit
  fmax = 20; % Frequency range - upper limit

  r_f = 1; % freq. resolution, decrease for less resolution, normally 1.

  resolutionPar = 1; % increase for increased freq. resolution. Resolution
                     % is (f/5)/resolutionPar for freq.,
                     % (1/f)*resolutionPar for time.

  lagTimeList = 200/1000; %25 50 75 100 125 150 175 200 225 250 275 300 325 350]/1000; % lag in seconds

  resampleN = 0; % accuracy of resampling (default 3, but can set
                 % to 0 for faster execution during debugging. MATLABS
                 % default of 10 is overkill (checked))

  numSubjects = 10; % Number of subjects
  numTrials = 1;
  opengl neverselect  
  
  % Loaded data are transformed in a structure called "steps"
  if(doRemovePadding == 1)
    steps = unpadSteps(Matrix);
  elseif(doRemovePadding == 2)
    steps = clearPadding(Matrix);
  else
    steps = Matrix;
  end
  clear Matrix;
  
  % Determination of the average trial length for normalization
  % find smallest step and largest and mean
  steplengths = zeros(numTrials*numSubjects*numsteps,1);
  k=1;
  standardStepLength = inf;
  mlargest = -1;
  for trial = 1:numTrials
    for subj = 1:numSubjects
      for step = 1:numsteps;
        str = ['tmp = steps(' int2str(trial) ').T' ...
               int2str(step) ';'];
        eval(str);
        len = length(tmp(:,1));
        if(len<standardStepLength)
          standardStepLength = len;
        end
        if(len>mlargest)
          mlargest = len;
        end
        steplengths(k) = len;
        k = k + 1;
      end
    end
  end

  meansl = mean(steplengths);

  % standardStepLength is to which everything will be standardized: select mean
  standardStepLength = round(meansl);

  % build freq. list for wavelet transform
  freqList = [fmin];
  k = 2;
  while(1)
    freqList(k) = freqList(k-1)/(10*r_f*resolutionPar) + freqList(k-1);
    if(freqList(k)>fmax)
      break;
    end
    k = k + 1;
  end

  % for plotting:
  t = (0:standardStepLength-1)/fs*1000;
  nfreq = length(freqList);
  nlags = length(lagTimeList);
  % compute layout of lags plot
  nlagRows = floor(sqrt(nlags));
  nlagCols = ceil(nlags/nlagRows);
  
 for m=1;% m=1 for Fx (anteroposterior force) and m=2 for Fy (mediolateral force)
    for f=1%:1:4 %for all trials Only doing 1 trial at a time
    
      % average over subjects
      Wav1=zeros(nfreq,standardStepLength,nlags);
      Wav2=zeros(nfreq,standardStepLength,nlags);
      Wav12=zeros(nfreq,standardStepLength,nlags);

      for s=1:numSubjects
        pause(.1);
        
        SVSchan=3; % EVS signal is the third column of each trial
        
        % per subject
        W1=zeros(nfreq,standardStepLength,nlags);
        W2=zeros(nfreq,standardStepLength,nlags);
        W12=zeros(nfreq,standardStepLength,nlags);
        maxTotalCoh = zeros(nlags,1);
        
        lagNo = 1;
        for lagTime = lagTimeList

          for stepno=1:1:numsteps 
            
            evalstr=['yin=steps(s).T' int2str(stepno) '(:,SVSchan);']; % EVS signal is the input
            eval(evalstr);
            evalstr=['yout=steps(s).T' int2str(stepno) '(:,m);']; % Force signal is the output
            eval(evalstr);
            nsamples=length(yin);
            %yin = randn(nsamples,1);
            Wb1 = morletTransform(yin,fs,resolutionPar,freqList);
            Wb2 = morletTransform(yout,fs,resolutionPar,freqList);
            
            % implement lag (before resampling)
            lagN = fix(fs*lagTime);
            Wb2shifted = circshift(Wb2',-lagN)';
            
            %normalize to standard trial length (standardStepLength)
            Wa1 = resample(Wb1',standardStepLength,nsamples,resampleN);
            Wa1 = Wa1';
            Wa2shifted = resample(Wb2shifted',standardStepLength,nsamples,resampleN);
            Wa2shifted = Wa2shifted';
            
            %Add the current step to the running total
            W1(:,:,lagNo) = W1(:,:,lagNo) + abs(Wa1).^2;
            W2(:,:,lagNo) = W2(:,:,lagNo) + abs(Wa2shifted).^2;
            W12(:,:,lagNo) = W12(:,:,lagNo) + Wa1.*conj(Wa2shifted);
          end
          %compute coherence for this subject and lag
          coh12 = (abs(W12(:,:,lagNo).^2)) ./ (W1(:,:,lagNo).*W2(:,:,lagNo)); 
          gain12 = (abs(W12(:,:,lagNo))) ./ (W1(:,:,lagNo));
          cohlags12(:,:,lagNo) = coh12;
          
          % plot on subject plot
          if(doPlot>0)
            % Non-significant coherence is brought to zero
            thresh = 1.*(1-pLevel^(1/(numsteps-1))); %1.6
            cohTh = coh12;
            indTh = find(cohTh<thresh);
            cohTh(indTh) = 0;
            [totalCoherence,sumCoh,maxTotalCoh(lagNo),maxLocation] = totalCoh(cohTh,freqList);     
            relMaxLocPerc = 100*maxLocation/length(t);
          end
          lagNo = lagNo + 1;
        end
        
        % equalize axis on total coherence plot
        if(doPlot>0)
          maxTotalCohAll = max(eps,max(maxTotalCoh));
          for lagNo=1:nlags
%             subplot(nlagRows,nlagCols,lagNo);
            axis([t(1) t(end) 0 maxTotalCohAll]);
          end
        end
        
        % Add subject data to running average
        Wav1 = Wav1 + W1;
        Wav2 = Wav2 + W2;
        Wav12 = Wav12 + W12;
      end
      
      % Compute coherence across subject for each lag
      cohAv12 = (abs(Wav12.^2))./(Wav1.*Wav2); 
          
      % Plot figure with coherence per lag
      if(doPlot>0)
        figure(3)
        for lagNo = 1:nlags
          subplot(nlagRows,nlagCols,lagNo);
          
          % Non-significant coherence brought to zero
          thresh = 1.*(1-pLevel^(1/(numSubjects*numsteps-1)));
          cohTh = shiftdim(cohAv12(:,:,lagNo));
          indTh = find(cohTh<thresh);
          cohTh(indTh) = 0;
          cohAv12(:,:,lagNo) = cohTh;
          pcolor(t(:,:),freqList,cohAv12(:,:,lagNo)); shading interp;
        end
        
        figure(4)
        for lagNo = 1:nlags
          subplot(nlagRows,nlagCols,lagNo);
          [totalCoherence,sumCoh,maxTotalCoh(lagNo),maxLocation] = totalCoh(cohAv12(:,:,lagNo),freqList);     
          relMaxLocPerc = 100*maxLocation/length(t);
          plot(t,totalCoherence);
        end
        
        % equalize axis on total coherence plot
        maxTotalCohAll = max(eps,max(maxTotalCoh));
        for lagNo = 1:nlags
          subplot(nlagRows,nlagCols,lagNo);
          axis([t(1) t(end) 0 maxTotalCohAll]);
        end        
      end      
    end
  end
  
  toc
  