"""
Figure 5b - rate based model
"""

import os
import sys
import pickle as pkl
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib import ticker
from mpl_toolkits.axes_grid1 import make_axes_locatable, ImageGrid
import numpy as np

assert sys.version_info >= (3, 0), "Please run this script using Python3"


def plot_fig5_readout(filename):
    """

    :param filename: where to save
    :return:
    """
    figure = plt.figure(figsize=(20., 3))
    sig_len = 20
    sig_offset = 0
    shift = 0
    lw_tgt = 3.
    lw_out = 3.
    populations = ['E0', 'E2', 'E5']

    for pop_idx, pop in enumerate(populations):
        with open('data/fig_5b_readout_signal_crnn_pop={}.pkl'.format(pop), 'rb') as f:
            data = pkl.load(f)

        r_output = data['active']['output'][:, sig_offset + shift:sig_len * 200 + shift + sig_offset]
        r_target = data['active']['target'][:, sig_offset + shift:sig_len * 200 + shift + sig_offset]

        assert r_target.shape == r_output.shape
        n_channels_target, n_data_points = r_target.shape

        channels = [2, 9]
        n_channels = len(channels)

        for ch_idx, ch in enumerate(channels):
            ax = plt.subplot2grid((n_channels, len(populations)), (ch_idx, pop_idx))
            xticks = np.arange(0, n_data_points, 1)

            ax.plot(xticks, r_target[ch, :], '-', color='k', lw=lw_tgt)
            ax.plot(xticks, r_output[ch, :], '-', color='tomato', lw=lw_out)
            ax.set_ylim(-.5, 1.5)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_frame_on(False)
            ax.grid(False)

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


def plot_raster(filename):
    """

    :param filename: where to save
    :return:
    """
    figure = plt.figure(figsize=(20., 3.))
    gs = GridSpec(1, 3)

    for idx, pop in enumerate(['E0', 'E2', 'E5']):
        with open('data/fig_5b_state_activity_crnn_truncated_pop={}.pkl'.format(pop), 'rb') as f:
            data_ = pkl.load(f)

        ax = figure.add_subplot(gs[idx])
        v_min = 0.
        v_max = 1.
        image = ax.imshow(data_, aspect='auto', interpolation='nearest', vmin=v_min, vmax=v_max)

        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", "5%", pad="4%")
        cbar = figure.colorbar(image, cax=cax, ticks=[0., 0.5, 1.])
        # tick_locator = ticker.MaxNLocator(nbins=4)
        # cbar.locator = tick_locator
        # cbar.update_ticks()
        cbar.ax.tick_params(labelsize=22)

        ax.set_yticklabels([])
        ax.tick_params(axis='both', which='both', bottom=False, labelbottom=False, left=False, labelleft=False)
        ax.set_frame_on(False)

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    plot_fig5_readout('fig5_b_readout.pdf')
    plot_raster('fig5_b_activity.pdf')
