"""
Figure 3a - synaptic currents
"""

import os
import sys
import pickle as pkl
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
import helper

assert sys.version_info < (3, 0), "Please run this script using Python2.7"


def plot_fig3_currents_ei(data_, filename):
    """
    Main plot, stimulated / non-stimulated sub-populations.
    :param data_: pickled data from simulation
    :param filename: where to save
    :return:
    """
    helper.usetex_font()
    figure = plt.figure(figsize=(10, 11))

    ax_act = plt.subplot2grid((1, 1), (0, 0))
    axes = [ax_act]

    n_modules = 6
    p1_values = np.around(np.arange(0.8, 0.91, 0.02), decimals=3)
    p2v = 0.

    color_palette = sns.cubehelix_palette(20, start=.5, rot=-.75)[::2]
    ticksize = 28
    legendsize = 24

    for module_idx in range(0, 6):
        results_trials_act = data_[module_idx]['active']
        results_trials_non = data_[module_idx]['non_active']

        lw = 4
        tmp_res_active = [np.mean(results_trials_act[p1v]['I_syn_ex']) +
                          np.mean(results_trials_act[p1v]['I_syn_in']) for p1v in p1_values]
        tmp_res_nonactive = [np.mean(results_trials_non[p1v]['I_syn_ex']) +
                          np.mean(results_trials_non[p1v]['I_syn_in']) for p1v in p1_values]

        if module_idx != 0:
            c = 'tab:red' if module_idx == 0 else color_palette[module_idx]
            label = 'stimulated' if module_idx == n_modules - 1 else None
            ax_act.plot(np.arange(len(p1_values)), tmp_res_active, '-', linewidth=lw, color=c, label=label)
            label = 'non-stimulated' if module_idx == n_modules - 1 else None
            ax_act.plot(np.arange(len(p1_values)), tmp_res_nonactive, '--', linewidth=lw, color=c, label=label)
        else:
            ax_act.scatter([-0.2], [tmp_res_active[0]], marker='o', s=275, color=color_palette[0],
                           edgecolors='k', label=r'stimulated $\mathrm{SSN}_0$')
            ax_act.scatter([-0.2], [tmp_res_nonactive[0]], marker='o', s=275, color=color_palette[0],
                        linewidth=2., linestyle='--', edgecolors='k', label=r'non-stimulated $\mathrm{SSN}_0$')

        if module_idx == n_modules - 1:
            for ax_ in axes:
                ax_.set_xticks(np.arange(len(p1_values)))
                ax_.set_xticklabels([str(x)[:4] for x in p1_values.tolist()])
                ax_.set_xlim([-0.35, len(p1_values) - 1 + 0.3])
                ax_.ticklabel_format(style='sci', axis='y', scilimits=(0,0), useMathText=True)
                ax_.yaxis.get_offset_text().set_fontsize(24)
                ax_.tick_params(axis='both', labelsize=ticksize, direction='out')
                ax_.grid(False)

    ax_act.legend(loc='lower left', prop={'size': legendsize}, handlelength=3.5, frameon=False)
    for ax in axes:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


def plot_fig3_currents_merged_ei(data_, filename):
    """
    Inset plot

    :param data_: pickled data from simulation
    :param filename: where to save
    :return:
    """
    helper.usetex_font()
    figure = plt.figure(figsize=(6, 4))

    ax = plt.subplot2grid((1, 1), (0, 0))

    n_modules = 6
    p1_values = np.around(np.arange(0.8, 0.91, 0.02), decimals=3)
    p2v = 0.

    color_palette = sns.cubehelix_palette(20, start=.5, rot=-.75)[::2]
    ticksize = 2
    legendsize = 24

    for module_idx in range(0, n_modules):
        ax.plot(np.arange(len(p1_values)), [np.mean(data_[module_idx][p1v]['I_syn_ex']) +
                                            np.mean(data_[module_idx][p1v]['I_syn_in']) for p1v in p1_values],
                '-', color=color_palette[module_idx], linewidth=3.5)

        if module_idx == 0:
            ax.set_xticks([0, 5])
            ax.set_xticklabels([0.8, 0.9])
            ax.tick_params(axis='both', labelsize=ticksize)
            ax.grid(False)

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_3_synaptic_currents_GWN_sigma=0.0.pkl') as f:
        data = pkl.load(f)
    with open('data/fig_3_mergedEI_syn_currents_GWN_sigma=0.0.pkl') as f:
        data_merged = pkl.load(f)

    plot_fig3_currents_ei(data, 'fig3_a.pdf')
    plot_fig3_currents_merged_ei(data_merged, 'fig3_a_inset.pdf')
