%% style
clear; clc
set(0,'DefaultLineLineWidth',1.5,'DefaultTextFontWeight','bold', ...
    'DefaultTextFontSize',14,'DefaultAxesFontWeight','bold', ...
    'DefaultAxesFontSize',14, 'DefaultLineMarkerSize',10);
set(0,'DefaultAxesFontName', 'Arial')
rng(5)
%%
n_tpt = 10^4;
X = zeros(n_tpt, 6);
Y = zeros(n_tpt, 6);
T = zeros(n_tpt, 6);
[x,y,t] = strong_forcing(n_tpt);
X(:,1) = x; Y(:,1) = y; T(:,1) = t;
[x,y,t] = counterexample(n_tpt);
X(:,2) = x; Y(:,2) = y; T(:,2) = t;
[x,y,t] = oscillation(n_tpt);
X(:,3) = x; Y(:,3) = y; T(:,3) = t;
[x,y,t] = proxy(n_tpt);
X(:,4) = x; Y(:,4) = y; T(:,4) = t;
[y,x,t] = strong_forcing(n_tpt);
X(:,5) = x; Y(:,5) = y; T(:,5) = t;
[x,y,t] = commoncause(n_tpt);
X(:,6) = x; Y(:,6) = y; T(:,6) = t;

% save data
csvwrite('data/X.csv', X);
csvwrite('data/Y.csv', Y);
%% show dynamics
idx_show = 501 : 700;
fontsize = 18;

for sys = 1 : 6
    fig1 = figure;
    scatter(T(idx_show,sys), X(idx_show,sys), 'r.')
    hold on
    plot(T(idx_show,sys), X(idx_show,sys), 'k-', 'LineWidth',1)
    xlim([min(T(idx_show-1,sys)), max(T(idx_show,sys))])
    xticks([])
    yticks([])
    ylabel('X')
    set(gca,'LineWidth',2,'FontSize',fontsize,'FontName','Arial','fontweight','bold',...
        'units','inches','position',[2 2 3.4 0.85])
    saveas(fig1,sprintf('raw/figA%d.svg',sys) )
    
    fig1 = figure;    
    scatter(T(idx_show,sys), Y(idx_show,sys), 'b.')
    hold on
    plot(T(idx_show,sys), Y(idx_show,sys), 'k-', 'LineWidth',1)
    xlim([min(T(idx_show-1,sys)), max(T(idx_show,sys))])
    %xticks([])
    yticks([])
    xlabel('time')
    ylabel('Y')
    set(gca,'LineWidth',2,'FontSize',fontsize,'FontName','Arial','fontweight','bold',...
        'units','inches','position',[2 2 3.4 0.85])
    saveas(fig1,sprintf('raw/figB%d.svg',sys) )
end

%% show X map to Y
idx_show = zeros(6,2);
idx_show(:,1) = 501;
idx_show([1,3,4],2) = 2000;
idx_show(5,2) = 10^4;
idx_show(2,2) = 10^4;
idx_show(6,2) = 10^4;
tau = [2,2,72,2,1,1];
sz = [20, 20, 20, 20, 20, 20];

for sys = [1,4,5]
    X_ = X(idx_show(sys,1):idx_show(sys,2),:);
    Y_ = Y(idx_show(sys,1):idx_show(sys,2),:);
    fig2 = figure;
    v1 = X_(1:end-tau(sys),sys);
    v2 = X_(1+tau(sys):end,sys);
    v3 = Y_(1+tau(sys):end,sys);
    idx = randperm(length(v1));
    scatter(v1(idx), v2(idx), sz(sys), v3(idx),'filled')
    h = colorbar;
    title(h,'Y(t)')
    xlabel('X(t - \tau)')
    ylabel('X(t)')
    xticks([])
    yticks([])
    colormap gray; set(gca,'Color',[0.7 0.7 1]); set(gcf, 'InvertHardCopy', 'off');
    set(gca,'LineWidth',2,'FontSize',fontsize,'FontName','Arial','fontweight','bold',...
        'units','inches','position',[2 2 2 2])
    saveas(fig2,sprintf('raw/figC%d.svg',sys) )
    saveas(fig2,sprintf('raw/figC%d.png',sys) )
end

for sys = [2,3,6]
    X_ = X(idx_show(sys,1):idx_show(sys,2),:);
    Y_ = Y(idx_show(sys,1):idx_show(sys,2),:);
    fig2 = figure;
    v1 = X_(1:end-2*tau(sys),sys);
    v2 = X_(1+tau(sys):end-tau(sys),sys);
    v3 = X_(1+2*tau(sys):end,sys);
    v4 = Y_(1+2*tau(sys):end,sys);
    idx = randperm(length(v1));
    scatter3(v1(idx), v2(idx), v3(idx), sz(sys), v4(idx),'filled')
    h = colorbar;
    title(h,'Y(t)')
    xlabel('X(t - 2 \tau)')
    ylabel('X(t - \tau)')
    zlabel('X(t)')
    xticks([])
    yticks([])
    zticks([])
    if sys == 2
    hold on
    plot3([0.6, 0.4, 0.4, 0.6, 0.6],[0.3, 0.3, 0.3, 0.3, 0.3],...
          [0.35, 0.35, 0.9, 0.9, 0.35], 'r')
    view([-153.8 66.7])
    end
    if sys == 3
    view([107.7 12.6])
    end
    if sys == 6
    view([152.5 18.6]);
    xlim([0.163675716503682 0.91499051740526]);
    ylim([0.157075457694744 0.908390258596322]);
    zlim([0.119463400496835 0.870778201398412]);
    end
    colormap gray; set(gca,'Color',[0.7 0.7 1]); set(gcf, 'InvertHardCopy', 'off');
    set(gca,'LineWidth',2,'FontSize',fontsize,'FontName','Arial','fontweight','bold',...
        'units','inches','position',[2 2 2 2])
    print(fig2, sprintf('raw/figC%d.png',sys), '-dpng', '-r300')
    set(gcf, 'Renderer', 'painters');
    saveas(fig2,sprintf('raw/figC%d.svg',sys) )
    %saveas(fig2,sprintf('raw/figC%d.png',sys) )
    saveas(fig2,sprintf('raw/figC%d.pdf',sys) )
