"""
Figures 10 c,f (bottom panels) - readout signals
"""

import os
import sys
import pickle as pkl
from matplotlib import pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler

assert sys.version_info < (3, 0), "Please run this script using Python2.7"


def plot_readout(data, sig_offset, shift, filename):
    """

    :param data: pickled data from simulation
    :param sig_offset: time offset relative to beginning of recorded interval
    :param shift: best shift for retraining
    :param filename: where to save
    :return:
    """
    figure = plt.figure(figsize=(40, 8))
    state_var = 'V_m'

    lw_out = 5.
    lw_tgt = 6.
    sig_len = 5000  # in ms

    for pop_idx, pop in enumerate(['E0', 'E2', 'E5']):
        r_output = data[pop][state_var]['output']
        r_target = data[pop][state_var]['target'][-sig_len:]

        r_output = MinMaxScaler(feature_range=(0, 1)).fit_transform(r_output.reshape(-1, 1)).reshape(-1)
        r_output = r_output[-sig_len:]

        ax = plt.subplot2grid((1, 3), (0, pop_idx))
        ax.plot(r_output, '-', color='tomato', lw=lw_out)
        ax.plot(r_target, '-', color='k', lw=lw_tgt)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_frame_on(False)
        ax.grid(False)
        ax.tick_params(axis='both', labelsize=28)

    figure.tight_layout()
    figure.subplots_adjust(wspace=0.1)
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_10c_fast_dynamic_readout.pkl') as f:
        data_a = pkl.load(f)
    with open('data/fig_10f_slow_dynamic_readout.pkl') as f:
        data_b = pkl.load(f)

    plot_readout(data_a, 0, 0, 'fig10_c_readout.pdf')
    plot_readout(data_b, 0, 0, 'fig10_f_readout.pdf')
