% Script for eLife manuscript "Dopaminergic, neural and computational contributions to
% probabilistic reward learning in old age"
% this script extracts the timecourses from the spm files and then also
% cuts it in pieces to extract peri-stim timecourses
%
%%
% If your goal is to recreate figure 1c, start from line 298
%%
%
% Miriam C. Klein, April 2010
% Adapted by Lieke de Boer, July 2015

addpath 'C:\Experiments\DAD\scripts\timecourse\';

clear;

group           = [ones(1,27)*2 ones(1,30)];
subjins         = {'D02','D03','D04','D05','D08','D11','D13','D14','D15','D16','D17','D18','D19','D20','D21','D22',...
    'D23','D24','D25','D26','D29','D31','D33','D60','D61','D84','D85','D34','D35','D36','D37','D38',...
    'D39','D40','D42','D43','D46','D47','D48','D49','D50','D52','D55','D58','D62','D63','D64','D65',...
    'D66','D67','D70','D72','D80','D82','D83','D86','D90'};

subjins_behav   =   {'dad02','dad03','dad04','dad05','dad08','dad11','dad13','dad14','dad15','dad16','dad17','dad18',...
    'dad19','dad20','dad21','dad22','dad23','dad24','dad25','dad26','dad29','dad31','dad33','dad60',...
    'dad61','dad84','dad85','dad34','dad35','dad36','dad37','dad38','dad39','dad40','dad42','dad43',...
    'dad46','dad47','dad48','dad49','dad50','dad52','dad55','dad58','dad62','dad63','dad64','dad65',...
    'dad66','dad67','dad70','dad72','dad80','dad82','dad83','dad86','dad90'};

behavdir        = 'C:\Experiments\DAD\Behaviour\TAB\';
basedir         = 'C:\Experiments\DAD\TAB_art\';

maskdir         = 'C:\Experiments\DAD\TAB_art\2ndLvL\beta_QchRout\1sample_Q\';

tcFold          = 'timecourse_RQVCd';

% SET MASK (ROI) TO EXTRACT TS FROM
% ------------------------------------------------------------------- %
masknames = {'z_vmPFC_Qonly_fwe05'}
% --------------------------------------------------------------------------

cd(basedir);
nsess           = 2;
TR              = 2;

% SET PARAMETERS
% ------------------------------------------------------------------- %
% indicate what analysis steps to do
physio_correct  = 1;

% decide whether to do it based on peak voxel or roi
include_surrounding_voxels  = 0; % (i.e. 3x3x3 = 27 voxels around coordinate)
roi_extract                 = 1;

% Create shift to determine by how much to shift by
%% as our sequence is ascending and vstriatum is 1/3 from botom we want to shift by 1/3 of the TR
shifting        = 1;
sliceshift      = 0; %-1.68; % i.e. if slice time correct to mid slice shift by half a TR  lieke: changed from -1 to 0 because we have already slice time correct in relation to the striatum

% duration of phases
trial_dur = 15; % specify how long your phase. length of one trial from stim pres to outcome pres

% SET VOXEL COORDINATE
% ------------------------------------------------------------------- %
% Which voxel to extract the timecourse for at the moment relies on group
% peak
dummy_vxl = [99,99,99];
vxl_to_extract_from = dummy_vxl;

% colors
bl=[0 0 1];
gr=[0 1 0];
re=[1 0 0];
ye=[1 1 0];
kl=[0 0 0];
wh=[1 1 1];
cy=[0 1 1];
ma=[1 0 1];
or=[0.905882358551025 0.509803950786591 0];

% setting names just for saving ROI timecourses - you can change these
if roi_extract
    roi_name = masknames;
else
    roi_name = ['peak_',num2str(vxl_to_extract_from(1)),'_',num2str(vxl_to_extract_from(2)),'_',...
        num2str(vxl_to_extract_from(3))];
