"""
Figure 2c left panels, firing rates predicted by MFT.
"""

import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

assert sys.version_info >= (3, 0), "Please run this script using Python3"

import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns

from mft_helpers import mft_parameters as parameters
from mft_helpers import calc


if __name__ == "__main__":
    pars = parameters.set_parameter_space()  # create parameter set
    pars = calc.derived_parameters(pars)

    modularity = np.arange(0.78, 0.881, 0.01)
    lmbda = 0.05

    # compute firing rates in SSN0
    rates_ssn0 = calc.nu_first_layer(pars, lmbds=[lmbda])
    initial_rates = rates_ssn0[lmbda]
    results, modularity_values, limits = calc.nu_vs_layers(pars, modularity_values=modularity,
                                                           rates_initial_layer=initial_rates, max_layers=5)

    # restructure the results in a per population manner
    results_per_pop = {'stimulated': [[] for _ in range(6)], 'non_stimulated': [[] for _ in range(6)]}
    results_per_pop['stimulated'][0] = [initial_rates[0] for _ in range(len(modularity))]
    results_per_pop['non_stimulated'][0] = [initial_rates[1] for _ in range(len(modularity))]

    for pop_idx in range(1, 6):
        for m in modularity:
            results_per_pop['stimulated'][pop_idx].append(results[m][pop_idx - 1][0])
            results_per_pop['non_stimulated'][pop_idx].append(results[m][pop_idx - 1][1])

    # plotting params
    ticksize = 24
    axes = []  # list of axes, to set common params at the end

    fig = plt.figure(figsize=(7, 9))
    color_palette = sns.cubehelix_palette(20, start=.5, rot=-.75)[::2]

    for subpop_idx, (subpop, ylim) in enumerate(zip(['stimulated', 'non_stimulated'], [(0, 30), (-0.1, 6)])):
        ax = plt.subplot2grid((2, 1), (subpop_idx, 0))
        axes.append(ax)
        linestyle = '-' if subpop == 'stimulated' else '--'
        label = 'stim'
        for p_idx in range(6):
            ax.plot(modularity, results_per_pop[subpop][p_idx], linestyle, linewidth=4.,
                    color=color_palette[p_idx])
            ax.set_ylim(ylim)
            ax.set_xticks(modularity[::2])
            ax.tick_params('both', labelsize=ticksize, direction='out')
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.grid(False)
            if subpop_idx == 0:
                ax.set_yticks([0., 10., 20., 30.])

    fig.tight_layout()
    fig.subplots_adjust(hspace=0.34)
    fig.savefig("plots/fig2_c_theory.pdf")



