% run thrugh all subjects and fit them using DDM

clear all
close all

%% extract data
lo = load('FCP3_data');
nSubs=length(lo.choice);
subs=1:nSubs;
% subs=setdiff(subs,lo.badSubs);
nSubs=length(subs);
X = [];
Y = [];
for sub=subs
    % extract prior value difference and average item value
    vL=(lo.choice{sub}.rating1_L'-.9)./99.11;
    vR=(lo.choice{sub}.rating1_R'-.9)./99.11;
    ease{sub}=abs(vL-vR);
    meanV{sub} = (vL+vR)/2;
    % extract average value prior certainty
    cL=(lo.choice{sub}.certainty1_L'-.9)./99.11;
    cR=(lo.choice{sub}.certainty1_R'-.9)./99.11;
    cert{sub}=(cL+cR)./2;
    % extract consequential choice tag
    consequential{sub} = lo.choice{sub}.reward';
    % extract penalized choice tag
    penalty{sub} = lo.choice{sub}.penalty';
    % extract choice confidence
    CC{sub}=(lo.choice{sub}.confidence'-.9)./99.11;
    % extract subjective effort rating
    EF{sub}=(lo.choice{sub}.effort'-.9)./99.11;
    % extract log(RT)
    RT{sub}=log(lo.choice{sub}.RT');
    % exptract spreading-of-alternatives
    SP{sub}=(lo.choice{sub}.spreading'-.9)./99.11;
    % extract change-of-mind
    COM{sub} = lo.choice{sub}.changing_mind';%+lo.choice{sub}.error';
    % extract value precision gain
    cL2=(lo.choice{sub}.certainty2_L'-.9)./99.11;
    cR2=(lo.choice{sub}.certainty2_R'-.9)./99.11;
    cert2=(cL2+cR2)./2;
    dcert{sub} = cert2-cert{sub};
    % store MCD-relevant value prior moments
    X = [X,zscore(vec(ease{sub}))];
    Y = [Y,zscore(vec(cert{sub}))];
end


%% run subject-by-subject model inversion (all data)
logFlag = 1;
verbose = 0;
i12 = setdiff(1:10,8);
Ps = NaN(10,nSubs,3);
for sub=subs
    % set common DDM inputs
    rt = exp(RT{sub});
    co(COM{sub}==0) = 1; % choice confirms prior prefs
    co(COM{sub}==1) = -1; % choice goes against prior pref
    Xbounds = [consequential{sub}',penalty{sub}'];
    % set DDM1
    dr = abs(ease{sub});
%     Xparams.v = [ones(length(dr),1),dr(:)];
%     [ddm] = DDM_fitRT_Xparams(rt,co,Xparams,verbose,0,logFlag);
    [ddm] = DDM_fitRT_movingBound(rt,co,dr,verbose,0,Xbounds,[],logFlag);
    posteriorSUB{1,sub} = ddm.vba.posterior;
    outSUB{1,sub} = ddm.vba.out;
    % set DDM2
    dr = abs(ease{sub})./(eps+2*(1-cert{sub}));
%     Xparams.v = [ones(length(dr),1),dr(:)];
%     [ddm] = DDM_fitRT_Xparams(rt,co,Xparams,verbose,0,logFlag);
    [ddm] = DDM_fitRT_movingBound(rt,co,dr,verbose,0,Xbounds,[],logFlag);
    posteriorSUB{2,sub} = ddm.vba.posterior;
    outSUB{2,sub} = ddm.vba.out;
    % set DDM3
    dr = abs(ease{sub});
    Xstd =[eps+2*(1-vec(cert{sub}))];
%     Xparams.v = [ones(length(dr),1),dr(:)];
%     Xparams.s = [ones(length(dr),1),Xstd(:)];
%     [ddm] = DDM_fitRT_Xparams(rt,co,Xparams,verbose,0,logFlag);
    [ddm] = DDM_fitRT_movingBound(rt,co,dr,verbose,0,Xbounds,Xstd,logFlag);
    posteriorSUB{3,sub} = ddm.vba.posterior;
    outSUB{3,sub} = ddm.vba.out;
    % extract parameters and postdiction accuracy
    for i=1:3
        if i<3
            ip = i12;
        else
            ip = 1:10;
        end
        Ps(ip,sub,i) = posteriorSUB{i,sub}.muPhi;
        Ey = outSUB{i,sub}.suffStat.gx;
        if length(unique(Ey))==1
            rs(i,sub) = 0;
        else
            if logFlag
                rs(i,sub) = corr(log(rt)',Ey);
            else
                rs(i,sub) = corr(rt',Ey);
            end
        end
    end
    disp(['sub #',num2str(sub),': r = ',num2str(rs(:,sub)')])
end

%% eyeball fit accuracy per choice feature
mr = mean(100*rs(:,subs),2);
vr = var(100*rs(:,subs),[],2)./nSubs;
[haf,hf,hp] = plotUncertainTimeSeries(mr,vr,[],[]);
set(get(haf,'parent'),'name','DDM fit accuracy')
set(haf,'xtick',1:3,'xticklabel',{'DDM1','DDM2','DDM3'})
ylabel(haf,'trial-by-trial correlation (%)')
box(haf,'off')



%% display model-based RFX analysis results
N = 3;
col = 'bgrc';
hf = figure('name','RFX fit');
% 2D-binning of data
i = 1;
Zi = [];
for sub=subs % concatenate data over subjects
    if ~logFlag
        Zi = [Zi,vec(exp(RT{sub}))];
    else
        Zi = [Zi,zscore(vec(RT{sub}))];
    end
end
os = smart2DbinPlot(X,Y,Zi,N,0);
EZs{i,1} = os.EZxy;
VZs{i,1} = os.VZxy;
Ns{i,1} = os.nxy;
Xs{i,1} = os.Xxy;
Ys{i,1} = os.Yxy;
ha = axes('parent',hf,'nextplot','add');
dx = 0.5./N;
for j=1:N
    he = errorbar((j-N/2)*dx+[1:N],EZs{i,1}(:,j),sqrt(VZs{i,1}(:,j)./Ns{i,1}(:,j)),'parent',ha,...
        'marker','.','linestyle','-','linewidth',2,'markersize',24,'color',col(j));
end
for il=1:N
    strx{il} = [num2str(round(100*(il-1)/N)),'-',num2str(round(100*il/N)),'%'];
    strl{il} = ['VCR0 percentile: ',strx{il}];
end
% 2D-binning of model fits
markers = {'*','diamond','p'};
linestyles = {'-.','--',':'};
for i=2:3
    Zi = [];
    for sub=subs
        Zi = [Zi,vec(outSUB{i,sub}.suffStat.gx)];
    end
    out = smart2DbinPlot(X,Y,Zi,N,0);
    EZs{i,2} = out.EZxy;
    VZs{i,2} = out.VZxy;
    Ns{i,2} = out.nxy;
    Xs{i,2} = out.Xxy;
    Ys{i,2} = out.Yxy;
    for j=1:N
        he = plot((j-N/2)*dx+[1:N],EZs{i,2}(:,j),'parent',ha,...
            'marker',markers{i},'linestyle',linestyles{i},'linewidth',1,'markersize',12,'color',col(j));
    end
end
title(ha,'RT (DDM)')
xlabel(ha,'dVR0 percentile')
set(ha,'xtick',1:N,'xticklabel',strx,'xlim',0.5+[0,N])
legend(ha,strl)
getSubplots


%% save results
thisdate = datestr(now,1);%'18-Jan-2021'; %
fsn = ['results_DDM_',thisdate,'_V100_logRT_all.mat'];
save(fsn)
