"""
Figure 8 S1 a
"""
import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

assert sys.version_info >= (3, 0), "Please run this script using Python3"
import numpy as np
import pickle as pkl

from mft_helpers import mft_parameters as parameters
from mft_helpers import calc
from mft_helpers import mft_plotting
from mft_helpers import siegert
from mft_helpers.siegert import piecewise_linear_approx as pla_
from matplotlib import pyplot as plt
import seaborn as sns


recompute = False  # change to run fixed point search again, otherwise use stored data
method = 'nu0_fb433'
round_rates = 'manual'  # previously working version


if __name__ == "__main__":
    pars = parameters.set_parameter_space()  # create parameter set
    pars = calc.derived_parameters(pars)

    mus = np.arange(-200,401,10)
    plas_baseline = np.zeros(len(mus))
    plas_extended = np.zeros(len(mus))
    plas_shifted = np.zeros(len(mus))

    siegert.experiment_pars = pars
    pla = lambda x:  pla_(pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], x,None)*1e3

    for m, mu in enumerate(mus):
        plas_baseline[m] = pla(mu)

    pcwlin_dyn_range_shift = 60
    # pars['pcwlin_mu_min'] += pcwlin_dyn_range_shift
    pars['pcwlin_mu_max'] += pcwlin_dyn_range_shift
    for m, mu in enumerate(mus):
        plas_extended[m] = pla(mu)

    pcwlin_dyn_range_shift = 60
    pars['pcwlin_mu_min'] += pcwlin_dyn_range_shift
    pars['pcwlin_mu_max'] += pcwlin_dyn_range_shift

    for m, mu in enumerate(mus):
        plas_shifted[m] = pla(mu)

    mft_plotting.usetex_font()
    fig = plt.figure(figsize=mft_plotting.fig_defaults['figsize']['main'])

    lw = 3.5
    labelsize = mft_plotting.fig_defaults['labelsize']['main_fp']
    ticksize = mft_plotting.fig_defaults['ticksize']['main_fp']
    legendsize = mft_plotting.fig_defaults['legendsize']['main_fp'] * 1.3

    plt.plot(mus, plas_baseline, '-', color='k', label='baseline', lw=lw)
    plt.plot(mus, plas_extended, '-', color='darkkhaki', label='extended', lw=lw)
    plt.plot(mus, plas_shifted, '-', color='peru', label='shifted', lw=lw)

    plt.xlabel(r'$\mu$', fontsize=labelsize)
    plt.ylabel(r'$\nu$', fontsize=labelsize)
    plt.tick_params('both', labelsize=ticksize)
    sns.despine()

    plt.legend(fontsize=legendsize)
    plt.tight_layout()
    # plt.show()
    fig.savefig(f"{pars['figure_root_path']}/fig8_s1_a.pdf")
