"""
Figure 8c
"""
import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

import pickle as pkl
import numpy as np
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
import seaborn as sns

import helper
# from mft_helpers import mft_parameters as parameters
# from mft_helpers import calc


def plot_2d_parscans(image_arrays=[], axis=[], fig_handle=None, labels=[], cmap='coolwarm', boundaries=[], limits=[],
                     interpolation='nearest', display=True, norm=None, ticksize=None, cbar_size='10%',
                     cbar_label=None, **kwargs):
    """
    Plots a list of arrays as images in the corresponding axis with the corresponding colorbar

    :return:
    """
    assert len(image_arrays) == len(axis), "Number of provided arrays must match number of axes"

    origin = 'upper'
    for idx, ax in enumerate(axis):
        if not isinstance(ax, matplotlib.axes.Axes):
            raise ValueError('ax must be matplotlib.axes.Axes instance.')
        else:
            vmin, vmax = None, None
            if limits:
                vmin, vmax = limits[idx]
            plt1 = ax.imshow(image_arrays[idx], aspect='auto', origin=origin, cmap=cmap, interpolation=interpolation,
                             vmin=vmin, vmax=vmax, norm=norm)
            if boundaries:
                cont = ax.contour(image_arrays[idx], boundaries[idx], origin='lower', colors='k', linewidths=2)
                plt.clabel(cont, fmt='%2.1f', colors='k', fontsize=12)
            if labels:
                ax.set_title(labels[idx])

            if fig_handle is not None:
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", cbar_size, pad="4%")
                cbar = fig_handle.colorbar(plt1, cax=cax)
                cbar.ax.tick_params(labelsize=24 if not ticksize else ticksize)
                if cbar_label:
                    cbar.ax.set_ylabel(cbar_label, fontsize=20)
            ax.set(**kwargs)
            plt.draw()
    return cbar


def plot_scan(data_, filename=None):
    """

    :param data_:
    :param filename:
    :return:
    """
    score, coord_wlc, coord_wta, coord_cox = data_

    # figure parameters
    helper.usetex_font()
    ticksize = 18
    labelsize = 26

    p1_values = np.round(np.arange(0.8, 1.001, 0.01).astype(np.float32), decimals=4)
    p2_values = np.round(np.arange(0.40, 1.01, 0.02).astype(np.float32), decimals=4)

    figure = plt.figure(figsize=(7, 6))
    ax_cc = plt.subplot2grid((1, 1), (0, 0))

    limits = [(-1., 1.)]
    cbar = plot_2d_parscans(image_arrays=[score], axis=[ax_cc],
                            fig_handle=figure, limits=limits, ticksize=16)
    cbar.set_label('Similarity score', fontsize=16)

    # plot boundaries
    ax_cc.plot(coord_cox[0], coord_cox[1], '--', color='tab:red', linewidth=3.)
    ax_cc.plot(coord_wlc[0], coord_wlc[1], '--', color='royalblue', linewidth=3.)
    ax_cc.plot(coord_wta[0], coord_wta[1], '--', color='k', linewidth=3.)

    ax_cc.set_xticks(np.arange(len(p1_values))[::4])
    ax_cc.set_yticks(np.arange(len(p2_values))[::5])
    ax_cc.set_xticklabels(p1_values[::4])
    ax_cc.set_yticklabels(p2_values[::-1][::5])
    cbar.set_ticks([-1, -0.5, 0, 0.5, 1.])

    ax_cc.grid(False)
    ax_cc.tick_params(axis='both', labelsize=ticksize)
    ax_cc.set_xlabel('m', fontsize=labelsize)
    ax_cc.set_ylabel(r'Intensity ratio $\lambda_2 / \lambda_1$', fontsize=labelsize)

    sns.despine()

    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_8_dynamical_regimes_cc_score_boundaries.pkl', 'rb') as f:
        data = pkl.load(f)
    plot_scan(data, filename='fig8_c.pdf')
