"""
Figure 8b
"""
import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

# assert sys.version_info >= (3, 0), "Please run this script using Python3"

import pickle as pkl
import numpy as np
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
import seaborn as sns

import helper
# from mft_helpers import mft_parameters as parameters
# from mft_helpers import calc


def plot_scan(data_, plot_labels=False, plot_legend=False, filename=None):
    """

    :param data_:
    :param filename:
    :return:
    """
    # figure parameters
    helper.usetex_font()
    lw = 4.
    ticklabelsize = 52
    labelsize = 58
    legend_labelsize = 34
    helper.usetex_font()

    figure = plt.figure(figsize=plt.figaspect(0.35))
    ax = plt.subplot2grid((1, 1), (0, 0))

    xticks, a_s1_fr, a_s1_fr_i, a_s2_fr, a_s2_fr_i, na_fr = data_

    cutoff = 300
    c_s1 = '#867CE8'
    c_s2 = '#48D1CC'
    ax.plot(xticks[:cutoff], na_fr[:cutoff], '-', color='k', label='non-stimulated', lw=lw)
    ax.plot(xticks[:cutoff], a_s1_fr[:cutoff], '-', color=c_s1, label='S1', lw=lw)
    ax.plot(xticks[:cutoff], a_s2_fr[:cutoff], '-', color=c_s2, label='S2', lw=lw)

    for ax_idx, ax in enumerate([ax]):
        ax.grid(False)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.tick_params(axis='both', labelsize=ticklabelsize, direction='out')
        ax.set_xticks([2000., 3000., 4000., 5000.])
        ax.set_xticklabels([2, 3, 4, 5])
        ax.set_ylabel(r'Rate $\nu_5$', fontsize=labelsize, labelpad=15)
        ax.yaxis.set_major_locator(MaxNLocator(4))

        if not plot_labels:
            ax.set_xticklabels([])
        if plot_legend:
            # ax.set_xlabel('Time (s)', fontsize=labelsize)
            lh = plt.legend(loc='upper left', prop={'size': legend_labelsize}, handleheight=1.2, handlelength=1.5)
            lh.set_alpha(1)
            for line in lh.get_lines():
                line.set_linewidth(7.0)
    sns.despine()

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_8_wta_fr_m=0.92_lratio=0.96.pkl', 'rb') as f:
        data = pkl.load(f)
    plot_scan(data, plot_labels=True, filename='fig8_b_wlc.pdf')

    with open('data/fig_8_wta_fr_m=0.82_lratio=0.86.pkl', 'rb') as f:
        data = pkl.load(f)
    plot_scan(data, filename='fig8_b_coex.pdf')

    with open('data/fig_8_wta_fr_m=0.9_lratio=0.5.pkl', 'rb') as f:
        data = pkl.load(f)
    plot_scan(data, plot_legend=True, filename='fig8_b_wta.pdf')