%% style
clear; clc
set(0,'DefaultLineLineWidth',1.5,'DefaultTextFontWeight','bold', ...
    'DefaultTextFontSize',14,'DefaultAxesFontWeight','bold', ...
    'DefaultAxesFontSize',14, 'DefaultLineMarkerSize',10);
set(0,'DefaultAxesFontName', 'Arial')
rng(5)
fontsize = 18;
f = @(x) nanmean(x);
%% show separate
for sys = 1 : 3
    % x map y in green
    fig1 = figure;
    T = readtable(sprintf('data/sys%d_separate_XmapY.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    RHO = zeros(n_sizes,3);
    for i = 1 : n_sizes
        rho = T.rho(T.lib_size == lib_sizes(i));
        RHO(i,1) = f(rho);
        RHO(i,2:3) = bootci(1000, {f,rho}, 'alpha', 0.05, 'type', 'bca');
    end
    plot(lib_sizes, RHO(:,1), 'g')
    hold on
    for i = 1 : n_sizes
        plot([lib_sizes(i), lib_sizes(i)], [RHO(i,2), RHO(i,3)], 'g')
    end
    scatter(lib_sizes, RHO(:,1), 40, 'g', 'filled')
    % y map x in magenta
    T = readtable(sprintf('data/sys%d_separate_YmapX.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    RHO = zeros(n_sizes,3);
    for i = 1 : n_sizes
        rho = T.rho(T.lib_size == lib_sizes(i));
        RHO(i,1) = f(rho);
        RHO(i,2:3) = bootci(1000, {f,rho}, 'alpha', 0.05, 'type', 'bca');
    end
    plot(lib_sizes, RHO(:,1), 'm')
    hold on
    for i = 1 : n_sizes
        plot([lib_sizes(i), lib_sizes(i)], [RHO(i,2), RHO(i,3)], 'm')
    end
    scatter(lib_sizes, RHO(:,1), 40, 'm', 'filled')
    % other stuff
    xlabel('training data size')
    ylabel('cross map skill')
    xticks([0, 150, 300])
    ylim([-1, 1])
    set(gca,'LineWidth',2,'FontSize',fontsize,'FontName','Arial','fontweight','bold',...
        'units','inches','position',[2 2 2 2])
    title(sprintf('system %d, separate train/test splits\n"y cause x" in green\n"x cause y" in purple',sys))
    saveas(fig1,sprintf('raw/resA%d.svg',sys) )
end

%% show interspersed
for sys = 1 : 3
    % x map y in green
    fig1 = figure;
    T = readtable(sprintf('data/sys%d_interspersed_XmapY.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    RHO = zeros(n_sizes,3);
    for i = 1 : n_sizes
        rho = T.rho(T.lib_size == lib_sizes(i));
        RHO(i,1) = f(rho);
        RHO(i,2:3) = bootci(1000, {f,rho}, 'alpha', 0.05, 'type', 'bca');
    end
    plot(lib_sizes, RHO(:,1), 'g')
    hold on
    for i = 1 : n_sizes
        plot([lib_sizes(i), lib_sizes(i)], [RHO(i,2), RHO(i,3)], 'g')
    end
    scatter(lib_sizes, RHO(:,1), 40, 'g', 'filled')
    % y map x in magenta
    T = readtable(sprintf('data/sys%d_interspersed_YmapX.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    RHO = zeros(n_sizes,3);
    for i = 1 : n_sizes
        rho = T.rho(T.lib_size == lib_sizes(i));
        RHO(i,1) = f(rho);
        RHO(i,2:3) = bootci(1000, {f,rho}, 'alpha', 0.05, 'type', 'bca');
    end
    plot(lib_sizes, RHO(:,1), 'm')
    hold on
    for i = 1 : n_sizes
        plot([lib_sizes(i), lib_sizes(i)], [RHO(i,2), RHO(i,3)], 'm')
    end
    scatter(lib_sizes, RHO(:,1), 40, 'm', 'filled')
    % other stuff
    xlabel('training data size')
    ylabel('cross map skill')
    xticks([0, 150, 300])
    ylim([-1, 1])
    set(gca,'LineWidth',2,'FontSize',fontsize,'FontName','Arial','fontweight','bold',...
        'units','inches','position',[2 2 2 2])
    title(sprintf('system %d, interspersed train/test\n"y cause x" in green\n"x cause y" in purple',sys))
    saveas(fig1,sprintf('raw/resB%d.svg',sys) )
end

%% check for nan
for sys = 1 : 3
    T = readtable(sprintf('data/sys%d_separate_XmapY.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    nans = zeros(size(lib_sizes));
    for i = 1 : n_sizes
        nans(i) = sum(isnan(T.rho(T.lib_size == lib_sizes(i))));
    end
    disp(max(nans))
end
%%
for sys = 1 : 3
    T = readtable(sprintf('data/sys%d_separate_YmapX.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    nans = zeros(size(lib_sizes));
    for i = 1 : n_sizes
        nans(i) = sum(isnan(T.rho(T.lib_size == lib_sizes(i))));
    end
    disp(max(nans))
end

for sys = 1 : 3
    T = readtable(sprintf('data/sys%d_interspersed_XmapY.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    nans = zeros(size(lib_sizes));
    for i = 1 : n_sizes
        nans(i) = sum(isnan(T.rho(T.lib_size == lib_sizes(i))));
    end
    disp(max(nans))
end

for sys = 1 : 3
    T = readtable(sprintf('data/sys%d_interspersed_YmapX.csv', sys) );
    lib_sizes = unique(T.lib_size);
    n_sizes = length(lib_sizes);
    nans = zeros(size(lib_sizes));
    for i = 1 : n_sizes
        nans(i) = sum(isnan(T.rho(T.lib_size == lib_sizes(i))));
    end
    disp(max(nans))
end