"""
Figures 1 b,e - readout signals
"""

import os
import sys
import pickle as pkl
from matplotlib import pyplot as plt
import numpy as np

assert sys.version_info < (3, 0), "Please run this script using Python2.7"


def plot_fig1_readout(data, sig_offset, shift, channels, filename):
    """

    :param data: pickled data from simulation
    :param sig_offset: time offset relative to beginning of recorded interval
    :param shift: best shift for retraining
    :param channels: channel numbers to plot
    :param filename: where to save
    :return:
    """
    figure = plt.figure(figsize=(20., 5))
    sig_len = 10
    state_var = 'V_m'

    for pop_idx, pop in enumerate(['E0', 'E2', 'E5']):
        r_output = data[pop][state_var]['output'][:, sig_offset + shift:sig_len * 200 + shift + sig_offset]
        r_target = data[pop][state_var]['target'][:, sig_offset + shift:sig_len * 200 + shift + sig_offset]

        assert r_target.shape == r_output.shape
        n_channels_target, n_data_points = r_target.shape
        for ch_idx, ch in enumerate(channels):
            ax = plt.subplot2grid((2, 3), (ch_idx, pop_idx))

            xticks = np.arange(0, n_data_points, 1)
            ax.plot(xticks, r_target[ch, :], '-', color='k', lw=5.)
            ax.plot(xticks, r_output[ch, :], '-', color='tomato', lw=5.)
            ax.set_ylim(-.5, 1.5)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)
            ax.grid(False)

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))

if __name__ == "__main__":
    with open('data/fig_1_readout_signal_m=0.75_GWN_sigma=0.0.pkl') as f:
        data_a = pkl.load(f)
    with open('data/fig_1_readout_signal_m=0.9_GWN_sigma=0.0.pkl') as f:
        data_b = pkl.load(f)

    plot_fig1_readout(data_a, 600, 14, [4, 7], 'fig1_b.pdf')
    plot_fig1_readout(data_b, 0, 12, [0, 4], 'fig1_e.pdf')
