"""
Figure 5a - conductance based model
"""

import os
import sys
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np

assert sys.version_info < (3, 0), "Please run this script using Python2.7"


def plot_fig5_readout(data, sig_offset, shifts, channels, filename):
    """

    :param data: pickled data from simulation
    :param sig_offset: time offset relative to beginning of recorded interval
    :param shifts: best shift found after retraining
    :param channels: channel numbers to plot
    :param filename: where to save
    :return:
    """
    figure = plt.figure(figsize=(20., 3))
    sig_len = 20
    state_var = 'V_m'
    lw = 3.

    for pop_idx, pop in enumerate(['E0', 'E2', 'E5']):
        shift = shifts[pop_idx]
        r_output = data[pop][state_var]['output'][:, sig_offset + shift:sig_len * 200 + shift + sig_offset]
        r_target = data[pop][state_var]['target'][:, sig_offset + shift:sig_len * 200 + shift + sig_offset]

        assert r_target.shape == r_output.shape
        n_channels_target, n_data_points = r_target.shape
        for ch_idx, ch in enumerate(channels):
            ax = plt.subplot2grid((2, 3), (ch_idx, pop_idx))

            xticks = np.arange(0, n_data_points, 1)
            ax.plot(xticks, r_target[ch, :], '-', color='k', lw=lw)
            ax.plot(xticks, r_output[ch, :], '-', color='tomato', lw=lw)
            ax.set_ylim(-.5, 1.5)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)
            ax.grid(False)

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


def plot_raster(data_, filename):
    """

    :param data_: pickled data from simulation
    :param filename: where to save
    :return:
    """
    figure = plt.figure(figsize=(20., 3.))
    gs = GridSpec(1, 3)

    for idx, pop in enumerate(['E0', 'E2', 'E5']):
        axes_spikes = figure.add_subplot(gs[idx])
        times_exc = data_[pop][:, 0]
        times_inh = data_[pop.replace('E', 'I')][:, 0]
        neurons_exc = data_[pop][:, 1] + 2000
        neurons_inh = data_[pop.replace('E', 'I')][:, 1] - 8000
        c_exc = 'royalblue'
        c_inh = 'tab:red'

        axes_spikes.plot(times_exc, neurons_exc, '.', color=c_exc, markersize=1)
        axes_spikes.plot(times_inh, neurons_inh, '.', color=c_inh, markersize=1)
        n_neurons = 10000
        base = 100
        pop_idx = int(pop[1])
        axes_spikes.set(ylim=[pop_idx * n_neurons - 100, (pop_idx + 1) * n_neurons + 100],
                        xlim=[base * round(min(min(times_exc), min(times_inh)) / base) - 10,
                              base * round(max(max(times_exc), max(times_inh)) / base) + 10])

        axes_spikes.set_xticklabels([''])
        axes_spikes.set_yticklabels([''])
        axes_spikes.tick_params(labelsize=24)
        axes_spikes.tick_params(axis='x', which='both', bottom=False)
        axes_spikes.grid(False)
        axes_spikes.set_yticks([])
        axes_spikes.set_frame_on(False)

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_5a_readout_signal_m=0.9_GWN_sigma=3.0.pkl') as f:
        data_readout = pkl.load(f)
    with open('data/fig_5a_raster_m=0.9_GWN_sigma=3.0.pkl') as f:
        data_spikes = pkl.load(f)

    plot_fig5_readout(data_readout, 0, [13, 22, 30], [0, 9], 'fig5_a_readout.pdf')
    plot_raster(data_spikes, 'fig5_a_raster.pdf')
