"""
Figure 2a, both top and bottom rows.
"""

import os
import sys
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec

assert sys.version_info < (3, 0), "Please run this script using Python2.7"


def plot_fig2_raster(data_075, data_1, filename):
    """

    :param data: pickled data from simulation
    :param filename: where to save
    :return:
    """
    figure = plt.figure(figsize=plt.figaspect(0.5))

    gs = GridSpec(2, 3, left=0.02, right=0.98, top=0.92, bottom=0.08, wspace=0.15, hspace=0.2)

    for row in range(2):
        data = [data_075, data_1][row]
        for idx, pop in enumerate(['E0', 'E2', 'E5']):
            axes_spikes = figure.add_subplot(gs[row*3 + idx])
            times_exc = data[pop][:, 0]
            times_inh = data[pop.replace('E', 'i')][:, 0]
            neurons_exc = data[pop][:, 1] + 2000
            neurons_inh = data[pop.replace('E', 'i')][:, 1] - 8000
            c_exc = 'royalblue'
            c_inh = 'tab:red'

            axes_spikes.plot(times_exc, neurons_exc, '.', color=c_exc, markersize=1)
            axes_spikes.plot(times_inh, neurons_inh, '.', color=c_inh, markersize=1)
            n_neurons = 10000
            base = 100
            pop_idx = int(pop[1])
            axes_spikes.set(ylim=[pop_idx * n_neurons - 100, (pop_idx + 1) * n_neurons + 100],
                            xlim=[base * round(min(min(times_exc), min(times_inh)) / base) - 10,
                                  base * round(max(max(times_exc), max(times_inh)) / base) + 10])

            axes_spikes.set_xticklabels([''])
            axes_spikes.set_yticklabels([''])
            axes_spikes.tick_params(labelsize=24)
            axes_spikes.tick_params(axis='x', which='both', bottom=False)
            axes_spikes.grid(False)
            axes_spikes.set_yticks([])
            axes_spikes.set_frame_on(False)

    # figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_2_raster_m=0.75_GWN_sigma=3.0.pkl') as f:
        data_a = pkl.load(f)
    with open('data/fig_2_raster_m=1.0_GWN_sigma=3.0.pkl') as f:
        data_b = pkl.load(f)

    plot_fig2_raster(data_a, data_b, 'fig2_a.pdf')
