# Example: python3 benchmark.py 1 twodrivers seasonal 12 17
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pickle
from scipy.integrate import solve_ivp
from scipy.stats import kendalltau
import sys

# inputs
job_idx = int(sys.argv[1])
system_case = sys.argv[2]
surr_type = sys.argv[3]
if surr_type == 'seasonal':
    surr_period_x = sys.argv[4] # surrogate period for testing y causes x
    surr_period_y = sys.argv[5] # surrogate period for testing x causes y
else:
    surr_period_x = 1
    surr_period_y = 1
if surr_type == 'granger':
    import mvgc
else:
    import dtpy_ccm as ccm
assert system_case in ["onedriver", "twodrivers"]

altstats = False
if (len(sys.argv) > 6):
     if (sys.argv[6] == 'altstats'):
         altstats = True

# analysis parameters
lib = (1, 200)
pred = (1,200)
max_tau=6
max_E=6
use_criterion_4 = False

### Machinery for stochastic simulations
def simulate_stochastic_process(f, x0, t_eval, noise=0, seed=None,
                                require_nonneg=True):
    """
    Args:
        f (function): Must accept (t, x) and return the dxdt
        x0 (numpy.ndarray): initial values:
            * x0 is a 1d numpy array: dynamics are returned as 2d numpy array
              where first dim corresponds to species and second dim corresponds
              to time points
        t_eval (iterable): times for evaluation
        noise (float): level of relative Gaussian process noise

    Returns:
        timepoints
        dynamics
    """
    if type(seed) != type(None):
        np.random.seed(seed)
    x = [x0]
    for i in range(1, len(t_eval)):
        dt = t_eval[i] - t_eval[i-1]
        x_next = x[i-1] + f(t_eval[i-1], x[i-1]) * dt + np.random.normal(0, noise, size=len(x[i-1]))
        if require_nonneg:
            x_next[x_next < 0] = 0
        x.append(x_next)
    return np.array(x)

### Scoring results
classes = {(False,True) : "Correct",
           (False,False) : "Empty",
           (True,True) : "Full",
           (True,False) : "Opposite"}

def classify(A):
    return classes[(A[0,1] <= 0.05, A[1,0] <= 0.05)]

def score(pvals, valid=None):
    result_dict = {"Correct" : 0,
                   "Empty" : 0,
                   "Full" : 0,
                   "Opposite" : 0,
                   "Invalid" : 0}
    result = [classify(_) for _ in pvals]
    for i, cat in enumerate(result):
        if not(valid is None):
            if valid[i][0] > 0.05 or valid[i][1] > 0.05:
                result_dict["Invalid"] += 1
            else:
                result_dict[cat] += 1
        else:
            result_dict[cat] += 1
    for cat in result_dict:
        result_dict[cat] /= len(result)
    return result_dict

def get_criterion4_pval(AxB, n_boot=1000):
    """Assumes that all entries in AxB have max library size
    """
    horizons = np.unique(AxB['tp'])
    scores = np.array([np.mean(AxB[AxB['tp'] == _]['rho']) for _ in horizons])
    neg_horizons = horizons[horizons<0]
    best_neg = neg_horizons[np.argmax(scores[horizons<0])]
    pos_horizons = horizons[horizons>=0]
    best_pos = pos_horizons[np.argmax(scores[horizons>=0])]
    return ccm.EX_gt_EY(AxB[AxB['tp'] == best_neg].rho.values,
                        AxB[AxB['tp'] == best_pos].rho.values, n_boot=n_boot)

def get_p2_standard(df):
    rho_tinylib = df[df.lib_size==np.min(df.lib_size)].rho.values
    rho_largelib = df[df.lib_size==np.max(df.lib_size)].rho.values
    assert rho_largelib.size == rho_tinylib.size
    return 1 - np.sum(rho_largelib > rho_tinylib) / rho_tinylib.size

def get_p2_alternate(df):
        xy = df.groupby('lib_size', as_index=False).agg({"rho": "median"})
        kcorr, kpval_2t = kendalltau(xy.lib_size, xy.rho)
        kpval_1t = kpval_2t * int(kcorr > 0) / 2 + int(kcorr <= 0)
        return kpval_1t

