% Cesanek et al. (2021) Figure 5 Source Code (MATLAB code - rename with .m suffix to run)
% Creates all figure subplots and runs all relevant reported analyses

clearvars
close all
clc

if exist('gramm','file')==0
    error('The gramm plotting library is required to run this code.\nYou can download it from https://github.com/piermorel/gramm).\n');
end

%% Sample size estimation for GMM analysis
fprintf('\n** WEB-BASED TASK: Running GMM power analysis, please wait.');

N = 36; % sample size
numSimulations = 1000;
recoveries = nan(1,numSimulations);
for simulation = 1:numSimulations
    if mod(simulation,100)==1
        fprintf('.');
    end
    % Occasionally a simulated dataset will fail to converge
    % Use a try-catch block to skip these ones...
    warning('error','stats:gmdistribution:FailedToConverge'); %#ok<CTPCT>
    count = 0;
    err_count = 0;
    while count == err_count
        try
            data = [randn(N/2,1); randn(N/2,1)+3.5]; % assume 50-50 split
            AIC = nan(1,2);
            for components = 1:2
                GMModel = fitgmdist(data,components,'Options',statset('MaxIter',250));
                AIC(components) = GMModel.AIC;
            end
        catch MyErr
            err_count = err_count + 1;
        end
        count = count + 1;
    end
    [minAIC,numComponents] = min(AIC);
    recoveries(simulation) = numComponents;
end
fprintf('\nCorrectly recovered 2 components in %i%% of simulations with N=%.0f.\n', round(sum(recoveries==2)/numSimulations*100),N);

%% 
clearvars;
if ~exist('Figure5_SourceData1.txt','file')
    error('This analysis requires the following source data files: Figure5_SourceData1.txt');
end
F5 = readtable('Figure5_SourceData1.txt');

Subject = F5.Subject;
ObjID = F5.ObjID;
Block = F5.Block;
Condition = F5.Condition;
Phase = F5.Phase;
Weight = F5.Weight;
AF = F5.AF;
RF = F5.RF;
ConditionName = F5.ConditionName;

expNames = {'Linear+','Linear++','Linear-','Linear--'}; % in the order of Condition label
expNamesOrdered = {'Linear++','Linear+','Linear-','Linear--'}; % in the order we want to analyze them

expNames_sort = sort(expNames);
row_order = nan(1,length(expNames));
for ri = 1:length(expNames)
    row_order(ri) = find(strcmp(expNames_sort,expNamesOrdered{ri}));
end

famcolors = [255 124  49;
             255  97  20;
             215  77  18;
             175  56  18;
             137  34  17]/255;

%% GMM
NumLateTestBins = 5;
subset = ObjID==3 & Block>(max(Block)-NumLateTestBins);
[groups, s, c] = findgroups(Subject(subset),Condition(subset));
c = expNames(c)';
AF_m = splitapply(@nanmean,AF(subset),groups);
Class = nan(size(AF,1),1);

fprintf('\n** WEB-BASED TASK: Gaussian model fits to individual differences in outlier learning (Fig. 5c, f, i, l)\n');