end

%% deal with system 2
sys = 2;
[x,y,t] = counterexample(6*10^5);
fig2 = figure;
v1 = x(1:end-2*tau(sys));
v2 = x(1+2*tau(sys):end);
v3 = y(1+tau(sys):end);
idx = randperm(length(v1));
scatter(v1(idx), v2(idx), 8, v3(idx),'filled')
xlabel('X(t - 2 \tau)')
ylabel('X(t)')
xlim([0.4, 0.6])
ylim([0.35, 0.9])
xticks([])
yticks([])
set(gca, 'XDir','reverse')
colormap gray; set(gca,'Color',[0.7 0.7 1]); set(gcf, 'InvertHardCopy', 'off');
set(gca,'LineWidth',2,'FontSize',fontsize,'FontName','Arial','fontweight','bold',...
    'units','inches','position',[2 2 1 1])
%saveas(fig2,sprintf('raw/figC_inset%d.png',sys) )
print(fig2, sprintf('raw/figC%d.png',sys), '-dpng', '-r500')
%%
close all


%% make data: we're going to map FROM X TO Y

function [x, y, t] = strong_forcing(n_tpt)
nextstate = @(x) [x(1) * (3.8 - 3.8 * x(1)), ...
                  x(2) * (3.1 - 3.1 * x(2) - 2 * x(1))];
t = 1 : n_tpt;
states = zeros(n_tpt, 2);
states(1,:) = [0.2, 0.4];
for i = 2 : n_tpt
    states(i, :) = nextstate(states(i-1,:));
end
x = states(:,1);
y = states(:,2);
end

function [x, y, t] = counterexample(n_tpt)
nextstate = @(x) [x(1) * (3.8 - 3.8 * x(1) - 0.15 * x(2)), ...
                  x(2) * (3.1 - 3.1 * x(2) - 0.15 * x(1))];
t = 1 : n_tpt;
states = zeros(n_tpt, 2);
states(1,:) = [0.2, 0.4];
for i = 2 : n_tpt
    states(i, :) = nextstate(states(i-1,:));
end
x = states(:,1);
y = states(:,2);
end

function [x, y, t] = proxy(n_tpt)
nextstate = @(x) [x(1) * (3.8 - 3.8 * x(1))];
t = 1 : n_tpt;
states = zeros(n_tpt, 1);
states(1,:) = 0.2;
for i = 2 : n_tpt
    states(i) = nextstate(states(i-1));
end
x = states;
y = [0; sin(10*states(1:end-1))];
end

function [x, y, t] = oscillation(n_tpt)
omega1 = 3/2;
omega2 = 4/2;

% system
xprime = @(x) [x(2), -(omega1)^6*x(1)^5];
xprime2 = @(x) [x(2), -(omega2)^6*x(1)^5];
xprime3 = @(x) x(1) - x(3);
xprime_all = @(x) [xprime(x(1:2)), xprime2(x(3:4)), xprime3(x)];

% numerical integration
tspan = (0.1 : 0.1 : 0.1 * n_tpt);
y0 = ones(1,5);
y0(1) = 0;
y0(3) = 0;
opts = odeset('RelTol',1e-6,'AbsTol',1e-6);
[t,X] = ode45(@(t,y) xprime_all(y)',tspan, y0, opts);
y = X(:,3);
x = X(:,5);
end

function [x, y, t] = commoncause(n_tpt)

delta = [1,1,3];

r = [3.69, 0   , 0;
     0.4 , 3.78, 0;
     0.33, 0   , 3.71];
 
X = zeros(n_tpt, 3);
X0 = [0.2, 0.3, 0.4];
for i = 1 : max(delta)
    X(i,:) = X0;
end

for i = max(delta)+1 : n_tpt
    X(i,1) = X(i-1,1) * ( r(1,1) - r(1,1) * X(i-1,1) );
    X(i,2) = X(i-1,2) * ( r(2,2) - r(2,2) * X(i-1,2) - r(2,1) * X(i-delta(2),1) );
    X(i,3) = X(i-1,3) * ( r(3,3) - r(3,3) * X(i-1,3) - r(3,1) * X(i-delta(3),1) );
end
x = X(:,3);
y = X(:,2);
t = 1 : n_tpt;
end