"""
Figure 2 S1
"""

import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

assert sys.version_info >= (3, 0), "Please run this script using Python3"

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
import pickle as pkl

from mft_helpers import mft_parameters as parameters
from mft_helpers import calc

recompute = False

if __name__ == "__main__":
    pars = parameters.set_parameter_space()  # create parameter set
    pars = calc.derived_parameters(pars)

    modularity = np.round(np.arange(0.72, 0.921, 0.01), decimals=2)
    Lambda_range = np.round(np.arange(0.04, .141, 0.005), decimals=3)

    # compute firing rates in SSN0
    rates_ssn0 = calc.nu_first_layer(pars, lmbds=Lambda_range)

    try:
        if recompute:
            raise
        with open('mft_helpers/mft_data/rate_gains_intensity.pkl', 'rb') as f:
            rate_gains = pkl.load(f)
    except:
        # store rate gains for 2 modules, indexed by their number
        rate_gains = {3: np.zeros((len(Lambda_range), len(modularity))),
                      5: np.zeros((len(Lambda_range), len(modularity)))}

        for lmbda_idx, lmbda in enumerate(Lambda_range):
            initial_rates = rates_ssn0[lmbda]
            results, modularity_values, limits = calc.nu_vs_layers(pars, modularity_values=modularity,
                                                                   rates_initial_layer=initial_rates, max_layers=5)
            for m_idx, m in enumerate(modularity):
                rate_gains[3][lmbda_idx, m_idx] = results[m][2, 0] - results[m][1, 0]  # SSN0 is not stored, hence -1
                rate_gains[5][lmbda_idx, m_idx] = results[m][4, 0] - results[m][3, 0]  # SSN0 is not stored, hence -1

        with open('mft_helpers/mft_data/rate_gains_intensity.pkl', 'wb') as f:
            pkl.dump(rate_gains, f)

    # plotting params
    ticksize = 24
    axes = []  # list of axes, to set common params at the end

    fig = plt.figure(figsize=(12, 4))

    for idx, pop_idx in enumerate([3, 5]):
        ax = plt.subplot2grid((1, 2), (0, idx))
        cmap = sns.diverging_palette(220, 20, as_cmap=True)
        cmap_name = 'greenred'
        results_diff = np.flipud(rate_gains[pop_idx])
        argmaxY = np.argmax(results_diff > 0., axis=1)
        img = ax.imshow(results_diff, cmap=cmap, vmin=-3., vmax=30., interpolation='lanczos')
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", "10%", pad="4%")
        cbar = fig.colorbar(img, cax=cax, format='%d')
        cbar.ax.tick_params(labelsize=14)

        fontsize = 24
        ax.set_ylabel(r'$\lambda$', fontsize=fontsize)
        ax.set_xlabel('m', fontsize=fontsize)
        cbar.set_label(r'$\nu_{} - \nu_{}$'.format(pop_idx, pop_idx - 1), fontsize=20)

        ax.plot(argmaxY, np.arange(results_diff.shape[0]), '--', color='white', linewidth=2.)

        ax.set_xticks(np.arange(len(modularity))[::5])
        ax.set_yticks(np.arange(len(Lambda_range))[::4])

        ax.set_xticklabels(modularity[::5])
        ax.set_yticklabels(Lambda_range[::-1][::4])

        ax.grid(False)
        ax.tick_params(labelsize=16., direction='out')

    fig.tight_layout()
    fig.savefig("plots/fig2_s1.pdf")



