"""
Figure 8b - bifurcation diagram for piecewise linear activation function
"""
import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

assert sys.version_info >= (3, 0), "Please run this script using Python3"
import numpy as np
import pickle as pkl

from mft_helpers import mft_parameters as parameters
from mft_helpers import calc
from mft_helpers import mft_plotting
from mft_helpers import siegert
from mft_helpers.siegert import piecewise_linear_approx as pla_
from matplotlib import pyplot as plt
import seaborn as sns

recompute = False  # change to run fixed point search again, otherwise use stored data
method = 'piecewise_linear'


if __name__ == "__main__":
    pars = parameters.set_parameter_space()  # create parameter set
    pars = calc.derived_parameters(pars)
    print(pars)
    print("\n\nCalculating bifurcation diagram of a piecewise linear activation function for nuX=12")

    stability_analysis = 'matrix'
    filename = 'data/fig_8b_data.pkl'

    pars['method'] = method
    pars['nuX'] = 12.
    modularity = np.arange(0.6, 0.991, 0.01)
    # modularity = np.arange(0.75, 0.85, 0.01)

    if recompute:
        ## compute self-consistent stationary state (firing rate, mean and SD of input; "working point")
        results = calc.fixed_points(pars, modularity_values=modularity,
                                    stability_analysis=stability_analysis, multithreaded=False)
        with open(filename, 'wb') as f:
            pkl.dump(results, f)
    else:
        with open(filename, 'rb') as f:
            results = pkl.load(f)

    fig, ax, fig_filename = mft_plotting.plot_fixed_points(
        results, pars, stability_analysis=stability_analysis, filename_suffix='_fig_8B_pcws_lin_bifurcation',
        save=False, loc_legend_stability='lower left', loc_legend_stimulated=None, figsize=(7, 4),
        legendsize=mft_plotting.fig_defaults['legendsize']['main_fp'] * 1.3)

    ax.plot([0.8169] * 2, [1e-1, 700], '-', c='tab:cyan', lw=3., zorder=-1)
    mft_plotting.sns.despine(ax=ax)
    ax.set_xlim(0.6, 1.)
    ax.set_ylim(1e-1, 700)
    ax.set_title(None)
    labelsize = mft_plotting.fig_defaults['labelsize']['main_fp'] * 1.2
    ticksize = mft_plotting.fig_defaults['ticksize']['main_fp'] * 1.2
    ax.tick_params('both', labelsize=ticksize)
    ax.yaxis.label.set_size(labelsize)
    ax.xaxis.label.set_size(labelsize)
    # plt.show()
    fig.savefig('{}/{}'.format(pars['figure_root_path'], 'fig8_b.pdf'))
    print('Finished plotting')