end
%%
for roi_indx = 1:length(roi_name)
    all_subs_ssso = []; all_subs_dsso = [];   all_subs_dsdo = []; all_subs_ddso = [];
    % does the directory for this ROI already exist
    if ~exist([basedir tcFold filesep roi_name{roi_indx}],'dir')
        mkdir([basedir tcFold filesep roi_name{roi_indx}])
    end
    
    j=0;
    
    for i=1:length(subjins); j=j+1; %
        subj      = subjins{i};
        subj_b    = subjins_behav{i};
        
        % ------------------------------------------------------------------- %
        % load images of one session at a time to get  timecourse
        ps            = [];
        All_ITI_t     = [];
        ps_tmp        = []; %remove after H checkings
        stimcond.ssso = [];
        stimcond.dsso = [];
        stimcond.ddso = [];
        stimcond.dsdo = [];
        
        %% seperate sessions
        for sess = 1:nsess
            sssoIdx = [];
            dssoIdx = [];
            ddsoIdx = [];
            dsdoIdx = [];
            
            if exist([basedir 'timecourse_RQVCd' filesep roi_name{roi_indx} filesep subj '_run_' num2str(sess) '_' roi_name{roi_indx} '.mat'])
                % if tc has previously been extracted, just load it, it will be called 'tc'
                load([basedir 'timecourse_RQVCd' filesep roi_name{roi_indx} filesep subj '_run_' num2str(sess) '_' roi_name{roi_indx} '.mat']);
                
            else
                scandir = [basedir subjins{i} filesep 'run' num2str(sess)];
                V   = spm_select('ExtList',scandir , '^sw.*\.nii$', 1:330); %This lists all smoothed/preprocessed epis
                
                for epi = 1:length(V) %V2 is a list of all the directories of all the swuabf files listed in V
                    V2(epi,:) = [basedir,subjins{i} filesep 'run' num2str(sess) filesep V(epi,:)];
                end
                
                % ------------------------------------------------------------------- %
                % read all functional images from this block
                disp(['Extract timecourse for subject ',subj, ' run',num2str(sess)]);
                vol = spm_read_vols(spm_vol(V2)); % vol is 4D, time x x/y/z
                
                % ------------------------------------------------------------------- %
                % exctract timecourse / convert to mm voxel first (from MNI
                % space)
                v = spm_vol(V2(1,:));
                if ~roi_extract
                    loc = round(inv(v.mat)*[vxl_to_extract_from 1]');
                    tc = squeeze(vol(loc(1), loc(2), loc(3),:));
                end
                
                % extract timecourses of all neighboring voxels and average
                % into one
                if include_surrounding_voxels
                    neighbor_vxls = [[1 0 0];[1 -1 0];[0 -1 0];[-1 -1 0];[-1 0 0];[-1 1 0];[0 1 0];[1 1 0];...
                        [1 0 1];[1 -1 1];[0 -1 1];[-1 -1 1];[-1 0 1];[-1 1 1];[0 1 1];[1 1 1];...
                        [1 0 -1];[1 -1 -1];[0 -1 -1];[-1 -1 -1];[-1 0 -1];[-1 1 -1];[0 1 -1];[1 1 -1];...
                        [0 0 1];[0 0 -1]];
                    for k=1:length(neighbor_vxls) % every voxel has 26 around it
                        tc(:,k+1) = squeeze(vol(loc(1)+neighbor_vxls(k,1), loc(2)+neighbor_vxls(k,2), loc(3)+neighbor_vxls(k,3),:));
                    end
                    tc = mean(tc,2);
                    
                elseif roi_extract
                    % or extract tc of the ROI specified
                    maskLoc = spm_select('List',maskdir,[masknames{roi_indx},'.*\.img$']);
                    maskImg = spm_vol([maskdir,maskLoc]);
                    maskVol = spm_read_vols(maskImg);% gets header information for the image
                    maskIDX = find(maskVol);
                    for epi = 1:length(V)
                        tmpvol = vol(:,:,:,epi);
                        tc(epi) =  mean(tmpvol(maskIDX)); %tc is the magnitude of the timecourse at each volume acquired, averaged across ROI volume
                    end
                    nVoxels = length(maskIDX);
                    fprintf(['Number of voxels in ROI: ',num2str(nVoxels),'\n']);
                    
                end
                
                save([basedir tcFold filesep roi_name{roi_indx} filesep subj '_run_' num2str(sess) '_' roi_name{roi_indx}],'tc');
                
            end
            
            % ----------------------------------------------------------------------- %
            % Convert stim/rew presentation lengths into phase length
            % i.e. using the timecourse duration specified in seconds, this
            % calculates the number of TRs this is equivalent to
            pl_trial(sess,:)    = round(trial_dur*10/TR);    %phase_length_pres plm - telling us how many data point per trial/event we need
            pl_trialm(sess,:)   = pl_trial(sess,:) -1;       %plpm
            
            
            % ----------------------------------------------------------------------- %
            % Load stim file to determine which trial was which
            load([basedir tcFold filesep subj filesep 'TC_onsets_' subj '.mat']);
            
            % -----------------------------------------------------------------------
            % Establish onset times for all conditions
            Choice  = Choice{sess};
            Qc      = Qchosen{sess};
            %             V    = V{sess};
            R       = R{sess};
            
            sssoIdx = Choice; dssoIdx = Qc;  dsdoIdx = R;
            stimcond.ssso = [stimcond.ssso;sssoIdx(2:end-4)]; %Cut off end trials as they won't fit in 12 second time bins
            stimcond.dsso = [stimcond.dsso;dssoIdx(2:end-4)]; %Q chosen (red)
            stimcond.dsdo = [stimcond.dsdo;dsdoIdx(2:end-4)]; % reward (blue)
            
            
            % %Example matrix- used for regressions against extracted
            % timecourse
            mat1 = [stimcond.ssso,stimcond.dsso,stimcond.dsdo,ones(length(stimcond.dsdo),1)];

            % ----------------------------------------------------------------------- %
            % load timeseries and EVs, upsample time series
            ts=tc;
            
            %correct the timeseries to be in %signal change from baseline:
            if mean(ts)~=0 %& ~use_eigenvariate %don't do this for eigenvariates
                ts=(ts-mean(ts))/mean(ts)*100;
            end
            
            if physio_correct
                [ts] = orthogonalise_to_physio(ts, basedir, subj, sess, TR);
                [ts] = normalise(ts);
            end
            
            if shifting
                % Try not extracting the last 2 trials
                tot_trials = length(Choice)-4; % 3 at end removed
                %One at start further removed to avoid time idx going negative
                Choice(2:(tot_trials)) = Choice(2:(tot_trials))+sliceshift;
            end
            trial_starts = Choice(2:(tot_trials)); % remove first trial here and end two above
            
            clear Qchosen Qc R Choice
            
            t           = 0:length(ts)-1; %  minus one as one behav trial is lost in order to realign
            t_ups       = 0:0.1:length(ts)-1;
            ts_ups      = spline(t,ts,t_ups);        % upsample the fMRI timeseries to 3.33 Hz
            
            %trialT is the trial start times of each trial converted into TRs
            %multiply by 10 due upsampling
            trialT      = round(trial_starts(1:end,1)*10/TR+1);  % in TRs, the start time of each trial
            
            
            % ----------------------------------------------------------------------- %
            % Format the timecourse in desired output matrix!!!
            % rows = time course of one trial
            
            tmp       = repmat(trialT(1:end),1,pl_trial(sess,1))  +  repmat(0:pl_trialm(sess,1),length(trialT),1);
            sessps     = ts_ups(tmp); %this is required for 'design matrix'
            ps         = [ps;ts_ups(tmp)]; % ps now contains time course for all three blocks
            
        end
              
        
        %% TO USE WHEN DESIGN MATRIX IS ADDED for regressions
        matrix = mat1;
        
        %USE THIS
        [c,v,t]=ols(ps(1:end,:),matrix(1:end,:),eye(size(matrix,2)));
        
        Copes(j,:,:)=c;
        
        clear ts*
        % ------------------------------------------------------------------- %
        %ps is a matrix with all trial types in it, i.e. B*3 but with the
        %first trial from each block deleted due to time shift.  There are 30
        %columns which correpsond to time (i.e. upsampled time)
        alll = ps;
        all2 = ps_tmp;
        
        save([basedir,tcFold,'/',roi_name{roi_indx},'/',subj,'_tc_conds'],...
            'ds*','ss*','all');
        
        all_ssso = mean(stimcond.ssso,1);
        all_dsso = mean(stimcond.dsso,1);
        all_ddso = mean(stimcond.ddso,1);
        all_dsdo = mean(stimcond.dsdo,1);
        
        clear ds* dd* ss* alll all2 ps_tmp ts tc tmp s ps*
        
        
        % ----------------------------------------------------------------------- %
        %                     PLOT RESULTING TIMECOURSES
        % ----------------------------------------------------------------------- %
        
        
        % ALL SUBS OLD
        all_subs_ssso = [all_subs_ssso; all_ssso]; %#ok<*AGROW> %This just put the average for each subject into one matrix
        all_subs_dsso = [all_subs_dsso; all_dsso];
        all_subs_ddso = [all_subs_ddso; all_ddso];
        all_subs_dsdo = [all_subs_dsdo; all_dsdo];
        
    end
    roi_Copes{roi_indx} = Copes; %#ok<*SAGROW>
    roi_all_subs_ssso{roi_indx} = all_subs_ssso;
    roi_all_subs_dsso{roi_indx} = all_subs_dsso;
    roi_all_subs_ddso{roi_indx} = all_subs_ddso;
    roi_all_subs_dsdo{roi_indx} = all_subs_dsdo;
    clear Copes
end

vmPFC = roi_Copes{1};

%% run from here to recreate plot 3c (with data tcplot in the same folder)
load tcplot_fig1c
all_tc_rois = {vmPFC};
all_tc_rois_names = {'vmPFC'};

for roi_tcidx = 1:length(all_tc_rois)
    
    T = all_tc_rois{roi_tcidx};
    
    Q_all=squeeze(T(:,2,:));
    
    %young
    Q_young=squeeze(T(group==1,2,:));
    figure;plot_JackKnife(Q_young,re,1,TR); % difference in Qs
    title(['young' all_tc_rois_names(roi_tcidx)]);
    axis([0 15 -0.5 1.5])
    
    %old
    Q_old=squeeze(T(group==2,2,:));
    figure;plot_JackKnife(Q_old,re,1,TR);
    title(['old' all_tc_rois_names(roi_tcidx)]);
    axis([0 15 -0.5 1.5])
    
    %old and young
    figure; plot_JackKnife(Q_young, gr, 1, TR)
    hold on
    plot_JackKnife(Q_old, bl, 1, TR)
    title(['old and young' all_tc_rois_names(roi_tcidx)]);
    axis([0 15 -0.5 1.5])
end
