clear all;
close all;

load MSElin_stats_subjall_trialall.mat

% Removing T=1000
SEM_MSElin_T1=SEM_MSElin(:,2);
SEM_MSElin(:,2)=[];
ave_MSElin_T1=ave_MSElin(:,2);
ave_MSElin(:,2)=[];

%% Fig.3A Direct storage fit to all items:
s1=['Fig.3A Direct storage fit to all items, using weighted least squares for fit'];
display(s1);

SEM=SEM_MSElin;

Items=[1 2 4 6];
Times=[100 2000 3000];
Colors =['k';'b';'c';'g';'r';'m';'y';'k';'b';'c'];
Lines = ['-','--','-.',':'];
Widths = [4,2.5,1.5,0.5];
alpha_init=1/18000/3.2;
range=pi;  % in units of radians, the variable goes from [0,pi]

Nsearchpts=10000;
alpharange = linspace(alpha_init/20, 2*alpha_init, Nsearchpts);
cost_direct =zeros(Nsearchpts,1);

for i=1:Nsearchpts,
    
    alpha = alpharange(i);
    direct_model = alpha*(Items'*(Times-100))+ave_MSElin(:,1)*ones(1,3);
    % compute cost based on Weighted LS fit to all curves
    cost_direct(i) = sum(sum((direct_model-ave_MSElin).^2./SEM.^2));
    
end

% set optimal param value:
alpha = alpharange(find(cost_direct==min(cost_direct)));
direct_model = alpha*(Items'*(Times-100))+ave_MSElin(:,1)*ones(1,3);

s2=['Sum of Weighted Square Error for Fig.3A: ', num2str( sum(sum((direct_model-ave_MSElin).^2./SEM.^2)) )];
s3=['Sum of Square Error for Fig.3A: ', num2str( sum(sum((direct_model-ave_MSElin).^2)) )];
disp(s2)
disp(s3)

% plot data w/ direct model:
figure('position', [0, 400, 400, 300]);
hL = errorbar((Times./1000)'*ones(1,4), ave_MSElin'/range^2, SEM_MSElin'/range^2, 'o','MarkerSize',11,'LineWidth',2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g')
hold on;
hL = errorbar((1000./1000)'*ones(1,4), ave_MSElin_T1'/range^2, SEM_MSElin_T1'/range^2, 'x','MarkerSize',11,'LineWidth',2);
hL = plot(Times./1000, direct_model'/range^2,'LineWidth',2);
xlabel('delay (sec)');
ylabel('mean square error (norm)')
set(gca,'XLim',[0 3])
set(gca,'XTick',[0.1 1 2 3])
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g')
savefig('3_Fig3a.fig')
saveas(gcf,'3_Fig3a','epsc')

% Predictions for T=1000 point
direct_model = alpha*(Items'*(1000-100))+ave_MSElin(:,1);
s3_new=['Sum of Square Error for T=1000 prediction: ', num2str( sum((direct_model-ave_MSElin_T1).^2) )];
disp(s3_new)

save direct_model_allitems_WLS.mat direct_model alpha


%% Fig.3B Direct storage fit to 6 item:
s4=['Fig.3B Direct storage fit to 6 item, using ordinary least squares for fit'];
display(s4);

SEM=SEM_MSElin;

Items=[1 2 4 6];
Times=[100 2000 3000];
Colors =['k';'b';'c';'g';'r';'m';'y';'k';'b';'c'];
Lines = ['-','--','-.',':'];
Widths = [4,2.5,1.5,0.5];
alpha_init=1/18000/3.2;
range=pi;  % in units of radians, the variable goes from [0,pi]

Nsearchpts=10000;
alpharange = linspace(alpha_init/20, 2*alpha_init, Nsearchpts);
cost_direct =zeros(Nsearchpts,1);

for i=1:Nsearchpts,
    
    alpha = alpharange(i);
    direct_model = alpha*(Items'*(Times-100))+ave_MSElin(:,1)*ones(1,3);
    % compute cost based only on fit to 6-item curve
    cost_direct(i) = sum(sum((direct_model(4,:)-ave_MSElin(4,:)).^2));
    
end

% set optimal param value:
alpha = alpharange(find(cost_direct==min(cost_direct)));
direct_model = alpha*(Items'*(Times-100))+ave_MSElin(:,1)*ones(1,3);

s5=['Sum of Weighted Square Error for Fig.3B: ', num2str( sum(sum((direct_model-ave_MSElin).^2./SEM.^2)) )];
s6=['Sum of Square Error for Fig.3B: ', num2str( sum(sum((direct_model-ave_MSElin).^2)) )];
disp(s5)
disp(s6)



% plot cost as fxn of alpha:
% figure;
% plot(alpharange, cost_direct);

% plot data w/ direct model:
figure('position', [0, 0, 400, 300]);
hL = errorbar((Times./1000)'*ones(1,4), ave_MSElin'/range^2, SEM_MSElin'/range^2, 'o','MarkerSize',11,'LineWidth',2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g')
hold on;
hL = errorbar((1000./1000)'*ones(1,4), ave_MSElin_T1'/range^2, SEM_MSElin_T1'/range^2, 'x','MarkerSize',11,'LineWidth',2);
hL = plot(Times./1000, direct_model'/range^2,'LineWidth',2);
xlabel('delay (sec)');
ylabel('mean square error (norm)')
set(gca,'XLim',[0 3])
set(gca,'XTick',[0.1 1 2 3])
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g')
savefig('3_Fig3b.fig')
saveas(gcf,'3_Fig3b','epsc')

% Predictions for T=1000 point
direct_model = alpha*(Items'*(1000-100))+ave_MSElin(:,1);
s6_new=['Sum of Square Error for T=1000 prediction: ', num2str( sum((direct_model-ave_MSElin_T1).^2) )];
disp(s6_new)

save direct_model_6item_OLS.mat direct_model alpha



%% Fig.5B Coded storage
s7=['Fig.5B Coded storage, using weighted least squares for fit'];
display(s7);

% MSElin is in units of radians^2
SEM=SEM_MSElin;

Items=[1 2 4 6];
Times=[100 2000 3000];
Colors =['k';'b';'c';'g';'r';'m';'y';'k';'b';'c'];
Lines = ['-','--','-.',':'];
Widths = [4,2.5,1.5,0.5];
Dc_init=2.4/10000;
N=10;
range=pi;  % in units of radians, the variable goes from [0,pi]

Nsearchpts=200-1;
Nrange_begin=0;
Nrange_end=3;
Dcrange_begin=-8;
Dcrange_end=8;

Nrange=logspace(Nrange_begin,Nrange_end,Nsearchpts);
Dcrange = Dc_init*logspace(Dcrange_begin, Dcrange_end, Nsearchpts);


cost_coded = zeros(Nsearchpts,Nsearchpts);

for j=1:Nsearchpts,
    
    N=Nrange(j);
    
    for i=1:Nsearchpts,
        
        Dc = Dcrange(i);
        coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(Times-100))).^(-N./(Items'*ones(1,3)))+ave_MSElin(:,1)*ones(1,3);
        cost_coded(i,j) = sum(sum((coded_model-ave_MSElin).^2./SEM.^2));
        MSE_coded(i,j) = sum(sum(coded_model));
        paramN(i,j) = N;
        paramDc(i,j) = Dc;
        
    end
    
end


% figure;
% imagesc(Nrange, Dcrange, log(cost_coded), 'CData',log(cost_coded));
% title('fit of coded model');
% colorbar;

figure('position', [600, 400, 400, 300]);
surf(Nrange, Dcrange, log(cost_coded), 'EdgeColor','none');
view(0,90);
set(gca,'xscale','log');
set(gca,'yscale','log');
set(gca, 'ydir', 'reverse');
title('fit of coded model');
xlabel('N');
ylabel('D');
axis([10.^Nrange_begin 10.^Nrange_end Dc_init.*10.^Dcrange_begin Dc_init.*10.^Dcrange_end]);
colorbar;
savefig('3_Fig5a.fig')
saveas(gcf,'3_Fig5a','epsc')


% find valley bottom values of N, Dc:

Dcmin=zeros(Nsearchpts,1);
mincost=zeros(Nsearchpts,1);

for j=1:Nsearchpts,
    % for each N, find the minimum cost and the corresponding Dc:
    mincost(j) = min(cost_coded(:,j));
    mincost_Dc_ind = find(cost_coded(:,j)==mincost(j));
    Dcmin(j) = Dcrange(mincost_Dc_ind);
end



figure('position', [600, 0, 400, 150]);
[hAx,hLine1,hLine2] = plotyy(Nrange,mincost,Nrange,Nrange'.*(1./Dcmin),'semilogx','loglog');
xlabel('N')
ylabel(hAx(1),'fit quality along valley') % left y-axis
ylabel(hAx(2),'total resource use') % right y-axis
set(hAx(1),{'ycolor'},{'b'})
set(hAx(2),{'ycolor'},{'k'})
set(hLine1,'LineWidth',0.1,'marker','o','color','b')
set(hLine2,'LineWidth',0.1,'marker','o','color','k')
savefig('3_Fig5c.fig')
saveas(gcf,'3_Fig5c','epsc')


figure('position', [600, 200, 400, 150]);
subplot(1,3,1)
N=5;
for i=1:Nsearchpts,
    Dc = Dcrange(i);
    coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(Times-100))).^(-N./(Items'*ones(1,3)))+ave_MSElin(:,1)*ones(1,3);
    Dc_cost_coded(i) = sum(sum((coded_model-ave_MSElin).^2./SEM.^2));
end
Dc = Dcrange(find(Dc_cost_coded==min(Dc_cost_coded)));
coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(Times-100))).^(-N./(Items'*ones(1,3)))+ave_MSElin(:,1)*ones(1,3);
hold on;
hL = errorbar((Times./1000)'*ones(1,4), ave_MSElin'/range^2, SEM_MSElin'/range^2, 'o','MarkerSize',11,'LineWidth',2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g');
hL = errorbar((1000./1000)'*ones(1,4), ave_MSElin_T1'/range^2, SEM_MSElin_T1'/range^2, 'x','MarkerSize',11,'LineWidth',2);
hL=plot(Times./1000, coded_model'/range^2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g');
title('N=5');
xlabel('delay (sec)');
ylabel('mean squared error (norm)');
set(gca,'XLim',[0 3]);
set(gca,'XTick',[0.1 1 2 3]);


subplot(1,3,2)
N=10;
for i=1:Nsearchpts,
    Dc = Dcrange(i);
    coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(Times-100))).^(-N./(Items'*ones(1,3)))+ave_MSElin(:,1)*ones(1,3);
    Dc_cost_coded(i) = sum(sum((coded_model-ave_MSElin).^2./SEM.^2));
end
Dc = Dcrange(find(Dc_cost_coded==min(Dc_cost_coded)));
coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(Times-100))).^(-N./(Items'*ones(1,3)))+ave_MSElin(:,1)*ones(1,3);
hold on;
hL = errorbar((Times./1000)'*ones(1,4), ave_MSElin'/range^2, SEM_MSElin'/range^2, 'o','MarkerSize',11,'LineWidth',2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g');
hL = errorbar((1000./1000)'*ones(1,4), ave_MSElin_T1'/range^2, SEM_MSElin_T1'/range^2, 'x','MarkerSize',11,'LineWidth',2);
hL=plot(Times./1000, coded_model'/range^2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g');
title('N=10');
set(gca,'XLim',[0 3])
set(gca,'XTick',[0.1 1 2 3])


s8=['Sum of Weighted Square Error for Fig.5B: ', num2str( sum(sum((coded_model-ave_MSElin).^2./SEM.^2)) )];
s9=['Sum of Square Error for Fig.5B: ', num2str( sum(sum((coded_model-ave_MSElin).^2)) )];
disp(s8)
disp(s9)

% Predictions for T=2000 point
coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(1000-100))).^(-N./(Items'))+ave_MSElin(:,1);
s9_new=['Sum of Square Error for T=1000 prediction: ', num2str( sum((coded_model-ave_MSElin_T1).^2) )];
disp(s9_new)

save coded_model_allitems_WLS.mat coded_model N Dc



subplot(1,3,3)
N=100;
for i=1:Nsearchpts,
    Dc = Dcrange(i);
    coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(Times-100))).^(-N./(Items'*ones(1,3)))+ave_MSElin(:,1)*ones(1,3);
    Dc_cost_coded(i) = sum(sum((coded_model-ave_MSElin).^2./SEM.^2));
end
Dc = Dcrange(find(Dc_cost_coded==min(Dc_cost_coded)));
coded_model = ((range)^2/(2*pi*exp(1)))*(1+1./(2*Dc*ones(4,1)*(Times-100))).^(-N./(Items'*ones(1,3)))+ave_MSElin(:,1)*ones(1,3);
hold on;
hL = errorbar((Times./1000)'*ones(1,4), ave_MSElin'/range^2, SEM'/range^2, 'o','MarkerSize',11,'LineWidth',2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g');
hL = errorbar((1000./1000)'*ones(1,4), ave_MSElin_T1'/range^2, SEM_MSElin_T1'/range^2, 'x','MarkerSize',11,'LineWidth',2);
hL=plot(Times./1000, coded_model'/range^2);
set(hL(1),'color','k');set(hL(2),'color','b');set(hL(3),'color','c');set(hL(4),'color','g');
title('N=100');
set(gca,'XLim',[0 3])
set(gca,'XTick',[0.1 1 2 3])

savefig('3_Fig5b.fig')
saveas(gcf,'3_Fig5b','epsc')




%% BIC calculations

s10=['BIC calculations'];
display(s10);

Items=[1 2 4 6];
Times=[100 2000 3000];

ave_MSElin=zeros(4,1);
var_MSElin=zeros(4,1);
SEM_MSElin=zeros(4,1);

range = pi;

% Direct model all items WLS
load direct_model_allitems_WLS.mat
LL_direct_model_allitems_WLS_direct_model = direct_model;
LL_direct_model_allitems_WLS_alpha = alpha;

% Direct model 6item OLS
load direct_model_6item_OLS.mat
LL_direct_model_6item_OLS_direct_model = direct_model;
LL_direct_model_6item_OLS_alpha = alpha;

% Coded model all items WLS
load coded_model_allitems_WLS.mat
LL_coded_model_allitems_WLS_coded_model = coded_model;
LL_coded_model_allitems_WLS_N = N;
LL_coded_model_allitems_WLS_Dc = Dc;


% Generate necessary processed data vectors from raw data:
for i=1:4 % Items (1,2,4,6)
    for j=2 % Delay (100, 1000, 2000, 3000)
        %each Item, Delay: re-initialize the data-vectors
        data_oneitem_onedelay=[];
        for k=1:10 % (Subjects)
            %load data per subject (different
            %num trials possible across subjs).
            DataString=['Data/Subject_',num2str(k),'_data.mat'];
            load(DataString);
            workdata = squeeze(Data(i,j,:));
            % convert to radians and discard empty trials
            workdata = (2*pi/360)*workdata(find(workdata~=1000));
            % keep all trials
            data_oneitem_onedelay =[data_oneitem_onedelay; workdata];
        end %(loop over subjects)
        
        % MSElin is vector of sq error for all subjects and all trials:
        MSElin = (data_oneitem_onedelay).^2;
        
        % all trial, all subject MSE average for each item, delay:
        ave_MSElin(i,1) = mean(MSElin);
        var_MSElin(i,1) = var(MSElin);
        SEM_MSElin(i,1) = sqrt(var_MSElin(i,1)./length(data_oneitem_onedelay));
        
        % Use data vectors and theoretical predictions to generate likelihoods etc:
        LL_direct_model_allitems_WLS(i,1) = sum(log(1/sqrt(2*pi*var_MSElin(i,1)))-((MSElin-LL_direct_model_allitems_WLS_direct_model(i,1)).^2)/(2*var_MSElin(i,1)));
        LL_direct_model_6item_OLS(i,1) = sum(log(1/sqrt(2*pi*var_MSElin(i,1)))-((MSElin-LL_direct_model_6item_OLS_direct_model(i,1)).^2)/(2*var_MSElin(i,1)));
        LL_coded_model_allitems_WLS(i,1) = sum(log(1/sqrt(2*pi*var_MSElin(i,1)))-((MSElin-LL_coded_model_allitems_WLS_coded_model(i,1)).^2)/(2*var_MSElin(i,1)));
                
    end %(loop over j=delays)
end %(loop over i=numitems)


% Model comparison (more negative is better):

BIC_direct_model_allitems_WLS = -2*sum(sum(LL_direct_model_allitems_WLS)) + (2)*log(660*4*2*pi);
BIC_direct_model_6item_OLS = -2*sum(sum(LL_direct_model_6item_OLS)) + (2)*log(660*4*2*pi);
BIC_coded_model_allitems_WLS = -2*sum(sum(LL_coded_model_allitems_WLS)) + (2)*log(660*4*2*pi);

s11 = ['BIC for direct model all items WLS: ', num2str( BIC_direct_model_allitems_WLS ), ' (2D/N: ', num2str(LL_direct_model_allitems_WLS_alpha/range^2), ')'];
s12 = ['BIC for direct model 6 item OLS: ', num2str( BIC_direct_model_6item_OLS ), ' (2D/N: ', num2str(LL_direct_model_6item_OLS_alpha/range^2), ')'];
s13 = ['BIC for coded model all items WLS: ', num2str( BIC_coded_model_allitems_WLS ), ' (N: ', num2str(LL_coded_model_allitems_WLS_N), ', D: ', num2str(LL_coded_model_allitems_WLS_Dc),')'];
disp(s11)
disp(s12)
disp(s13)

% more positive favors coded
DeltaBIC_direct_minus_coded = BIC_direct_model_allitems_WLS - BIC_coded_model_allitems_WLS;
s14 = ['Delta BIC = BIC(direct model all items WLS) - BIC(coded model all items WLS): ', num2str(DeltaBIC_direct_minus_coded) ];
disp(s14)

%% Generate report

fid = fopen('Results.txt','wt');
fprintf(fid, ['Re-fit using only T=[100 2000 3000], i.e., cross-validation for T=1000. \n']); 
fprintf(fid, ['For the coded model, fit quality reaches to a minimum around N=10. \n\n']);
fprintf(fid, [s1,'\n',s2,'\n',s3,'\n',s3_new,'\n',s4,'\n',s5,'\n',s6,'\n',s6_new,'\n',s7,'\n',s8,'\n',s9,'\n',s9_new,'\n',s10,'\n',s11,'\n',s12,'\n',s13,'\n',s14]);
fclose(fid);