"""
Figure 8c - Analytically derived bounds on modularity, baseline network
"""
import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

assert sys.version_info >= (3, 0), "Please run this script using Python3"
import numpy as np
import pickle as pkl

from mft_helpers import mft_parameters as parameters
from mft_helpers import calc
from mft_helpers import mft_plotting
from mft_helpers import siegert
from mft_helpers.siegert import piecewise_linear_approx as pla_
from matplotlib import pyplot as plt
import seaborn as sns

recompute = False  # change to run fixed point search again, otherwise use stored data
plot_fp_dots = False  # whether to plots FPs as points
method = 'piecewise_linear'


if __name__ == "__main__":
    pars = parameters.set_parameter_space()  # create parameter set
    pars = calc.derived_parameters(pars)
    pars['g'] = 12

    stability_analysis = 'matrix'
    filename = 'data/fig_8c_data.pkl'

    nuXs = np.arange(-100., 100.1, 0.1)
    modularity_range = np.arange(0., 1.001, 0.01)
    sparsify = 1
    N_C = 10.  # number of clusters / stimuli

    if recompute:
        results = calc.fixed_points_piecewise_linear(pars, nuXs, modularity_range, N_C=10., multithreaded=False)
        with open(filename, 'wb') as f:
            pkl.dump(results, f)
    else:
        with open(filename, 'rb') as f:
            results = pkl.load(f)

    fig = plt.figure(figsize=mft_plotting.fig_defaults['figsize']['main'])
    ax = fig.add_subplot(111)
    print("Plotting analytical boundaries...")
    lw = 3.
    theoretical_boundaries_cyan = [calc.compute_theoretical_boundaries_pcws_linear(pars, x, N_C)[1] for x in nuXs]
    ax.plot(nuXs, [calc.compute_theoretical_boundaries_pcws_linear(pars, x, N_C)[0] for x in nuXs], color='mediumpurple', lw=lw)
    ax.plot(nuXs, [calc.compute_theoretical_boundaries_pcws_linear(pars, x, N_C)[1] for x in nuXs], color='tab:cyan', lw=lw)
    ax.plot(nuXs, [calc.compute_theoretical_boundaries_pcws_linear(pars, x, N_C)[2] for x in nuXs], color='green', lw=lw)

    print(f"low bound on m for nuX = 12: {calc.compute_theoretical_boundaries_pcws_linear(pars, 12., N_C)[1]}")

    print("Plotting FP markers...")
    # mark edges
    stable_m_max = 0
    stable_m_min = 1e3
    stable_nuX_min = None
    unstable_m_max = 0
    unstable_m_min = 1e3
    unstable_nuX_min = None
    label_stable_non_saturated = True
    label_unstable_non_saturated = True
    # c_stable = 'k'
    c_stable = 'silver'
    c_unstable = 'tab:red'

    unstable_non_saturated_set = [[], [], []]
    unstable_saturated_set = [[], [], []]
    # stable_saturated_set = [[], [], []]
    stable_non_saturated_set = [[], [], []]

    # plot markers for stable / unstable fixed points - double check shading / filling, assert on fail
    for idx_nuX, nuX in enumerate(nuXs):
        if idx_nuX % sparsify:
            continue
        stable_non_saturated = modularity_range[np.array(results[idx_nuX]['fp_non_saturated']) == 1]
        unstable_non_saturated = modularity_range[np.array(results[idx_nuX]['fp_non_saturated']) == 0]
        tmp_label = None
        if len(stable_non_saturated):
            stable_m_max = max(stable_m_max, max(stable_non_saturated))
            stable_m_min = min(stable_m_min, min(stable_non_saturated))
            if stable_nuX_min is None:
                stable_nuX_min = nuX
            tmp_label = 'stable (non-saturated)' if label_stable_non_saturated else None
            label_stable_non_saturated = False

            stable_non_saturated_set[0].append(min(stable_non_saturated))
            stable_non_saturated_set[1].append(max(stable_non_saturated))
            stable_non_saturated_set[2].append(nuX)

        if plot_fp_dots:
            ax.scatter([nuX] * len(stable_non_saturated[::sparsify]), stable_non_saturated[::sparsify],
                       color=c_stable, s=0.4, label=tmp_label)

        tmp_label = None
        if len(unstable_non_saturated):
            unstable_m_max = max(unstable_m_max, max(unstable_non_saturated))
            unstable_m_min = min(unstable_m_min, min(unstable_non_saturated))
            if unstable_nuX_min is None:
                unstable_nuX_min = nuX
            tmp_label = 'unstable (non-saturated)' if label_unstable_non_saturated else None
            label_unstable_non_saturated = False

            unstable_non_saturated_set[0].append(min(unstable_non_saturated))
            unstable_non_saturated_set[1].append(max(unstable_non_saturated))
            unstable_non_saturated_set[2].append(nuX)

        if plot_fp_dots:
            ax.scatter([nuX] * len(unstable_non_saturated[::sparsify]), unstable_non_saturated[::sparsify],
                       color=c_unstable, s=0.4, label=tmp_label)


        stable_saturated = modularity_range[np.array(results[idx_nuX]['fp_saturated']) == 1]
        unstable_saturated = modularity_range[np.array(results[idx_nuX]['fp_saturated']) == 0]

        if plot_fp_dots:
            ax.scatter([nuX] * len(stable_saturated), stable_saturated, color='k', s=5., marker='h')
            # actually, all saturated FPs are stable when they exist
            ax.scatter([nuX] * len(unstable_saturated[::sparsify]), unstable_saturated[::sparsify], color='r', s=5., marker='h')

    valid_indices = (np.abs(nuXs[:,None] - np.array(stable_non_saturated_set[2])) < 1e-6).any(1)
    ax.fill_between(stable_non_saturated_set[2], stable_non_saturated_set[0],
                    np.maximum(stable_non_saturated_set[1], np.array(theoretical_boundaries_cyan)[valid_indices]),
                    lw=0., facecolor=c_stable, label='stable (non-saturated)')

    valid_indices = (np.abs(nuXs[:,None] - np.array(unstable_non_saturated_set[2])) < 1e-6).any(1)
    ax.fill_between(unstable_non_saturated_set[2],
                    np.minimum(unstable_non_saturated_set[1], np.array(theoretical_boundaries_cyan)[valid_indices]),
                    unstable_non_saturated_set[1],
                    lw=0., facecolor=c_unstable, label='unstable (non-saturated)', alpha=0.3)

    # first, manually fill in background color
    plt.rcParams['hatch.linewidth'] = 1.5
    ax.fill_between(nuXs, [calc.compute_theoretical_boundaries_pcws_linear(pars, x, N_C)[1] for x in nuXs],
                    [1.] * len(nuXs), hatch='/', edgecolor='k', lw=0., facecolor="none", label='stable (saturated)')

    ax.set_ylim([0.,1.])
    # ax.set_xlabel(r'external background input $\nu_X$')
    # ax.set_ylabel('modularity')
    ax.set_xlabel(r'$\nu_X$ (external background input)', fontsize=mft_plotting.fig_defaults['labelsize']['main_fp'])
    ax.set_ylabel('m', fontsize=mft_plotting.fig_defaults['labelsize']['main_fp'])
    ax.set_xticks([-100,-75,-50,-25,0,25,50,75,100])
    ax.set_xticklabels(['-100','-75','-50','-25','0','25','50','75','100'])
    ax.tick_params(axis='both', labelsize=mft_plotting.fig_defaults['ticksize']['main_fp'])
    legend = ax.legend(loc='lower right', fontsize=mft_plotting.fig_defaults['legendsize']['main_fp']*1.2,
                       handlelength=3., frameon=False)

    for idx, handle in enumerate(legend.legendHandles):
        if idx < 2:
            try:
                handle.set_sizes([22.0])
            except:
                pass

    sns.despine(ax=ax)
    print("Saving figure...")
    # plt.show()
    fig.tight_layout()
    fig.savefig(f"{pars['figure_root_path']}/fig8_c.pdf")