def get_pvals(stochsol):
    # criterion 1: at max library size, rho values from real data are greater than from surrogate data
    AxB, AxB_surr = ccm.run_ccm(data = stochsol, data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
    			max_tau=max_tau, max_E=max_E, min_lib_size=lib[1], max_lib_size=lib[1],
    			num_samples=1, lib=lib, pred=pred, rand_libs=False,
    			prediction_horizon=[0], surr_test=True, surr_period=surr_period_x,
    			n_surr_datasets=999, n_surr_samples=1, surr_type = surr_type,
    			script_filename = "script.R", remove_temp_files=True)

    BxA, BxA_surr = ccm.run_ccm(data = stochsol[:,(1,0)], data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
    			max_tau=max_tau, max_E=max_E, min_lib_size=lib[1], max_lib_size=lib[1],
    			num_samples=1, lib=lib, pred=pred, rand_libs=False,
    			prediction_horizon=[0], surr_test=True, surr_period=surr_period_y,
    			n_surr_datasets=999, n_surr_samples=1, surr_type = surr_type,
    			script_filename = "script.R", remove_temp_files=True)

    p1_AxB = (np.sum(AxB_surr.rho >= AxB.rho[0]) + 1) / (AxB_surr.rho.size + 1)
    p1_BxA = (np.sum(BxA_surr.rho >= BxA.rho[0]) + 1) / (BxA_surr.rho.size + 1)

    # run ccm for criteria 2
    AxB = ccm.run_ccm(data = stochsol, data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
    			max_tau=max_tau, max_E=max_E, min_lib_size=15, max_lib_size=-1,
    			lib_size_step=-1, num_samples=1000, lib=lib, pred=pred,
    			prediction_horizon=[0], surr_test=False, remove_temp_files=True)

    BxA = ccm.run_ccm(data = stochsol[:,(1,0)], data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
    			max_tau=max_tau, max_E=max_E, min_lib_size=15, max_lib_size=-1,
    			lib_size_step=-1, num_samples=1000, lib=lib, pred=pred,
    			prediction_horizon=[0], surr_test=False, remove_temp_files=True)

    # criterion 2: rho values at max library size are greater than at min library size
    if not altstats:
        p2_AxB = get_p2_standard(AxB)
        p2_BxA = get_p2_standard(BxA)
    else:
        AxB = ccm.run_ccm(data=stochsol[:,(0,1)], data_filename = "data.csv",
                          temp_folder = f"tmp{job_idx}",
                          max_E=max_E, max_tau=max_tau, prediction_horizon=[0],
                          rand_libs=True, min_lib_size=15, max_lib_size=-1,
                          lib_size_step=3, num_samples=50, replace=False,
                          remove_temp_files=True)
        BxA = ccm.run_ccm(data=stochsol[:,(1,0)], data_filename = "data.csv",
                          temp_folder = f"tmp{job_idx}",
                          max_E=max_E, max_tau=max_tau, prediction_horizon=[0],
                          rand_libs=True, min_lib_size=15, max_lib_size=-1,
                          lib_size_step=3, num_samples=50, replace=False,
                          remove_temp_files=True)
        p2_AxB = get_p2_alternate(AxB)
        p2_BxA = get_p2_alternate(BxA)

    # criterion 3: rho values at max library size are greater than 0
    AxB = ccm.run_ccm(data = stochsol, data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
                max_tau=max_tau, max_E=max_E, min_lib_size=-1, max_lib_size=-1,
                lib_size_step=-1, num_samples=1, lib=lib, pred=pred,
                prediction_horizon=[0], surr_test=False, remove_temp_files=True,
                rand_libs=False)
    BxA = ccm.run_ccm(data = stochsol[:,(1,0)], data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
                max_tau=max_tau, max_E=max_E, min_lib_size=-1, max_lib_size=-1,
                lib_size_step=-1, num_samples=1, lib=lib, pred=pred,
                prediction_horizon=[0], surr_test=False, remove_temp_files=True,
                rand_libs=False)
    p3_AxB = int(AxB.rho.values.item() <= 0)
    p3_BxA = int(BxA.rho.values.item() <= 0)

    # criterion 4: predictions into the past have higher rho than predictions into the future
    p4_AxB = 0
    p4_BxA = 0
    if use_criterion_4:
        AxB = ccm.run_ccm(data = stochsol[:,(0,1)], data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
            max_tau=max_tau, max_E=max_E, min_lib_size=lib[1], max_lib_size=lib[1],
        	lib_size_step=1, num_samples=300, lib=lib, pred=pred,
        	prediction_horizon = np.arange(-36,37,6), surr_test=False,
            remove_temp_files=True)
        p4_AxB = get_criterion4_pval(AxB)
        BxA = ccm.run_ccm(data = stochsol[:,(1,0)], data_filename = "data.csv", temp_folder = f"tmp{job_idx}",
            max_tau=max_tau, max_E=max_E, min_lib_size=lib[1], max_lib_size=lib[1],
        	lib_size_step=1, num_samples=300, lib=lib, pred=pred,
        	prediction_horizon = np.arange(-36,37,6), surr_test=False,
            remove_temp_files=True)
        p4_BxA = get_criterion4_pval(BxA)

    pvals = np.zeros([2,2]);
    pvals[0,1] = np.max([p1_AxB, p2_AxB, p3_AxB]) # B affects A = AxB, ignoring p4
    pvals[1,0] = np.max([p1_BxA, p2_BxA, p3_BxA]) # B affects A = AxB, ignoring p4
    pvals[0,0] = np.nan
    pvals[1,1] = np.nan
    pval_details = {'AxB' : (p1_AxB, p2_AxB, p3_AxB, p4_AxB),
                    'BxA' : (p1_BxA, p2_BxA, p3_BxA, p4_BxA)}
    return pvals, AxB.tau[0], AxB.E[0], BxA.tau[0], BxA.E[0], pval_details

