% Cesanek et al. (2021) Figure 4 Source Code (MATLAB code - rename with .m suffix to run)
% Creates all figure subplots and runs all relevant reported analyses

clearvars
close all
clc

if exist('gramm','file')==0
    error('The gramm plotting library is required to run this code.\nYou can download it from https://github.com/piermorel/gramm).\n');
end

if ~exist('Figure4_SourceData1.txt','file')
    error('This analysis requires the following source data files: Figure4_SourceData1.txt.');
end
F4 = readtable('Figure4_SourceData1.txt');

Subject = F4.Subject';
ExperimentName = F4.ExperimentName';
Block = F4.Block';
ObjID = F4.ObjID';
AF = F4.AF';
RF = round(F4.RF,4)';
Phase = F4.Phase';

famcolors = [255 124  49;
             255  97  20;
             215  77  18;
             175  56  18;
             137  34  17]/255;

%% Fig 4, left: Anticipatory Force Timelines for LinearUp and LinearDown
[groups, ~, b, p, o, w, e] = findgroups(Subject,Block,Phase,ObjID,RF,ExperimentName);
AF_m = splitapply(@nanmean,AF,groups);
RF_m = splitapply(@unique,RF,groups);

close(figure(41));
figh = figure(41);

row_order = [2 1];
pixPerInch = 500/6.95;
WidthInches = 5;
HeightInches = 2.35;
figh.Position(1) = 0;
figh.Position(2) = 0;
figh.Position(3) = pixPerInch*WidthInches;
figh.Position(4) = pixPerInch*HeightInches;

clear g
g = gramm('x',b, 'y',AF_m, 'color',w, 'group',p, 'subset', o~=3);
g.set_layout_options('redraw',false,'margin_height',[.1 .02],'margin_width',[.07 .02],'gap',[0.02 0.02]);
g.no_legend();
g.set_names('x','','y','','color','','Column','','Row','');
g.axe_property('XLim',[-1 10*ceil(max(b-.5)/10)+1],'YLim',[4 17]);
g.set_order_options('row',row_order);
g.facet_grid(e,[]);
g.set_point_options('base_size',2);
g.set_line_options('base_size',0.5);
g.set_color_options('map',famcolors([1 2 4 5],:),'n_color',4,'n_lightness',1);
g.stat_summary('geom',{'area'},'width',1,'type','sem');
g.draw();
g.update('y',RF_m);
g.no_legend();
g.stat_summary('geom',{'line'});
g.set_line_options('styles',{':'},'base_size',0.5);
g.draw();
g.update('y',RF_m,'color',[],'subset',o==3);
g.no_legend();
g.set_point_options('base_size',3);
g.set_line_options('base_size',1);
g.set_color_options('map',famcolors(3,:),'n_color',1,'n_lightness',1);
g.stat_summary('geom',{'line'});
g.set_line_options('styles',{':'});
g.draw();
g.update('y',AF_m);
g.no_legend();
g.set_line_options('styles',{'-'});
g.stat_summary('geom',{'area','point'});
g.draw();

%% Fig 4, right: End of Test Phase averages for LinearUp and LinearDown
NumEndTestBins = 16;
% Data for the family
TimePeriods = Block>(max(Block)-NumEndTestBins);
subset = TimePeriods & ObjID~=3;
[groups, ~, e, ~, p] = findgroups(Subject(subset), ExperimentName(subset), ObjID(subset), Phase(subset));
AF_fam = splitapply(@nanmean,AF(subset),groups);
Mass_fam = splitapply(@nanmean,RF(subset)/9.81,groups);

