"""
Figure 6a
"""
import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

assert sys.version_info < (3, 0), "Please run this script using Python2.7"

import pickle as pkl
import numpy as np
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import pyplot as plt
import seaborn as sns

import helper


def plot_2d_parscans(image_arrays=[], axis=[], fig_handle=None, labels=[], cmap='coolwarm', boundaries=[], limits=[],
                     interpolation='nearest', display=True, norm=None, ticksize=None, cbar_size='10%',
                     cbar_label=None, **kwargs):
    """
    Plots a list of arrays as images in the corresponding axis with the corresponding colorbar

    :return:
    """
    assert len(image_arrays) == len(axis), "Number of provided arrays must match number of axes"

    origin = 'upper'
    for idx, ax in enumerate(axis):
        if not isinstance(ax, matplotlib.axes.Axes):
            raise ValueError('ax must be matplotlib.axes.Axes instance.')
        else:
            vmin, vmax = None, None
            if limits:
                vmin, vmax = limits[idx]
            plt1 = ax.imshow(image_arrays[idx], aspect='auto', origin=origin, cmap=cmap, interpolation=interpolation,
                             vmin=vmin, vmax=vmax, norm=norm)
            if boundaries:
                cont = ax.contour(image_arrays[idx], boundaries[idx], origin='lower', colors='k', linewidths=2)
                plt.clabel(cont, fmt='%2.1f', colors='k', fontsize=12)
            if labels:
                ax.set_title(labels[idx])

            if fig_handle is not None:
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", cbar_size, pad="4%")
                cbar = fig_handle.colorbar(plt1, cax=cax)
                cbar.ax.tick_params(labelsize=24 if not ticksize else ticksize)
                if cbar_label:
                    cbar.ax.set_ylabel(cbar_label, fontsize=20)
            ax.set(**kwargs)
            plt.draw()
    return plt1


def plot_scan(data_, filename):
    """

    :param data_:
    :param filename:
    :return:
    """
    populations = ['E0', 'E5']

    interpolation = 'lanczos'
    # interpolation = 'spline36'
    width = 10.
    ticksize = 20
    labelsize= 24
    helper.usetex_font()
    # cbar_label = r'$\nu^\mathrm{S}_5 - \nu^\mathrm{S}_0$'
    cbar_label = r'$\nu^\mathrm{S}_5 / \nu^\mathrm{S}_0$'
    relative = True
    delta = [0.]
    Lambda = 0.05
    modularity = np.round(np.arange(0.72, .921, 0.01), decimals=3)
    method = "quotient"
    cmap = sns.diverging_palette(220, 20, as_cmap=True)

    sub_pop = 'active'
    for pop in populations:
        if pop == 'E0':
            continue
        pop_data = data_[pop][sub_pop]
        figure = plt.figure(figsize=(6, 6 * 0.7))

        if relative:
            if method == 'difference':
                pop_data = data_[pop][sub_pop] - data_['E0'][sub_pop]  # difference
                midpoint = 1 - 60. / (60. - (-10.))
            elif method == 'quotient':
                pop_data = data_[pop][sub_pop] / data_['E0'][sub_pop]
                midpoint = 1 - (np.max(pop_data) - 1.) / (np.max(pop_data) - np.min(pop_data))

            tmp_cmap = helper.shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
            norm = None
        else:
            figure = plt.figure(figsize=(6, 6 * 0.8))
            if sub_pop == 'active':
                if method == 'difference':
                    fr_midpoint = np.mean(data_['E0'][sub_pop])
                    midpoint = fr_midpoint / (pop_data.max() - pop_data.min())
                elif method == 'quotient':
                    fr_midpoint = np.mean(data_['E0'][sub_pop])
                    midpoint = fr_midpoint / (pop_data.max() - pop_data.min())
                tmp_cmap = helper.shiftedColorMap(cmap, midpoint=midpoint, name='shifted')
            else:
                tmp_cmap = cmap
            norm = None
            # norm = LogNorm(vmin=data.min(), vmax=data.max()) if sub_pop == 'active' else None

        ax = figure.add_subplot(111)
        plot_2d_parscans(image_arrays=[np.flipud(pop_data.T)], axis=[ax], fig_handle=figure, interpolation='lanczos',
                         labels=None, norm=norm, cmap=tmp_cmap, ticksize='x-large', cbar_size='5%',
                         cbar_label=cbar_label)
        ax.grid(False)
        ax.set_xticks(np.arange(len(modularity))[::4])
        ax.set_xticklabels(modularity[::4])
        ax.set_yticks([0, 3, 5, 8])
        ax.set_yticklabels([0.2, 0.5, 1.0, 2.5][::-1])
        ax.set_ylabel(r"Map size $(d)x$", fontsize=labelsize)
        ax.set_xlabel("m", fontsize=labelsize)

        ax.tick_params(axis='both', direction='out', labelsize=ticksize)
        figure.tight_layout()
        figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_6_rates_mapsize_scan.pkl', 'rb') as f:
        data = pkl.load(f)

    plot_scan(data, 'fig6_a.pdf')