def check_ev_valid(stochsol):
    # checks to see whether embedding dimension is valid
        AxA = ccm.run_ccm(data = stochsol[:,(0,0)], data_filename = "data.csv",
                    temp_folder = f"tmp{job_idx}",
        			max_tau=max_tau, max_E=max_E, min_lib_size=lib[1], max_lib_size=lib[1],
        			num_samples=300, lib=lib, pred=pred,
        			prediction_horizon=[1], remove_temp_files=True)
        pval_A = np.sum(AxA.rho <= 0) / AxA.rho.size

        BxB = ccm.run_ccm(data = stochsol[:,(1,1)], data_filename = "data.csv",
                    temp_folder = f"tmp{job_idx}",
        			max_tau=max_tau, max_E=max_E, min_lib_size=lib[1], max_lib_size=lib[1],
        			num_samples=300, lib=lib, pred=pred,
        			prediction_horizon=[1], remove_temp_files=True)
        pval_B = np.sum(BxB.rho <= 0) / BxB.rho.size
        return pval_A, pval_B

### Set up dynamics
r0 = 0.2   # intrinsic growth rate of A
r1 = 0.1   # intrinsic growth rate of B
a0 = -0.1  # self-inhibition of A
a1 = -0.2  # self-inhibition of B
b = 0.3    # affect of A on B
eps0 = 0.05  # affect of environmental driver on A
eps1 = 0.1  * int(system_case == "twodrivers")# affect of environmental driver on B
x0=[2, 4.5]# initial values of A and B

### Systematic simulations
def run_condition(sigp, sigm, n_trials=100, show_first=True,
                  t_eval_=np.arange(0, pred[1]*2, 1), x0=[2,4.5]):
    data = [] # sigm=0, sigp=0
    seeds = []
    for trial in range(n_trials):
        t_eval = t_eval_
        # run simulation
        seed = n_trials*job_idx+trial
        seeds.append(seed)
        np.random.seed(seed+1000)
        theta0 = np.random.uniform(0, 2*np.pi)
        theta1 = np.random.uniform(0, 2*np.pi)
        def env(t):
            return ( np.sin(t * 5 / 6 + theta0) + np.sin(t + theta0),
                     np.sin(t/np.sqrt(10) + theta1))
        def lhs(t, x):
            return np.array([x[0] * ( r0 + a0*x[0] + eps0*env(t)[0]) ,
                             x[1] * ( r1 + a1*x[1] + eps1*env(t)[1] + b*x[0])])
        stochsol = simulate_stochastic_process(lhs, x0, t_eval,
                                               noise=(sigp, sigp * 2.5), seed=seed)
        stochsol = stochsol + np.random.normal(0, (sigm / 1.5, sigm), size=stochsol.shape)
        # "burn" off first half of time series to promote stationarity
        start_idx = int(len(t_eval) / 2)
        t_eval = t_eval[start_idx:]
        stochsol = stochsol[start_idx:, :]
        data.append(stochsol)
        if trial == 0 and show_first and (sigp==0 or sigm==0):
            filename = f"{system_case}_sigp_{sigp}_sigm_{sigm}".replace('.',',') + ".csv"
            np.savetxt(filename, stochsol)
    # run CCM inference
    pvals = []
    embed = []
    valid = []
    pval_details = []
    if surr_type == "granger":
        pvals, valid, info = mvgc.get_mvgc_pvals(data,
                             temp_folder=f"tmp{job_idx}", remove_temp=False)
    else:
        for _ in data:
            result_ = get_pvals(_)
            pvals.append(result_[0])
            embed.append(result_[1:-1])
            valid.append(check_ev_valid(_))
            pval_details.append(result_[-1])
    return pvals, score(pvals, valid=valid), seeds, embed, valid, pval_details

n_trials = 10

RES = {"sigp" : [],
          "sigm" : [],
          "Correct" : [],
          "Empty" : [],
          "Full" : [],
          "Opposite" : [],
          "Invalid" : []}

details = {}
for sigp in [0, 0.05, 0.15, 8]:
    for sigm in [0, 0.15, 1]:
        # now we are returning "pval_details" for CCM
        pvals, scores, seeds, embed , valid, pval_details = run_condition(sigp,
                                    sigm, n_trials=n_trials, show_first=False)
        RES["sigp"].append(sigp)
        RES["sigm"].append(sigm)
        for cat in scores:
            RES[cat].append(scores[cat])
        deets = {'pvals' : pvals, 'seeds' : seeds, 'embed' : embed,
                 'valid' : valid, 'pval_details' : pval_details}
        details[(sigp,sigm)] = deets
df = pd.DataFrame.from_dict(RES)
df.to_csv(f'data{job_idx}.csv')
pickle.dump( details, open( f"details{job_idx}.pickle", "wb" ) )
