"""
Figure 6b
"""

import os
import sys
import pickle as pkl
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import helper
import seaborn as sns

assert sys.version_info < (3, 0), "Please run this script using Python2.7"


def plot_performance(data_, filename, percentage=False):
    """

    :param data_: pickled data from simulation
    :param filename: where to save
    :return:
    """
    p1_values = np.round(np.arange(0.80, 0.921, 0.01), decimals=2)
    p2_values = [0.01, 0.0625, .1, .125, .2]
    populations = ['E5']

    ticksize = 24
    labelsize = 26
    lw = 3.
    relative = True
    stat = 'nrmse'
    baseline_error = 0.35  # manually set here
    color_palette = sns.color_palette("ch:s=-.2,r=.6", n_colors=len(p2_values))

    p1_values = np.array(p1_values)
    if relative:
        baseline_E0_perf = data_['E0']

    for idx, pop in enumerate(populations):
        helper.usetex_font()
        figure = plt.figure(figsize=(6, 6 * 0.8))
        ax = plt.subplot2grid((1, 1), (0, 0))

        for p2_idx, p2v in enumerate(p2_values):
            ls = ':' if p2v == 0.01 or p2v == 0.2 else '-'
            color = color_palette[p2_idx]
            print("Result for pop {} and p2v: {} = {}".format(pop, p2v, data_[pop][:, p2_idx]))
            valid_p1_idx = np.where(~np.isnan(data_[pop][:, p2_idx]))[0]
            if not relative:
                ax.plot(p1_values, data_[pop][:, p2_idx], ls, color=color, label='d={}'.format(p2v), lw=lw)
                # chance
                ax.plot(p1_values, [baseline_error] * len(p1_values), ls, color='tab:red', linewidth=lw)
            else:
                tmp_baseline = baseline_E0_perf[valid_p1_idx, p2_idx]
                ax.plot(p1_values[valid_p1_idx], (tmp_baseline - data_[pop][valid_p1_idx, p2_idx]) / tmp_baseline * 100.,
                        ls, color=color, label='d={}'.format(p2v), lw=lw)

        if not relative:
            if stat == 'nrmse':
                ax.set_ylim(0.1, 0.42)
            else:
                ax.set_ylim(0.0, 0.3)

        ax.plot(p1_values, [0.] * len(p1_values), '--', color='tab:red', linewidth=lw)
        ax.set_ylabel('% gain', fontsize=labelsize)
        ax.set_xlabel('m', fontsize=labelsize)
        ax.set_xticks(np.round([0.8, 0.84, 0.88, 0.92], decimals=2))
        ax.yaxis.set_major_locator(MaxNLocator(4))
        ax.grid(False)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.tick_params(axis='both', labelsize=ticksize, direction='out')
        plt.legend(loc='upper left', prop={'size': 'x-large'}, handlelength=2.5)
        figure.tight_layout()
        figure.savefig(os.path.join('plots/', filename))


if __name__ == "__main__":
    with open('data/fig_6_performance_fixed_mapsize.pkl') as f:
        data = pkl.load(f)

    plot_performance(data, 'fig6_b.pdf')