% Helper functions
mySlope = @(x,y) subsref( ([ones(length(x),1) x']\y'), struct('type','()','subs',{{2}}));
myIntercept = @(x,y) subsref( ([ones(length(x),1) x']\y'), struct('type','()','subs',{{1}}));
% Fit a linear function for each participant using OLS on Training Objects
[groupsLR, ~, eLR] = findgroups(Subject(subset),ExperimentName(subset));
WindowSlopes = splitapply(mySlope,RF(subset)/9.81,AF(subset),groupsLR);
WindowIntercepts = splitapply(myIntercept,RF(subset)/9.81,AF(subset),groupsLR);
xpred = 0.5:0.01:1.3;
ypred = repmat(WindowIntercepts',1,length(xpred))+repmat(WindowSlopes',1,length(xpred)).*repmat(xpred,length(WindowSlopes),1);

% Data for the outlier (there won't be any for Training phases)
TimePeriods = Block>(max(Block)-NumEndTestBins);
subset2 = TimePeriods & ObjID==3;
[groups2, ~, e2, p2] = findgroups(Subject(subset2),ExperimentName(subset2),Phase(subset2));
WindowOutlierAF = splitapply(@nanmean, AF(subset2), groups2);
WindowOutlierIW = repmat(0.9,size(WindowOutlierAF));

close(figure(42));
figh = figure(42);

row_order = [2 1];
pixPerInch = 500/6.95;
WidthInches = 1.2;
HeightInches = 2.35;
figh.Position(1) = 0;
figh.Position(2) = 0;
figh.Position(3) = pixPerInch*WidthInches;
figh.Position(4) = pixPerInch*HeightInches;

% First draw the regression line for the family
g = gramm('x',xpred,'y',ypred);
g.set_layout_options('redraw',false,'margin_height',[.1 .02],'margin_width',[.15 .02],'gap',[0.02 0.02]);
g.set_order_options('row',row_order);
g.no_legend();
g.set_names('x','', 'y','', 'color','','Row','','Column','');
g.axe_property('XLim',[.5 1.3],'YLim',[4 17],'XTick',[.6 .75 .9 1.05 1.2],'TickLength',[0.03 0.03]);
g.set_line_options('base_size',1);
g.set_color_options('lightness',0,'chroma',0);
g.facet_grid(eLR,[]);
g.geom_abline('slope',9.81,'intercept',0,'style','k:'); % If Mass on x-axis, slope of unity line = +gravity
g.stat_summary('type','sem');
g.draw();

% Plot family data points over the regression line
% We force the color codes for Uncorr+ to be higher so we can set a different color ordering
g.update('x',Mass_fam,'y',AF_fam,'color',Mass_fam);
g.no_legend();
g.facet_grid(e,[]);
g.set_point_options('base_size',5);
g.set_color_options('map',famcolors([1 2 4 5],:),'n_color',8,'n_lightness',1);
g.stat_summary('geom',{'point'},'type','sem');
g.draw();
g.update();
g.no_legend();
g.set_line_options('base_size',1.5);
g.set_color_options('map',famcolors([1 2 4 5],:)*0.8,'n_color',8,'n_lightness',1);
g.stat_summary('geom',{'errorbar'},'type','sem');
g.draw();
% Plot the outlier average
g.update('x',WindowOutlierIW,'y',WindowOutlierAF);
g.no_legend();
g.set_point_options('base_size',5);
g.set_color_options('map',famcolors(3,:),'n_color',1,'n_lightness',1);
g.facet_grid(e2,[]);
g.stat_summary('geom',{'point'},'type','sem');
g.draw();
g.update();
g.no_legend();
g.set_line_options('base_size',1.5);
g.set_color_options('map',famcolors(3,:)*0.8,'n_color',1,'n_lightness',1);
g.stat_summary('geom',{'errorbar'},'type','sem','width',0);
g.draw();

% We want to draw lines for the correct outlier weight
RFs_byExp = [(1.5+1.425)/2 (1.2+1.275)/2]*9.81; % Y-axis units
exps = unique(e);
exps = exps(row_order);
phases = unique(p);

outlierX = .9; % X-axis units
lineWidth = 0.075;
xData = repmat(outlierX+[-1 1]*lineWidth,1,length(exps));
yData = repelem(RFs_byExp,2);
eData = repelem(exps,2); 
pData = repelem(phases,4); 
g.update('x',xData,'y',yData);
g.no_legend();
g.facet_grid(eData,pData);
g.set_color_options('map',famcolors(3,:),'n_color',1,'n_lightness',1);
g.set_line_options('base_size',1.5,'styles',{':'});
g.geom_line();
g.draw();

%% Analysis

% Map binary test outcomes onto appropriate text for console output
outcomes = {'not significant', 'significant'};

% Easier way to remember how the tails on paired t-tests work (see 'comparison' variable below) 
tails = {'both','left','right'};
comps = {'different','less','greater'};

%% Correlation of anticipatory force with object weight, end of training
NumEndTrainingBins = 8;
fprintf('\n\n** Correlation of Anticipatory Force and Object Weight @ End of Training (%i trial cycles) (Fig. 4a, c) **\n',NumEndTrainingBins);
for ee = {'LinearUp','LinearDown'}
    expName = ee{1};
    subset = strcmpi(ExperimentName,expName) & Phase==2 & Block>30-NumEndTrainingBins;
    [groups, ~] = findgroups(Subject(subset));
    EarlyCorrsFisherZ = splitapply(@(x,y)atanh(corr(x,y,'rows','pairwise')),RF(subset)',AF(subset)',groups');

    % Means and SEs
    LinearEarlyMeanCorr = tanh(mean(EarlyCorrsFisherZ));
    LinearEarlyMeanCorr_UB = tanh(mean(EarlyCorrsFisherZ)+1.96*std(EarlyCorrsFisherZ)/sqrt(length(EarlyCorrsFisherZ)));
    LinearEarlyMeanCorr_LB = tanh(mean(EarlyCorrsFisherZ)-1.96*std(EarlyCorrsFisherZ)/sqrt(length(EarlyCorrsFisherZ)));
    
    fprintf('\n%s: r = %.2f, 95%% CI = [%.2f, %.2f]\n', ...
        expName, ...
        LinearEarlyMeanCorr,LinearEarlyMeanCorr_LB,LinearEarlyMeanCorr_UB);
end

%% Average anticipatory force on first test trial
subset = Block==31 & ObjID==3;
[groups, ~, e] = findgroups(Subject(subset),ExperimentName(subset));
firstTestTrialOutlierAF = splitapply(@nanmean, AF(subset), groups);
[groups2, e2] = findgroups(e);
firstTestTrialMeans = splitapply(@nanmean, firstTestTrialOutlierAF, groups2);
firstTestTrialCIs = splitapply(@(x)(1.96*nanstd(x)/sqrt(sum(~isnan(x)))),firstTestTrialOutlierAF,groups2);
firstTestTrialLBs = firstTestTrialMeans-firstTestTrialCIs;
firstTestTrialUBs = firstTestTrialMeans+firstTestTrialCIs;

fprintf('\n\n** Average Anticipatory Force on First Test Trial (Fig. 2a, c, e) **\n');
for ee = {'LinearUp','LinearDown'}
    expName = ee{1};
    eii=find(strcmpi(e2,expName));
    fprintf('\n%s: %.2f N, 95%% CI = [%.2f, %.2f]\n', ...
        expName, ...
        firstTestTrialMeans(eii),firstTestTrialLBs(eii),firstTestTrialUBs(eii));
end

%% Test of Outlier Learning: AF minus Family-predicted weight
for ww = {'End of First Twenty Cycles of Test','End of Test'}
    window = ww{1};
    if strcmpi(window,'End of Test')
        NumLateTestBins = 16;
        insideWindow = Block>(max(Block)-NumLateTestBins) & Block<=max(Block);
    elseif strcmpi(window,'End of First Twenty Cycles of Test')
        NumLateTestBins = 4;
        insideWindow = Block>(50-NumLateTestBins) & Block<=50;
    end
    
    % Helper functions
    mySlope = @(x,y) subsref( ([ones(length(x(~isnan(y))),1) x(~isnan(y))']\y(~isnan(y))'), struct('type','()','subs',{{2}}));
    myIntercept = @(x,y) subsref( ([ones(length(x(~isnan(y))),1) x(~isnan(y))']\y(~isnan(y))'), struct('type','()','subs',{{1}}));

    fprintf('\n\n** Test for Outlier Learning @ %s (%i trial cycles) (Fig. 2a-f) **\n',window,NumLateTestBins);
    exps = {'LinearUp','LinearDown'};

    % Fit a linear function for each participant using OLS on Training Objects
    subset = ismember(ExperimentName,exps) & insideWindow & ObjID~=3;
    groups = findgroups(Subject(subset),ExperimentName(subset));
    WindowSlopes = splitapply(mySlope,RF(subset),AF(subset),groups);
    WindowIntercepts = splitapply(myIntercept,RF(subset),AF(subset),groups);

    % Get the Anticipatory Force for the Test Object
    subset2 = ismember(ExperimentName,exps) & insideWindow & ObjID==3;
    [groups2, s2, e2] = findgroups(Subject(subset2),ExperimentName(subset2));
    WindowOutlierAF = splitapply(@nanmean, AF(subset2), groups2);

    familyPredicted = WindowSlopes*0.9*9.81+WindowIntercepts;
    plusNull = 1.2*9.81;
    plusPlusNull = 1.5*9.81;

    d1_sub = s2(strcmpi(e2,'LinearUp'));
    d1=WindowOutlierAF(strcmpi(e2,'LinearUp'));
    familyPredicted_d1 = familyPredicted(strcmpi(e2,'LinearUp'));
    [~,~,CI,~] = ttest(d1);
    [~,~,CIpred,~] = ttest(familyPredicted_d1);
    fprintf('\nLinearUp anticipatory force: %.2f N, 95%% CI = [%.2f, %.2f]',mean(CI),CI);
    fprintf('\nLinearUp family-predicted weight: %.2f N, 95%% CI = [%.2f, %.2f]\n',mean(CIpred),CIpred);
    comparison = find(strcmpi(comps,'greater'));
    [H,P,~,STATS] = ttest(d1,familyPredicted_d1,'tail',tails{comparison});
    fprintf('AF %sly %s than family-predicted weight (t(%i) = %.2f, p = %.2g)\n',outcomes{H+1},comps{comparison}, STATS.df, STATS.tstat, P);

    d2_sub = s2(strcmpi(e2,'LinearDown'));
    d2=WindowOutlierAF(strcmpi(e2,'LinearDown'));
    familyPredicted_d2 = familyPredicted(strcmpi(e2,'LinearDown'));
    [~,~,CI,~] = ttest(d2);
    [~,~,CIpred,~] = ttest(familyPredicted_d2);
    fprintf('\nLinearDown anticipatory force: %.2f N, 95%% CI = [%.2f, %.2f]',mean(CI),CI);
    fprintf('\nLinearDown family-predicted weight: %.2f N, 95%% CI = [%.2f, %.2f]\n',mean(CIpred),CIpred);
    comparison = find(strcmpi(comps,'greater'));
    [H,P,~,STATS] = ttest(d2,familyPredicted_d2,'tail',tails{comparison});
    fprintf('AF %sly %s than family-predicted weight (t(%i) = %.2f, p = %.2g)\n',outcomes{H+1},comps{comparison}, STATS.df, STATS.tstat, P);
end

%% Comparing final learning in LinearUp with Linear++, and LinearDown with Linear+
% IMPORTANT: Remember to run the previous section with window = 'End of Test' immediately before running this section

if ~exist('Figure2_SourceData1.txt','file')
    error('This analysis requires the following source data files: Figure2_SourceData1.txt.');
end
F2 = readtable('Figure2_SourceData1.txt');
subset2 = F2.Block > 54 & F2.Block <= 70 & F2.ObjID==3;
[groups2, s2, e2] = findgroups(F2.Subject(subset2),F2.ExperimentName(subset2));
WindowOutlierAF = splitapply(@nanmean, F2.AF(subset2), groups2);
a1=WindowOutlierAF(strcmpi(e2,'Linear+'));
a2=WindowOutlierAF(strcmpi(e2,'Linear++'));

fprintf('\n\n** Direct comparison between the LinearDown (Up) and Linear+ (++) groups, respectively (End of Test Phase)\n');
comparison = find(strcmpi(comps,'different'));

[H,P,~,STATS] = ttest2(a2,d1,'tail',tails{comparison});
fprintf('\nLinearUp %sly %s than Linear++ (t(%i) = %.2f, p = %.2g)\n',outcomes{H+1},comps{comparison}, STATS.df, STATS.tstat, P);

[H,P,~,STATS] = ttest2(a1,d2,'tail',tails{comparison});
fprintf('\nLinearDown %sly %s than Linear+ (t(%i) = %.2f, p = %.2g)\n',outcomes{H+1},comps{comparison}, STATS.df, STATS.tstat, P);