% run thrugh all subjects and fit them using MCD

clear all
close all

% extract data
lo = load('FCP3_data');
nSubs=length(lo.choice);
subs=1:nSubs;
% subs=setdiff(subs,lo.badSubs);
nSubs=length(subs);
X = [];
Y = [];
for sub=subs
    % extract prior value difference and average item value
    vL=(lo.choice{sub}.rating1_L'-.9)./99.11;
    vR=(lo.choice{sub}.rating1_R'-.9)./99.11;
    ease{sub}=abs(vL-vR);
    meanV{sub} = (vL+vR)/2;
    % extract average value prior certainty
    cL=(lo.choice{sub}.certainty1_L'-.9)./99.11;
    cR=(lo.choice{sub}.certainty1_R'-.9)./99.11;
    cert{sub}=(cL+cR)./2;
    % extract consequential choice tag
    consequential{sub} = lo.choice{sub}.reward';
    % extract penalized choice tag
    penalty{sub} = lo.choice{sub}.penalty';
    % extract choice confidence
    CC{sub}=(lo.choice{sub}.confidence'-.9)./99.11;
    % extract subjective effort rating
    EF{sub}=(lo.choice{sub}.effort'-.9)./99.11;
    % extract log(RT)
    RT{sub}=log(lo.choice{sub}.RT');
    % exptract spreading-of-alternatives
    SP{sub}=(lo.choice{sub}.spreading'-.9)./99.11;
    % extract change-of-mind
    COM{sub} = lo.choice{sub}.changing_mind';%+lo.choice{sub}.error';
    % extract value precision gain
    cL2=(lo.choice{sub}.certainty2_L'-.9)./99.11;
    cR2=(lo.choice{sub}.certainty2_R'-.9)./99.11;
    cert2=(cL2+cR2)./2;
    dcert{sub} = cert2-cert{sub};
    % store MCD-relevant value prior moments
    X = [X,zscore(vec(ease{sub}))];
    Y = [Y,zscore(vec(cert{sub}))];
end

% assemble data
Z = {CC,SP,RT,COM,EF,dcert}; % choice output features
znames = {...
    'choice confidence',...
    'spreading of alternatives',...
    'reaction time',...
    'change of mind',...
    'effort ratings',...
    'precision gain'...
    };

% set MCD VBA-inversion options
switchFlags = [1;1;1];
g_fname = @g_MCD_extended;%@g_MCD_weirdR;%@g_MCD_expEfficacy;%@g_MCD; %
dim.n_theta = 0;
dim.n = 0;
dim.n_phi = 23;
options.priors.SigmaPhi = 1e2.*eye(dim.n_phi);
options.verbose = 1;
options.DisplayWin = 0;
options.inG.switchFlags = switchFlags;
M.dim = dim;
M.options = options;
M.g_fname = g_fname;

% run subject-by-subject model inversion (all data)
for sub=subs
    y = [zscore(CC{sub});
        zscore(SP{sub});
        zscore(RT{sub});
        zscore(COM{sub});
        zscore(EF{sub});
        zscore(dcert{sub})];
    M.options.isYout = zeros(size(y));
    M.u = [ease{sub};cert{sub};consequential{sub};penalty{sub};meanV{sub}];
    [posteriorSUB{sub},outSUB{sub}] = VBA_main(y,M);
    Ey = outSUB{sub}.suffStat.gx;
    for i=1:size(y,1)
        if length(unique(Ey(i,:)))==1
            rs(i,sub) = 0;
            disp(num2str(i))
        else
            rs(i,sub) = corr(y(i,:)',Ey(i,:)');
        end
    end
    disp(['sub #',num2str(sub),': r = ',num2str(rs(:,sub)')])
    Ps(:,sub) = posteriorSUB{sub}.muPhi(1:11);
end

% eyeball fit accuracy per choice feature
rs = rs([1,2,4,6,3,5],:); % re-order choice outcome features
mr = mean(100*rs(:,subs),2);
vr = var(100*rs(:,subs),[],2)./nSubs;
[haf,hf,hp] = plotUncertainTimeSeries(mr,vr,[],[]);
set(get(haf,'parent'),'name','MCD fit accuracies')
set(haf,'xtick',1:6,'xticklabel',{'Pc','SoA','Qcom','diV','RT','Eff'})
ylabel(haf,'trial-by-trial correlation (%)')
box(haf,'off')



% display model-based RFX analysis results
N = 2;
col = 'bgrc';
hf = figure('name','RFX fit');
for i=1:length(Z) % loop over choice output features (Z={CC,SP,RT,COM,EF,diV})
    % 2D-binning of data
    tmp = Z{i}; % i^th choice feature
    Zi = [];
    for sub=subs % concatenate data over subjects
        Zi = [Zi,zscore(vec(tmp{sub}))];
    end
    os = smart2DbinPlot(X,Y,Zi,N,0);
    EZs{i,1} = os.EZxy;
    VZs{i,1} = os.VZxy;
    Ns{i,1} = os.nxy;
    Xs{i,1} = os.Xxy;
    Ys{i,1} = os.Yxy;
    ha = subplot(3,2,i,'parent',hf,'nextplot','add');
    dx = 0.5./N;
    for j=1:N
        he = errorbar((j-N/2)*dx+[1:N],EZs{i,1}(:,j),sqrt(VZs{i,1}(:,j)./Ns{i,1}(:,j)),'parent',ha,...
            'marker','.','linestyle','-','linewidth',2,'markersize',24,'color',col(j));
    end
    
    for il=1:N
        strx{il} = [num2str(round(100*(il-1)/N)),'-',num2str(round(100*il/N)),'%'];
        strl{il} = ['certainty (data): ',strx{il}];
    end
    % 2D-binning of model fits
    Zi = [];
    for sub=subs
        Zi = [Zi,vec(outSUB{sub}.suffStat.gx(i,:))];
    end
    out = smart2DbinPlot(X,Y,Zi,N,0);
    EZs{i,2} = out.EZxy;
    VZs{i,2} = out.VZxy;
    Ns{i,2} = out.nxy;
    Xs{i,2} = out.Xxy;
    Ys{i,2} = out.Yxy;
    for j=1:N
        he = plot((j-N/2)*dx+[1:N],EZs{i,2}(:,j),'parent',ha,...
            'marker','diamond','linestyle','--','linewidth',1,'markersize',12,'color',col(j));
    end
    title(ha,znames{i})
    for il=1:N
        strl{N+il} = ['certainty (model): ',strx{il}];
    end
    xlabel(ha,'choice ease percentile')
    set(ha,'xtick',1:N,'xticklabel',strx,'xlim',0.5+[0,N])
    legend(ha,strl)
end
getSubplots

% check condition effects
subok = [];
for sub=subs
    if length(unique(penalty{sub}))>1&&length(unique(consequential{sub}))>1
        subok = [subok;sub];
    end
end
GLM_contrast(ones(length(subok),1),Ps(:,subok)',1,'t',1)


% save results
flags = [];
for i=1:3
    flags = [flags,num2str(switchFlags(i))];
end
fsn = ['results_MCD_',datestr(now,1),'_wobias_',flags,'_V100_Dcert_extended.mat'];
save(fsn)
