import matplotlib.pyplot
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.signal import argrelextrema, argrelmin
import os
import h5py_wrapper as h5w

from mft_helpers.data_io import dataset_energy
from mft_helpers import calc
from pathos.multiprocessing import ProcessingPool as PathosPool

mod_colors = {
    'purple': '#867CE8',
    'fgreen': 'forestgreen',
    'fgreen_comp': '#8B228B',
    'msgreen': 'mediumseagreen',
    'msgreen_comp': '#B33C7E',
    'g1': '#5FA075',
    'g1_comp': '#A05F8A',
    'nu_first_layer_low': '',
    'nu_first_layer_standard': '',
    'nu_first_layer_high': '',
    'tetradic_set': ['tab:red', '#D926B7', 'cornflowerblue', 'mediumseagreen']
}

fig_defaults = {
    'figsize': {
        'main': (6, 4),
        'zoomed': (3, 2),
        'grid_cross_sec': (15, 10),
    },

    'ticksize': {
        'main_fp': 14
    },

    'labelsize': {
        'main_fp': 18
    },
    'legendsize': {
        'main_fp': 12
    }
}


def plot_nu_vs_layers_fig2(results, pars, rates_initial_layer):
    """
    [Fig 2d] Plots the firing rates as a function of the number of layers. Custom colormaps and plotting ranges.
    """
    colors = ['twilight_shifted']
    color_offset = 3  # increase the number of colors when generating palette, to increase range and reduce extremities
    lw = 4.
    ticksize = 24

    for cname in colors:
        # trim unwanted m values
        results = {m: v for m, v in results.items() if 0.75 <= m <= 0.9}
        n_colors = len(results.keys()) + color_offset * 2

        if cname == 'div':
            # colors = sns.diverging_palette(220, 20, n=n_colors, center='dark')
            colors = sns.diverging_palette(250, 15, s=75, l=40, sep=1, n=n_colors, center="light")
        elif cname == 'circular':
            colors = sns.color_palette(n_colors=n_colors)
        else:
            colors = sns.color_palette(cname, n_colors)#[::-1]

        colors = colors[color_offset:]

        fig = plt.figure(figsize=(11, 9))
        for active in [True, False]:
            ax = plt.subplot2grid((2, 1), (1 - int(active), 0))

            for idx, (m, nu) in enumerate(results.items()):
                nu = nu[:50]
                if active:
                    ax.plot(range(1,len(nu)+1),nu[:,0],label='m='+str(np.round(m,2)),color=colors[idx], lw=lw)
                else:
                    ax.plot(range(1,len(nu)+1),nu[:,1], '--', label='m='+str(np.round(m,2)),color=colors[idx], lw=lw)

            ax.set_yscale('log')
            if active:
                ax.set_ylim([6e-1,5e2])
            else:
                ax.set_ylim([2e-1,5e2])

            ax.tick_params(axis='both', labelsize=ticksize)
            sns.despine(ax=ax)

        fig.tight_layout()
        plt.subplots_adjust(hspace=0.34)
        fig.savefig('plots/fig2_d.pdf')