"""
Figure 4 S1
"""

import os
import sys
import pickle as pkl
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
import helper
import seaborn as sns

assert sys.version_info < (3, 0), "Please run this script using Python2.7"

def plot_stats(data_, p2_value_idx=None, random=False, filename=None):
    """

    :param data_: pickled data from simulation
    :param p2_value_idx:
    :param filename: where to save
    :return:
    """
    if random:
        p2_values = np.arange(0., 12.1, 1.)
    else:
        p2_values = np.arange(0.85, .901, 0.01)

    populations = ['E4', 'E5']
    stats = ['mean_rates', 'ccs_pearson']
    labels = ['spks/sec', 'Pearson CC']
    # plotting params
    helper.usetex_font()
    color_palette_red_s = sns.color_palette('OrRd', 6)
    color_palette_blue_s = sns.color_palette('Blues', 6)

    labelsize = 24
    ticksize = 18
    legendsize = 18
    lw = 3.

    figure = plt.figure(figsize=(5, 7))  # this is to match the raster plot size

    for stat_idx, stat in enumerate(stats):
        ax = plt.subplot2grid((2, 1), (stat_idx, 0))

        for pop_idx, base_pop in enumerate(populations):
            if base_pop[0] == 'I':
                continue
            for subpop_idx, subpop in enumerate(['E', 'I']):
                pop = subpop + base_pop[1]
                if not random:
                    mean = data_[pop][stat]['mean'][:, p2_value_idx]
                else:
                    mean = data_[pop][stat]['mean'][0, :]

                color_offset = 1 if pop[1] == '4' else 0
                if pop[0] == 'E':
                    color = color_palette_blue_s[int(pop[1]) - color_offset]
                else:
                    color = color_palette_red_s[int(pop[1]) - color_offset]
                lstyle = ['-', '-'][subpop_idx]

                ax.plot(p2_values, mean, lstyle, color=color, label=pop, lw=lw)

        lim = {
            'ccs_pearson': (0., 0.8),
            'mean_rates': (1., 16.),
            'corrected_rates': (1., 16.)
        }
        ax.set_ylim(*lim[stat])
        ax.yaxis.set_major_locator(MaxNLocator(4))
        if random:
            ax.xaxis.set_major_locator(MaxNLocator(5))
        else:
            ax.xaxis.set_major_locator(MaxNLocator(6))
        ax.set_ylabel(labels[stat_idx], fontsize=labelsize)
        ax.set_xlabel('m' if not random else r'$\nu_\mathrm{X}^{+}$ (spks/sec)', fontsize=labelsize)

        ax.grid(False)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.tick_params(axis='both', labelsize=ticksize, direction='out')

        if filename == 'fig4_s1_a.pdf':
            plt.legend(loc='best', prop={'size': legendsize}, handleheight=1.2, handlelength=2.6, frameon=False)

    figure.tight_layout()
    figure.subplots_adjust(hspace=0.35)
    figure.savefig(os.path.join('plots/', filename))
    print("Finished plotting Figure 4S1 - " + filename)


if __name__ == "__main__":
    with open('data/stats_baseline.pkl', 'r') as f:
        data = pkl.load(f)
    plot_stats(data, p2_value_idx=0, filename='fig4_s1_a.pdf')

    with open('data/stats_random.pkl', 'r') as f:
        data = pkl.load(f)
    plot_stats(data, p2_value_idx=0, filename='fig4_s1_b.pdf')

    with open('data/stats_with_noise.pkl', 'r') as f:
        data = pkl.load(f)
    plot_stats(data, p2_value_idx=10, random=True, filename='fig4_s1_c.pdf')