"""
Figures 1 c,d - performance over layers as a function of modularity
"""

import os
import sys
import pickle as pkl
import seaborn as sns
from matplotlib import pyplot as plt
import numpy as np

assert sys.version_info < (3, 0), "Please run this script using Python2.7"

color_palette = sns.cubehelix_palette(20, start=.5, rot=-.75)[::2]


def plot_performance(data, filename, percentage=False):
    """

    :param data: pickled data from simulation
    :param filename: where to save
    :param percentage: gain in %
    :return:
    """
    p1_values = np.around(np.arange(0.75, 1.01, 0.025), decimals=3)
    p2_values = data.keys()
    populations = sorted(data.values()[0].keys())

    ticksize = 42
    lw = 5.

    for p2_idx, p2v in enumerate(p2_values):
        figure = plt.figure(figsize=(int(len(p1_values) * .95), 7))

        ax = plt.subplot2grid((1, 1), (0, 0))
        for idx, pop in enumerate(populations):
            mean = np.array(data[p2v][pop])
            ax.plot(np.arange(len(p1_values)), mean, '-', color=color_palette[idx], label=pop, linewidth=lw)

        if not percentage:
            ax.plot(np.arange(len(p1_values)), [0.35] * len(p1_values), '-', color='tab:red', linewidth=lw)

        # plot m_switch manually
        ax.plot([11 * (0.83 - 0.75) / 0.25] * 2, [-60, 60.], '-', color='0.7', linewidth=lw)

        ax.grid(False)
        if not percentage:
            ax.set_ylim(0.1, 0.4)
            ax.set_yticks([0.15, 0.25, 0.35])
            ax.set_yticklabels([0.15, 0.25, 0.35])
        else:
            ax.set_ylim(-50, 50)
        ax.set_xticks(np.arange(len(p1_values)))
        ax.set_xticklabels(p1_values)
        ax.tick_params(axis='both', labelsize=ticksize, direction='out')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

        for axis in [ax.xaxis]:
            for label in axis.get_ticklabels()[1::2]:
                label.set_visible(False)

        figure.tight_layout()
        figure.savefig(os.path.join('plots/', filename))

if __name__ == "__main__":
    with open('data/fig_1_perf_modules_raw.pkl') as f:
        data_raw = pkl.load(f)
    with open('data/fig_1_perf_modules_pcnt.pkl') as f:
        data_pcnt = pkl.load(f)

    plot_performance(data_raw, 'fig1_c.pdf')
    plot_performance(data_pcnt, 'fig1_d.pdf', percentage=True)
