"""
Figures 1 S1 - performance scan over noise and modularity
"""


import os
import sys
import pickle as pkl
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
import helper

assert sys.version_info < (3, 0), "Please run this script using Python2.7"

color_palette = sns.cubehelix_palette(20, start=.5, rot=-.75)[::2]


def plot_grid(data_error, data_rel, filename, conductance=False):
    """

    :param data_error: pickled data from simulation
    :param filename: where to save
    :return:
    """
    p1_values = np.around(np.arange(0.75, 1.01, 0.025), decimals=3)
    p2_values = np.around(np.arange(0., 3.01, 0.5), decimals=3)
    populations = ['E0', 'E2', 'E5']

    lim_min = 0.1 if not conductance else 0.
    lim_max = 0.4 if not conductance else 0.2
    ticksize = 22
    cbar_ticksize = 18
    fontsize = 28
    helper.usetex_font()

    if not conductance:
        figure = plt.figure(figsize=(int(len(p1_values) * .5 * 3), 4.))
    else:
        figure = plt.figure(figsize=(int(len(p1_values) * 1.1), 4))

    gs = GridSpec(1, 3, left=0.07, right=0.63, top=0.92, bottom=0.22, wspace=0.15)
    for idx, pop in enumerate(populations):
        # ax = plt.subplot2grid((1, 5), (0, idx))
        ax = figure.add_subplot(gs[idx])
        plt1 = ax.imshow(np.flipud(data_error[pop].T), aspect='auto', origin='upper', cmap='coolwarm',
                         interpolation='nearest', vmin=lim_min, vmax=lim_max, norm=None)
        ax.scatter(
            np.where(np.flipud(data_error[pop].T) >= 0.35)[1],
            np.where(np.flipud(data_error[pop].T) >= 0.35)[0], s=32, c='k', marker='x')

        ax.grid(False)
        ax.set_xticks([0, len(p1_values) // 2, len(p1_values) - 1])
        ax.set_yticks(np.arange(len(p2_values))[::2])
        ax.set_xticklabels([p1_values[0], p1_values[len(p1_values) // 2], p1_values[-1]], fontsize=ticksize)

        if idx == 0:
            ax.set_ylabel(r'$\mathrm{noise}\ \sigma_\xi$', fontsize=fontsize)
            ax.set_xlabel('m', fontsize=fontsize)
            ax.set_yticklabels(p2_values[::-1][::2], fontsize=ticksize)
        else:
            ax.set_yticklabels([])

        if idx == 2:
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("right", '10%', pad="4%")
            cbar = figure.colorbar(plt1, cax=cax)
            cbar.set_ticks([0.1, 0.2, 0.3, 0.4])
            cbar.ax.tick_params(labelsize=cbar_ticksize)
            cbar.ax.set_ylabel('NRMSE', fontsize=20)

    # # fig S1 b
    gs = GridSpec(1, 1, left=0.77, right=0.93, top=0.92, bottom=0.22, wspace=0.05)
    lim_max = np.max(data_rel)
    lim_min = np.min(data_rel)

    # ax = plt.subplot2grid((1, 4), (0, 3))
    ax = figure.add_subplot(gs[0])
    cmap = ListedColormap(sns.color_palette("BrBG_r", 256))

    img_2d = ax.imshow(np.flipud(data_rel.T), aspect='auto', origin='upper', cmap=cmap,
                       interpolation='nearest', vmin=lim_min, vmax=lim_max, norm=None)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", '10%', pad="4%")
    cbar = figure.colorbar(img_2d, cax=cax)
    cbar.ax.tick_params(labelsize=cbar_ticksize)
    cbar.ax.set_ylabel('Performance gain (%)', fontsize=20)
    ax.set_ylabel(r'$\mathrm{noise}\ \sigma_\xi$', fontsize=fontsize)
    ax.set_xlabel('m', fontsize=fontsize)
    ax.grid(False)
    ax.set_xticks([0, len(p1_values) // 2, len(p1_values) - 1])
    ax.set_yticks(np.arange(len(p2_values))[::2])
    ax.set_xticklabels([p1_values[0], p1_values[len(p1_values) // 2], p1_values[-1]], fontsize=ticksize)
    ax.set_yticklabels(p2_values[::-1][::2], fontsize=ticksize)

    # figure.tight_layout()
    figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_1s1_performance2d.pkl') as f:
        data_perf_error = pkl.load(f)
    with open('data/fig_1s1_performance2d_relative.pkl') as f:
        data_perf_rel = pkl.load(f)

    plot_grid(data_perf_error, data_perf_rel, 'fig1s1.pdf')
