"""
Figure 9 - figure supplement 1 - Evolution of similarity score for 12 sub-networks
"""
import sys
import os
sys.path.append(os.environ.get('NEST_PYTHON_PREFIX'))
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')

import pickle as pkl
import numpy as np
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
import seaborn as sns

import helper

def plot_2d_parscans(image_arrays=[], axis=[], fig_handle=None, labels=[], cmap='coolwarm', boundaries=[], limits=[],
                     interpolation='nearest', display=True, norm=None, ticksize=None, cbar_size='10%',
                     cbar_label=None, **kwargs):
    """
    Plots a list of arrays as images in the corresponding axis with the corresponding colorbar

    :return:
    """
    assert len(image_arrays) == len(axis), "Number of provided arrays must match number of axes"

    cbar = None
    origin = 'upper'
    for idx, ax in enumerate(axis):
        if not isinstance(ax, matplotlib.axes.Axes):
            raise ValueError('ax must be matplotlib.axes.Axes instance.')
        else:
            vmin, vmax = None, None
            if limits:
                vmin, vmax = limits[idx]
            plt1 = ax.imshow(image_arrays[idx], aspect='auto', origin=origin, cmap=cmap, interpolation=interpolation,
                             vmin=vmin, vmax=vmax, norm=norm)
            if boundaries:
                cont = ax.contour(image_arrays[idx], boundaries[idx], origin='lower', colors='k', linewidths=2)
                plt.clabel(cont, fmt='%2.1f', colors='k', fontsize=12)
            if labels:
                ax.set_title(labels[idx])

            if fig_handle is not None:
                divider = make_axes_locatable(ax)
                cax = divider.append_axes("right", cbar_size, pad="4%")
                cbar = fig_handle.colorbar(plt1, cax=cax)
                cbar.ax.tick_params(labelsize=24 if not ticksize else ticksize)
                if cbar_label:
                    cbar.ax.set_ylabel(cbar_label, fontsize=20)
            ax.set(**kwargs)
            plt.draw()

    return cbar


def plot_scan(populations, filename=None):
    """

    :param populations:
    :param filename:
    :return:
    """
    # figure parameters
    helper.usetex_font()
    ticksize = 16
    labelsize = 22

    p1_values = np.round(np.arange(0.8, 1.001, 0.01).astype(np.float32), decimals=4)
    p2_values = np.round(np.arange(0.50, 1.01, 0.02).astype(np.float32), decimals=4)

    figure = plt.figure(figsize=(18, 4))

    limits = [(-1., 1.)]
    for pop_idx, pop in enumerate(populations):
        with open('data/fig_9_coex_regions_pop={}.pkl'.format(pop), 'rb') as f:
            data_ = pkl.load(f)
        score, flipped_s2na = data_

        ax_cc = plt.subplot2grid((1, len(populations)), (0, pop_idx))
        cbar = plot_2d_parscans(image_arrays=[score], axis=[ax_cc],
                                fig_handle=figure if pop_idx == len(populations) - 1 else None,
                                limits=limits, ticksize=16)

        plot_2d_parscans(image_arrays=[score], axis=[ax_cc], limits=limits, ticksize=12)

        for valid_val, marker in zip([1, 3], ['x', 'x']):
            x_s2na = np.where((flipped_s2na == valid_val) & (score >= 0.))[1]
            y_s2na = np.where((flipped_s2na == valid_val) & (score >= 0.))[0]
            ax_cc.scatter(x_s2na, y_s2na, s=50, c='k', marker='x')

        ax_cc.grid(False)
        ax_cc.tick_params(axis='both', labelsize=16)
        # ax_cc.set_xticks(np.arange(len(p1_values))[::4])
        ax_cc.set_xticks([0, len(p1_values) // 2, len(p1_values) - 1])
        ax_cc.set_yticks(np.arange(len(p2_values))[::5])

        ax_cc.grid(False)
        ax_cc.tick_params(axis='both', labelsize=ticksize)
        ax_cc.set_xticklabels(p1_values[[0, len(p1_values) // 2, len(p1_values) - 1]])
        ax_cc.set_xlabel('m', fontsize=labelsize)

        if pop_idx == 0:
            ax_cc.set_ylabel(r'Intensity ratio $\lambda_2 / \lambda_1$', fontsize=labelsize)
            ax_cc.set_yticklabels(p2_values[::-1][::5])
        else:
            # ax_cc.set_xticklabels([])
            ax_cc.set_yticklabels([])
            if pop_idx == len(populations) - 1:
                cbar.set_ticks([-1, -0.5, 0, 0.5, 1.])
                cbar.set_label('Similarity score', fontsize=16)

        # ax_cc.xaxis.set_major_locator(MaxNLocator(5))
        # sns.despine(ax=ax_cc)
    figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    populations = ['E0', 'E2', 'E5', 'E8', 'E11']
    plot_scan(populations, filename='fig9_s1.pdf')
