% Script for eLife manuscript "Attenuation of dopamine-modulated prefrontal value signals 
                                % underlies probabilistic reward learning deficits in old age."
% this is the mother script for all the behavioural modelling. Should be run in
% the folder with all other model files. 

% Script by Peter Dayan, Lieke de Boer and Marc Guitart-Masip, 2016
%%% Adapted from Quentin J. M. Huys, UCL, London 2011
%%% Reference:
%%% Guitart-Masip M, Quentin JM, Fuentemilla LL, Dayan P, Duzel E, Dolan RJ (2012)
%%% Go and no-go learning in reward and punishment: Interaction between affect and effect NeuroImage doi:10.1016/j.neuroimage.2012.04.024

%% to start parfor
%c=parcluster;
%c.NumWorkers=22 ; % this is the maxnumber of workers you're allowed
%parpool(22) ; % same thing

%%
clear all;

% dt = input('eneter the date i.e.:150225 ','s');

dt = '160808_EM_all';

load tab_data.mat

dosave = 1;
docomp = 0;
docheck = 0;
Nsample = 2000;

ff{1} = 'llba';
ff{2} = 'llbac';
ff{3} = 'llbacf';  % a model that forgets the unchosen option with a forgetting rate
ff{4} = 'llbbdist1p'; % this model learns p(reward) for each bandit as a beta distribution with equal alpha and beta parameter
ff{5} = 'llbbdist2p'; % as preceeding but with different alpha and beta parameter for the beta distribution
ff{6} = 'llbbdist2p_b';    % beta distribution + choice kernel
ff{7} = 'llbbdist2p_V';
ff{8} = 'llbbdist2p_c';  % value of unchosen option on previous trial is modified according to the current uncertainty of that option (unchosen on previous trial)
ff{9} = 'llbbdist2p_V2';    % value of unchosen option on previous trial is modified according to the current uncertainty of that option (unchosen on previous trial)
ff{10} = 'llbbdist2p_vc';    % value of unchosen option on previous trial is modified according to the current uncertainty of that option (unchosen on previous trial) + confidence bonus on previously unchosen option
Npar=[2 3 4 2 3 4 4 4 4 5];


options=optimset('display','off','DerivativeCheck','on');
warning('off','optim:fminunc:SwitchingMethod')
for whichinf=1:size(ff,2)
    for ite=1:10
        
        exx=[]; E=[]; V=[]; PL=[]; mu=[]; nu=[]; par=[]; lt=[]; et=[]; LLi=[]; iL=[]; bici=[];
        
        Np = Npar(whichinf);
        
        ld = [dt '-' ff{whichinf} '-ite' num2str(ite)];
        
        Nsj=length(A);
        Z.mu=zeros(Np,1);
        Z.nui=eye(Np);
        init=.1*randn(Np,1);
        
        E=zeros(Np,Nsj);
        V=zeros(Np,Nsj);
        PL=zeros(1,Nsj);
        exx=zeros(1,Nsj);
        LL=zeros(1,Nsj);
        
        init
        
        emit=0;
        while 1;emit=emit+1;sj=0;
            
            % E step......................................................
            
            parfor sj=1:Nsj
                fprintf('%2d\n',sj);
                if isempty(A{sj})==0;
                    a=A{sj};
                    r=R{sj};
                    b=B{sj};
                    
                    init=.1*randn(Np,1);
                    if emit>1; init=E(:,sj); end
                    warning('off','optim:fminunc:SwitchingMethod')
                    [est,fval,ex]=fminsearch(@(x) feval(ff{whichinf},x, a, r, Z,1),init,options);
                    
                    mf=fval;mE=est;mx=ex;
                    for i=1:0
                        [est,fval,nex]=fminsearch(@(x) ...
                            feval(ff{whichinf},x, a, r, Z,1),init+.1*randn(Np,1),options);
                        if (fval < mf)
                            mf=fval;
                            mE=est;
                            mx=nex;
                        end
                    end
                    
                    fval=mf;
                    est=mE;
                    ex=mx;
                    
                    % 		    if ex<0 ; tmp=tmp+1; fprintf('didn''t converge %i times exit status %i\r',tmp,ex); end
                    
                    hess=NumHessian(@(x) feval(ff{whichinf},x, a, r, Z,1),est);
                    
                    exx(sj)=ex;
                    E(:,sj)=est;			% Subjets' parameter estimate
                    V(:,sj) = max(diag(inv(hess)),1e-5);	% inverse of Hessian = variance
                    PL(sj) = fval;
                    LL(sj)=feval(ff{whichinf},est, a, r, Z,0);
                    
                    fprintf('Emit=%i subject %i model %i iteration %i exit status=%i\r',emit,sj,whichinf,ite,exx(sj))
                end
            end
            
            % M step using factorized posterior .................................
            mu = mean(E,2);
%             nu = sqrt(sum(E.^2 + V,2)/Nsj - mu .^2);
            nu = sqrt(sum(E.^2,2)/Nsj + trimmean(V,10,2) - mu .^2);   % fix to remove extreme values of the Hessian
            Z.mu = mu; Z.nui = inv(diag(nu));
            
            whichinf
            [mu nu]
            
            par(emit,:) = [sum(LL) sum(PL) mu' nu(:)'];						% save stuff
            et(emit,:,:) = E;
            lt(emit,:,:) = [LL; PL];
            
            if dosave;
                eval(['save mat/' ld ' mu nu E V LL PL par et lt Z exx emit ld ff whichinf;']);
            end
            % check convergence ...........................................
            if emit>1;if sum(abs(par(emit,2:end)-par(emit-1,2:end)))<1e-2;fprintf('\n *** converged *** \n');break;end;end
        end
        
    end
end
% quit