for cond = 1:4
    AIC = nan(1,2);
    BIC = nan(1,2);
    exp = expNamesOrdered{cond};
    fprintf('\n*** %s\n',exp);
    for components = 1:2
        idx = 2*(cond-1)+(components-1)+1; % 1,2 = Linear++; 3,4 = Linear+; 5,6 = Linear-; 7,8 = Linear--;
        fprintf('Fitting %i-component GMM\n',components);
        GMModel{cond,components} = fitgmdist(AF_m(strcmpi(c,exp)),components,'Options',statset('MaxIter',1000));
        AIC(components) = GMModel{cond,components}.AIC;
        BIC(components) = GMModel{cond,components}.BIC;
        fprintf('AIC = %.1f\n',AIC(components));
        fprintf('BIC = %.1f\n',BIC(components));
        for ci = 1:components
            %fprintf('Mu%i = %.2f, Sigma%i = %.2f\n',ci,GMModel{components}.mu(ci),GMModel{components}.Sigma(ci));
        end
    end
    
    % Choose the winning model
    [~,numComponents] = min(AIC);
    fprintf('\nBest fit by %i components.\n',numComponents);
    mu = GMModel{cond,numComponents}.mu;
    sigma = GMModel{cond,numComponents}.Sigma;
    
    % Hard-cluster using the winning model
    class_assign = cluster(GMModel{cond,numComponents},AF_m(strcmp(c,exp)));
    subjectNums = s(strcmp(c,exp));
    for si = 1:length(subjectNums)
        Class(Subject==subjectNums(si)) = class_assign(si);
    end 
    % If Class 2 mean is closer to the interpolated weight than Class 1 mean, switch the Class labels
    if numComponents==2 && abs(mu(2)-0.981*5)<abs(mu(1)-0.981*5)
        Class(strcmp(expNames(Condition),exp)) = mod(Class(strcmp(expNames(Condition),exp)),2)+1;
    end
    group2 = findgroups(Subject(strcmp(expNames(Condition),exp)' & subset));
    Class_m = splitapply(@nanmean,Class(strcmp(expNames(Condition),exp)' & subset)',group2');
    fprintf('Cluster 1: %i non-learners',sum(Class_m==1));
    if numComponents==2
        fprintf('\nCluster 2: %i learners\n',sum(Class_m==2));
    else
        fprintf('\n');
    end

end

%% Fig 5, left column: Anticipatory force timelines
[groups, ~, b, p, o, w, c] = findgroups(Subject,Block,Phase,ObjID,Weight,Condition);
out = (o==3 & c<5) | (o==2 & c>=5);
c = expNames(c)';
AF_m = splitapply(@nanmean,AF,groups);
RF_m = splitapply(@unique,RF,groups);

[~,~,ip] = unique(p);

close(figure(51));
figh = figure(51);

pixPerInch = 500/6.95;
WidthInches = 4.5;
HeightInches = 4;
figh.Position(1) = 0;
figh.Position(2) = 0;
figh.Position(3) = pixPerInch*WidthInches;
figh.Position(4) = pixPerInch*HeightInches;

clear g
g = gramm('x',b, 'y',AF_m, 'color',w, 'group',ip, 'subset', ~out);
g.set_layout_options('redraw',false,'margin_height',[.1 .02],'margin_width',[.07 .02],'gap',[0.02 0.02]);
g.no_legend();
g.set_names('x','','y','','color','','Column','','Row','');
g.axe_property('XLim',[0 36],'YLim',[0 11],'YTick',1:2:11,'YTickLabels',{'1','','5','','9',''});
g.facet_grid(c,[]);
g.set_order_options('row',row_order);
g.set_point_options('base_size',2);
g.set_line_options('base_size',0.5);
g.set_color_options('map',famcolors([1 2 4 5],:),'n_color',4,'n_lightness',1);
g.stat_summary('geom',{'area'},'width',1,'type','sem');
g.draw();
g.update('y',RF_m);
g.no_legend();
g.stat_summary('geom',{'line'});
g.set_line_options('styles',{':'},'base_size',0.5);
g.draw();
g.update('y',RF_m,'color',[],'subset',out);
g.no_legend();
g.set_point_options('base_size',3);
g.set_line_options('base_size',1);
g.set_color_options('map',famcolors(3,:),'n_color',1,'n_lightness',1);
g.stat_summary('geom',{'line'});
g.set_line_options('styles',{':'});
g.draw();
g.update('y',AF_m);
g.no_legend();
g.set_line_options('styles',{'-'});
g.stat_summary('geom',{'area','point'},'type','sem');
g.draw();

%% Fig 5, middle column: Histograms
subset = ObjID==3 & Block>(max(Block)-NumLateTestBins);
[groups, ~, c] = findgroups(Subject(subset),Condition(subset));
c = expNames(c)';
AF_m = splitapply(@nanmean,AF(subset),groups);

binWidth = 0.981/2;

close(figure(52));
figh = figure(52);

pixPerInch = 500/6.95;
WidthInches = 1.4;
HeightInches = 4;
figh.Position(1) = 0;
figh.Position(2) = 0;
figh.Position(3) = pixPerInch*WidthInches;
figh.Position(4) = pixPerInch*HeightInches;

clear g
g = gramm('x',AF_m);
g.set_layout_options('redraw',false,'margin_height',[.1 .02],'margin_width',[.07 .02],'gap',[0.02 0.02]);
g.no_legend();
g.set_names('x','','y','','Row','');
g.axe_property('XLim',[0 11],'YLim',[-0.003 0.75],'XTick',1:2:11,'YTick',0:0.2:0.6,'TickLength',[0.03 0.03]);
g.facet_grid(c,[]);
g.set_order_options('row',row_order);
g.set_color_options('map',famcolors([3 5 1],:),'n_color',3,'n_lightness',1);
g.stat_bin('geom','stacked_bar','edges',((0.981/4)-binWidth):binWidth:(15*0.981),'fill','face','normalization','count');
g.coord_flip();
g.draw();

density_x = [];
density_y = [];
density_facet = [];
density_line = [];
for condi = 1:4
    for compi = 1:2
        density_x = [density_x 0:0.1:11]; %#ok<AGROW>
        density_y = [density_y pdf(GMModel{condi,compi},(0:0.1:11)')'];  %#ok<AGROW>
        density_facet = [density_facet repelem(condi,length(0:0.1:11))];  %#ok<AGROW>
        %density_line = [density_line repelem(GMModel{condi,compi}.AIC<GMModel{condi,mod(compi,2)+1}.AIC,length(0:0.1:11))];  %#ok<AGROW> 
        density_line = [density_line repelem(compi,length(0:0.1:11))];  %#ok<AGROW>
    end
end
g.update('x',density_x,'y',density_y,'color',density_line,'linestyle',density_line,'size',density_line);
g.no_legend();
g.set_order_options('row',1:4);
g.set_line_options('base_size',0.5,'step_size',1,'styles',{'-','-'});
g.set_color_options('map',[0 0 .75; 0 .5 0],'n_color',2,'n_lightness',1);
g.set_line_options('base_size',1.0,'step_size',0,'styles',{'-','-'});
g.facet_grid(density_facet,[]);
g.geom_line();
g.draw();

N = [length(s(strcmpi(c,'Linear++')))...
    length(s(strcmpi(c,'Linear+')))...
    length(s(strcmpi(c,'Linear-')))...
    length(s(strcmpi(c,'Linear--')))];

% Drop the facet labels because these will be next to the timelines
for paneli = 1:numel(g.facet_axes_handles)
    g.facet_axes_handles(paneli).XAxis.TickDirection = 'out';
    g.facet_axes_handles(paneli).XAxis.LineWidth = 0.5;
    tmp = g.facet_axes_handles(paneli).Children;
    
    % Adjust the bars so they correspond with the probability densities
    BarsList = findobj(tmp,'flat','type','patch');
    for pi = 1:length(BarsList)
        BarsList(pi).Vertices(:,2) = BarsList(pi).Vertices(:,2)/(N(paneli)*binWidth);
    end
end

%% Fig 5, right column: End of test phase averages
subset = Block>(max(Block)-NumLateTestBins);

[groups, s, o, p, cl, c] = findgroups(Subject(subset),ObjID(subset),Phase(subset),Class(subset),Condition(subset));
c = expNames(c)';
cl(o~=3) = 0;
AF_m = splitapply(@nanmean,AF(subset),groups);
RF_m = splitapply(@unique,RF(subset),groups);

% Helper functions
mySlope = @(x,y) subsref( ([ones(length(x),1) x']\y'), struct('type','()','subs',{{2}}));
myIntercept = @(x,y) subsref( ([ones(length(x),1) x']\y'), struct('type','()','subs',{{1}}));
% Fit a linear function for each participant using OLS on Training Objects
[groupsLR, sLR, eLR] = findgroups(Subject(subset&ObjID~=3),Condition(subset&ObjID~=3));
WindowSlopes = splitapply(mySlope,ObjID(subset&ObjID~=3)'+2,AF(subset&ObjID~=3)',groupsLR');
WindowIntercepts = splitapply(myIntercept,ObjID(subset&ObjID~=3)'+2,AF(subset&ObjID~=3)',groupsLR');
xpred = 2:0.1:8;
ypred = repmat(WindowIntercepts',1,length(xpred))+repmat(WindowSlopes',1,length(xpred)).*repmat(xpred,length(WindowSlopes),1);

close(figure(53));
figh = figure(53);

pixPerInch = 500/6.95;
WidthInches = 1.2;
HeightInches = 4;
figh.Position(1) = 0;
figh.Position(2) = 0;
figh.Position(3) = pixPerInch*WidthInches;
figh.Position(4) = pixPerInch*HeightInches;

g = gramm('x',o+2, 'y',AF_m, 'color',o);
g.set_layout_options('redraw',false,'margin_height',[.1 .02],'margin_width',[.07 .02],'gap',[0.02 0.02]);
g.no_legend();
g.set_names('x','','y','','color','','Column','','Row','');
g.axe_property('XLim',[2 8],'YLim',[0 11],'XTick',3:7,'YTick',1:2:11,'TickLength',[0.03 0.03]);
g.facet_grid(c,[]);
g.set_order_options('row',row_order);
g.set_point_options('base_size',4,'markers',{'o','d','s'});
g.set_line_options('base_size',0.5);
g.set_color_options('map',famcolors,'n_color',8,'n_lightness',1);
g.geom_abline('slope',3/3*0.981,'intercept',0,'style','r:');

g.stat_summary('geom',{'point','errorbar'},'width',0,'type','sem');
g.draw();

g.update('x',xpred,'y',ypred);
g.set_names('x','Volume (cm3/k)','y','Force (N)','color','Family','Column','','Row','');
g.axe_property('XLim',[2 8],'YLim',[0 11],'XTick',3:7,'YTick',1:2:11);
g.set_line_options('base_size',1);
g.set_color_options('lightness',0,'chroma',0);

g.facet_grid(eLR,[]);
g.no_legend();
g.stat_summary('type','sem');
g.draw();

RF_m(o~=3) = NaN;
g.update('y',RF_m,'color',[],'subset',[]);
g.no_legend();
g.set_point_options('markers',{'d'},'base_size',5);
g.stat_summary('geom','point');
g.draw();

density_means_x = [];
density_means_y = [];
density_means_facet = [];
density_means_color = [];
for condi = 1:4
    for compi = 2 %1:2
        density_means_x = [density_means_x repelem(5,compi)];  %#ok<AGROW>
        density_means_y = [density_means_y GMModel{condi,compi}.mu']; %#ok<AGROW> 
        density_means_facet = [density_means_facet repelem(condi,compi)]; %#ok<AGROW> 
        density_means_color = [density_means_color repelem(compi,compi)]; %#ok<AGROW> 
    end
end
% See the final section of analysis_code.m for the R code used to generate these bootstrapped CIs
ci_lower = [8.0 4.92 3.63 5.56 4.07 2.75 3.49 0.91];
ci_upper = [8.89 6.06 5.29 7.18 4.71 3.75 4.67 1.21];
% The ordering here could be wrong wrt density_means_y... this is a bit of a hack to figure it out
R_means = reshape(mean([ci_lower;ci_upper]),2,4);
ML_means = reshape(density_means_y,2,4);
correct_order = [];
for ei = 1:4
    for mi = 1:2
        [~,idx] = min(abs(R_means(mi,ei)-ML_means(:,ei)));
        correct_order = [correct_order idx+(ei-1)*2]; %#ok<AGROW>
    end
end
ci_lower = ci_lower(correct_order);
ci_upper = ci_upper(correct_order);
density_means_ymin = density_means_y - (density_means_y-ci_lower)/1.96;
density_means_ymax = density_means_y + (ci_upper-density_means_y)/1.96;
g.update('x',density_means_x,'y',density_means_y,'ymin',density_means_ymin,'ymax',density_means_ymax,'color',density_means_color);
g.no_legend();
g.set_order_options('row',1:4);
g.set_color_options('map',[0 .5 0],'n_color',1,'n_lightness',1);
g.set_point_options('markers',{'s'},'base_size',3);
g.facet_grid(density_means_facet,[]);
g.geom_point();
g.geom_interval('geom','black_errorbar','width',0);
g.draw();

%% Note: The GMM results in the manuscript were generated using the following R code
% The input data and results are identical to the MATLAB code above
% The mclust R package simply facilitates generating bootstrapped confidence intervals

% rm(list=ls())
% library(mclust)
% 
% data_plusplus = c(6.5220,8.6820,4.8960,5.7480,4.2780,5.4600,7.4640,8.2620,8.2980,8.9280,3.8820,5.6580,8.4660,8.5440,4.8263,8.8860,4.9860,10.1040,4.9200,7.1040,4.3223,6.9662,7.8540,8.2802,5.8680,8.4540,5.7240,6.1195,8.9520,6.5400,8.4060,8.9100,8.3160,9.4860,7.5720,5.6340,4.8180)
% mod_plusplus = Mclust(data_plusplus, G=2, modelNames = c('V'))
% means_plusplus = round(mod_plusplus$parameters$mean,2)
% boot_plusplus = MclustBootstrap(mod_plusplus,nboot=10000,type='pb')
% meanCI_plusplus = round(summary(boot_plusplus,what='ci')$mean,2)
% 
% data_plus = c(4.4406,5.2920,7.2120,3.7620,3.8880,6.0720,2.4600,3.6540,5.2500,5.1240,5.7240,4.6920,5.9520,6.2340,6.4200,3.5280,6.9780,4.4040,4.6860,5.2980,5.0160,7.2540,6.3420,5.7180,4.9080,6.4860,6.6840,5.3040,4.7400,5.4240,4.2780,6.4620,5.0400,5.9400,3.2280,4.7880)
% mod_plus = Mclust(data_plus, G=2, modelNames = c('V'))
% means_plus = round(mod_plus$parameters$mean,2)
% boot_plus = MclustBootstrap(mod_plus,nboot=10000,type='pb')
% meanCI_plus = round(summary(boot_plus,what='ci')$mean,2)
% 
% data_minus = c(3.4440,4.7340,4.4700,4.7220,4.9680,3.7380,4.3980,4.4760,3.6420,2.8920,3.7020,4.4520,3.2160,2.6880,4.5300,3.1440,4.3080,2.9460,4.9440,4.1178,4.0560,4.3680,4.1160,4.3560,4.6440,3.8160,5.5002,3.9120,3.2880,4.1400,5.1480,4.7220,3.6594,3.8940,3.3774,4.2480,4.1700)
% mod_minus = Mclust(data_minus, G=2, modelNames = c('V'))
% means_minus = round(mod_minus$parameters$mean,2)
% boot_minus = MclustBootstrap(mod_minus,nboot=10000,type='pb')
% meanCI_minus = round(summary(boot_minus,what='ci')$mean,2)
% 
% data_minusminus = c(4.9680,0.7860,2.6640,4.2720,0.5460,1.2420,6.3480,1.1100,3.5460,3.8280,0.9660,1.0320,4.8960,3.6420,1.2365,0.9420,1.4940,5.5680,4.2660,1.0680,1.6440,4.1100,0.9360,6.2280,3.3900,0.7980,0.7860,1.6140,3.5940,0.6360,1.2480,2.6460,0.6960,1.6620,4.3020,1.0320,4.1225,0.8580,2.8860)
% mod_minusminus = Mclust(data_minusminus, G=2, modelNames = c('V'))
% means_minusminus = round(mod_minusminus$parameters$mean,2)
% boot_minusminus = MclustBootstrap(mod_minusminus,nboot=10000,type='pb')
% meanCI_minusminus = round(summary(boot_minusminus,what='ci')$mean,2)
% 
% printResults = function(){
%   cat('Linear++\n')
%   cat(sep='', 'Mean 1 = ', means_plusplus[1], ', 95% CI = [', meanCI_plusplus[1,1,1],', ',meanCI_plusplus[2,1,1],']\n')
%   cat(sep='', 'Mean 2 = ', means_plusplus[2], ', 95% CI = [', meanCI_plusplus[1,1,2],', ',meanCI_plusplus[2,1,2],']\n')
%   cat('Linear+\n')
%   cat(sep='', 'Mean 1 = ', means_plus[1], ', 95% CI = [', meanCI_plus[1,1,1],', ',meanCI_plus[2,1,1],']\n')
%   cat(sep='', 'Mean 2 = ', means_plus[2], ', 95% CI = [', meanCI_plus[1,1,2],', ',meanCI_plus[2,1,2],']\n')
%   cat('Linear-\n')
%   cat(sep='', 'Mean 1 = ', means_minus[1], ', 95% CI = [', meanCI_minus[1,1,1],', ',meanCI_minus[2,1,1],']\n')
%   cat(sep='', 'Mean 2 = ', means_minus[2], ', 95% CI = [', meanCI_minus[1,1,2],', ',meanCI_minus[2,1,2],']\n')
%   cat('Linear--\n')
%   cat(sep='', 'Mean 1 = ', means_minusminus[1], ', 95% CI = [', meanCI_minusminus[1,1,1],', ',meanCI_minusminus[2,1,1],']\n')
%   cat(sep='', 'Mean 2 = ', means_minusminus[2], ', 95% CI = [', meanCI_minusminus[1,1,2],', ',meanCI_minusminus[2,1,2],']\n')
% }
% 
% printResults()

%% GMM analysis on laboratory conditions: Linear+ & +Linear, Linear++ & ++Linear
clearvars

fprintf('\n\n** LABORATORY TASK: Gaussian model fits to individual differences in outlier learning.\n');

if ~(exist('Figure2_SourceData1.txt','file') && exist('Figure3_SourceData1.txt','file'))
    error('This analysis requires the following source data files: Figure2_SourceData1.txt, Figure3_SourceData1.txt');
end
F2 = readtable('Figure2_SourceData1.txt');
F3 = readtable('Figure3_SourceData1.txt');
F2.RT = []; F2.Trial = []; F3.TrialsSinceSameObject = [];
F = [F2; F3];

Subject = F.Subject';
ExperimentName = F.ExperimentName';
Block = F.Block';
ObjID = F.ObjID';
AF = F.AF';
RF = round(F.RF,4)';
Phase = F.Phase';

% Helper functions
mySlope = @(x,y) subsref( ([ones(length(x),1) x']\y'), struct('type','()','subs',{{2}}));
myIntercept = @(x,y) subsref( ([ones(length(x),1) x']\y'), struct('type','()','subs',{{1}}));

% Fit a linear function for each participant using OLS on Training Objects
subset = Block > 54 & Block <= 70 & ObjID~=3;
groups = findgroups(Subject(subset),ExperimentName(subset));
WindowSlopes = splitapply(mySlope,RF(subset),AF(subset),groups);
WindowIntercepts = splitapply(myIntercept,RF(subset),AF(subset),groups);

% Get the Anticipatory Force for the Test Object
subset2 = Block > 54 & Block <= 70 & ObjID==3;
[groups2, s2, e2] = findgroups(Subject(subset2),ExperimentName(subset2));
WindowOutlierAF = splitapply(@nanmean, AF(subset2), groups2);

familyPredicted = WindowSlopes*0.9*9.81+WindowIntercepts;
plusNull = 1.2*9.81;
plusPlusNull = 1.5*9.81;

a1_sub = s2(strcmpi(e2,'Linear+'));
a1=WindowOutlierAF(strcmpi(e2,'Linear+'));
familyPredicted_a1 = familyPredicted(strcmpi(e2,'Linear+'));

a2_sub = s2(strcmpi(e2,'Linear++'));
a2=WindowOutlierAF(strcmpi(e2,'Linear++'));
familyPredicted_a2 = familyPredicted(strcmpi(e2,'Linear++'));

b1_sub = s2(strcmpi(e2,'Uncorr+'));
b1=WindowOutlierAF(strcmpi(e2,'Uncorr+'));
familyPredicted_b1 = familyPredicted(strcmpi(e2,'Uncorr+'));

c1_sub = s2(strcmpi(e2,'+Linear'));
c1=WindowOutlierAF(strcmpi(e2,'+Linear'));
familyPredicted_c1 = familyPredicted(strcmpi(e2,'+Linear'));

c2_sub = s2(strcmpi(e2,'++Linear'));
c2=WindowOutlierAF(strcmpi(e2,'++Linear'));
familyPredicted_c2 = familyPredicted(strcmpi(e2,'++Linear'));

Class = nan(size(AF,1),1);
for condi = 1:2
    if condi==1
        fprintf('\n** Linear+ and +Linear groups **\n');
        X = [a1 c1]'-[familyPredicted_a1 familyPredicted_c1]';
        subjectNums = [a1_sub c1_sub];
    else
        fprintf('\n** Linear++ and ++Linear groups **\n');
        X = [a2 c2]'-[familyPredicted_a2 familyPredicted_c2]';
        subjectNums = [a2_sub c2_sub];
    end
    clear AIC BIC
    for components = 1:2
        fprintf('Fitting %i-component GMM\n',components);
        GMModel{condi,components} = fitgmdist(X,components,'Options',statset('MaxIter',1000));
        AIC(components) = GMModel{condi,components}.AIC;
        BIC(components) = GMModel{condi,components}.BIC;
        fprintf('AIC = %.1f\n',AIC(components));
        fprintf('BIC = %.1f\n',BIC(components));
        for ci = 1:components
            %fprintf('Mu%i = %.2f, Sigma%i = %.2f\n',ci,GMModel{components}.mu(ci),GMModel{components}.Sigma(ci));
        end
        
    end
    [~,numComponents] = min(AIC);
    fprintf('\nBest fit by %i components.\n',numComponents);
    
    % Cluster assign
    mu = GMModel{condi,numComponents}.mu;
    sigma = GMModel{condi,numComponents}.Sigma;
    % Hard-cluster using the winning model
    class_assign = cluster(GMModel{condi,numComponents},X);
    % Switch class labels so learners are Class 2
    if mu(2) < mu(1)
        class_assign = mod(class_assign,2)+1;
        mu = mu([2 1]);
        sigma = sigma([2 1]);
    end
    for si = 1:length(subjectNums)
        Class(Subject==subjectNums(si)) = class_assign(si);
    end 
    fprintf('Cluster 1: %i non-learners\nCluster 2: %i learners\n',sum(class_assign==1),sum(class_assign==2));
end

nonLearnerSubjs = unique(Subject(Class==1));
learnerSubjs = unique(Subject(Class==2));

% Helper functions
mySlope = @(x,y) subsref( ([ones(length(x(~isnan(y))),1) x(~isnan(y))']\y(~isnan(y))'), struct('type','()','subs',{{2}}));
myIntercept = @(x,y) subsref( ([ones(length(x(~isnan(y))),1) x(~isnan(y))']\y(~isnan(y))'), struct('type','()','subs',{{1}}));

fprintf('\n** LABORATORY TASK: Revisiting Anticipatory Force @ End of Test, considering clusters **\n');
insideWindow = Block > 54 & Block <= 70;

clusterMembership = {ismember(Subject,nonLearnerSubjs) ismember(Subject,learnerSubjs)};
clusterNames = {'non-learners','learners'};
for clusti = 1:2
    e_select = clusterMembership{clusti};
    fprintf('\n** Considering only %s **\n',clusterNames{clusti});
    % Fit a linear function for each participant using OLS on Training Objects
    subset = e_select & insideWindow & ObjID~=3;
    groups = findgroups(Subject(subset),ExperimentName(subset));
    WindowSlopes = splitapply(mySlope,RF(subset),AF(subset),groups);
    WindowIntercepts = splitapply(myIntercept,RF(subset),AF(subset),groups);

    % Get the Anticipatory Force for the Test Object
    subset2 = e_select & insideWindow & ObjID==3;
    [groups2, s2, e2] = findgroups(Subject(subset2),ExperimentName(subset2));
    WindowOutlierAF = splitapply(@nanmean, AF(subset2), groups2);

    familyPredicted = WindowSlopes*0.9*9.81+WindowIntercepts;

    a11_sub = s2(strcmpi(e2,'Linear+'));
    a11 = WindowOutlierAF(strcmpi(e2,'Linear+'));
    a11_fp = familyPredicted(strcmpi(e2,'Linear+'));
    c11_sub = s2(strcmpi(e2,'+Linear'));
    c11 = WindowOutlierAF(strcmpi(e2,'+Linear'));
    c11_fp = familyPredicted(strcmpi(e2,'+Linear'));
    [~,~,CI,~] = ttest([a11 c11]);
    [~,~,CIpred,~] = ttest([a11_fp c11_fp]);
    fprintf('Linear+ & +Linear anticipatory force: %.2f N, 95%% CI = [%.2f, %.2f]\n',mean(CI),CI);
    fprintf('Linear+ & +Linear family-predicted weight of outlier: %.2f N, 95%% CI = [%.2f, %.2f]\n',mean(CIpred),CIpred);
    if(clusti==2)
        fprintf('Linear+ & +Linear actual outlier weight: %.2f N\n',1.2*9.81);
    end
    
    a22_sub = s2(strcmpi(e2,'Linear++'));
    a22 = WindowOutlierAF(strcmpi(e2,'Linear++'));
    a22_fp = familyPredicted(strcmpi(e2,'Linear++'));
    c22_sub = s2(strcmpi(e2,'++Linear'));
    c22 = WindowOutlierAF(strcmpi(e2,'++Linear'));
    c22_fp = familyPredicted(strcmpi(e2,'++Linear'));
    [~,~,CI,~] = ttest([a22 c22]);
    [~,~,CIpred,~] = ttest([a22_fp c22_fp]);
    fprintf('Linear++ & ++Linear anticipatory force: %.2f N, 95%% CI = [%.2f, %.2f]\n',mean(CI),CI);
    fprintf('Linear++ & ++Linear family-predicted weight of outlier: %.2f N, 95%% CI = [%.2f, %.2f]\n',mean(CIpred),CIpred);
    if(clusti==2)
        fprintf('Linear++ & ++Linear actual outlier weight: %.2f N\n',1.5*9.81);
    end

end