clear all;

load data_tab_simulation_betadist2p_vc

%%%%%%%%%%%%
%%% load Reward function
%%%%%%%%%%%%
load PayOff_R2;


for sj=1:length(E);
    
    %%%%%%%%
    %%% load parameters
    %%%%%%%%%
    
    beta   = exp(E(1,sj));
    w  = 1/(1+exp(-E(2,sj)));
    f = 1/(1+exp(-E(3,sj)));
    cfd   = E(4,sj);
    v     = E(5,sj);
    
    Q=ones(2,2);
    qs = [0 0];
    for ite=1:100
        for t=1:220
            
            
            q(1)=Q(1,1)/sum(Q(1,:));
            q(2)=Q(2,1)/sum(Q(2,:));
            
            V(1) = (Q(1,1)*Q(1,2))/((Q(1,1)+Q(1,2))^2*(Q(1,1)+Q(1,2)+1));
            V(2) = (Q(2,1)*Q(2,2))/((Q(2,1)+Q(2,2))^2*(Q(2,1)+Q(2,2)+1));
            
            q = q + qs;
            
            if (t>1)
                q(3-a)=q(3-a)+v*V(3-a);
            end
            
            q = beta*(q);
            
            l0 = q - max(q);
            la = l0 - log(sum(exp(l0)));
            pg = exp(la);
            
            if pg(1)>rand
                a=1;
            else
                a=2;
            end
            
            [Y b]=max(PayOff(:,t));
            if a==b
                optimal(ite,t)=1;
            else
                optimal(ite,t)=0;
            end
            
            if PayOff(a,t)>rand
                r=1;
            else
                r=0;
            end
            
            c=conff(Q(1,1),Q(1,2),Q(2,1),Q(2,2));
            
            qs = [0 0];
            %     qs(a(t)) = cfd*(3*c-1+a(t)*(1-2*c));
            qs(a) = cfd*(1-2*c)*(3-2*a); % this makes c(unchosen)-c(chosen)
            
            Q(a,:)=w+(1-w)*Q(a,:);
            Q(a,2-r)=Q(a,2-r)+1;
            Q(3-a,:)=f+(1-f)*Q(3-a,:);
            
            choice(ite,t)=a;
        end
    end
    choice(choice==2)=0;
    Choice{sj}=choice;
    Optimal{sj}=optimal;
end


for i=1:length(Ap); act=Ap{i}; act(act==2)=0; OC{i}=act; end;


for i=1:length(Ap)
    
    allOC1(i,:)=OC{i};
    allSC1(i,:)=mean(Choice{i})';
    
end

plot(mean(allOC1),'b')
hold on
plot(mean(allSC1),'r')



figure;
subplot(2,2,1)
plot(mean(allOC1(1:27,:)),'b')
hold on
plot(mean(allSC1(1:27,:)),'r')
axis([0 220 0 1])
title('older')

subplot(2,2,2)
plot(mean(allOC1(28:end,:)),'b')
hold on
plot(mean(allSC1(28:end,:)),'r')
axis([0 220 0 1])
title('younger')

figure
plot(mean(allSC1(1:27,:)),'b')
hold on
plot(mean(allSC1(28:end,:)),'g')

figure;
plot(mean(allOC1(1:27,:)),'b')
hold on
plot(mean(allOC1(28:end,:)),'g')


save simulationstats.mat OC allOC1 allSC1 