import matplotlib.pyplot
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.signal import argrelextrema, argrelmin
import os
import h5py_wrapper as h5w

from mft_helpers.data_io import dataset_energy
from mft_helpers import calc
from pathos.multiprocessing import ProcessingPool as PathosPool

mod_colors = {
    'purple': '#867CE8',
    'fgreen': 'forestgreen',
    'fgreen_comp': '#8B228B',
    'msgreen': 'mediumseagreen',
    'msgreen_comp': '#B33C7E',
    'g1': '#5FA075',
    'g1_comp': '#A05F8A',
    'nu_first_layer_low': '',
    'nu_first_layer_standard': '',
    'nu_first_layer_high': '',
    'tetradic_set': ['tab:red', '#D926B7', 'cornflowerblue', 'mediumseagreen']
}

fig_defaults = {
    'figsize': {
        'main': (6, 4),
        'zoomed': (3, 2),
        'grid_cross_sec': (15, 10),
    },

    'ticksize': {
        'main_fp': 14
    },

    'labelsize': {
        'main_fp': 18
    },
    'legendsize': {
        'main_fp': 12
    }
}


def plot_nu_vs_layers_fig2(results, pars, rates_initial_layer):
    """
    [Fig 2d] Plots the firing rates as a function of the number of layers. Custom colormaps and plotting ranges.
    """
    colors = ['twilight_shifted']
    color_offset = 3  # increase the number of colors when generating palette, to increase range and reduce extremities
    lw = 4.
    ticksize = 24

    for cname in colors:
        # trim unwanted m values
        results = {m: v for m, v in results.items() if 0.75 <= m <= 0.9}
        n_colors = len(results.keys()) + color_offset * 2

        if cname == 'div':
            # colors = sns.diverging_palette(220, 20, n=n_colors, center='dark')
            colors = sns.diverging_palette(250, 15, s=75, l=40, sep=1, n=n_colors, center="light")
        elif cname == 'circular':
            colors = sns.color_palette(n_colors=n_colors)
        else:
            colors = sns.color_palette(cname, n_colors)#[::-1]

        colors = colors[color_offset:]

        fig = plt.figure(figsize=(11, 9))
        for active in [True, False]:
            ax = plt.subplot2grid((2, 1), (1 - int(active), 0))

            for idx, (m, nu) in enumerate(results.items()):
                nu = nu[:50]
                if active:
                    ax.plot(range(1,len(nu)+1),nu[:,0],label='m='+str(np.round(m,2)),color=colors[idx], lw=lw)
                else:
                    ax.plot(range(1,len(nu)+1),nu[:,1], '--', label='m='+str(np.round(m,2)),color=colors[idx], lw=lw)

            ax.set_yscale('log')
            if active:
                ax.set_ylim([6e-1,5e2])
            else:
                ax.set_ylim([2e-1,5e2])

            ax.tick_params(axis='both', labelsize=ticksize)
            sns.despine(ax=ax)

        fig.tight_layout()
        plt.subplots_adjust(hspace=0.34)
        fig.savefig('plots/fig2_d.pdf')


def plot_energy_multi(pars, rates_Ax, potentials, verbose=False):
    """
    Main function to plot a 2D figure of the potential on the r1,r2 grid.

    :param potentials:
    :param rates_Ax:
    :param pars:
    :return:
    """
    try:
        del potentials['info']
    except:
        pass

    # plotting params
    lognorm = False
    plot_text = False
    plot_manual_zero = True
    set_ticklabels_int = True
    matplotlib.pyplot.rcParams['ytick.minor.visible'] = False
    matplotlib.pyplot.rcParams['xtick.minor.visible'] = False

    # trunc_plot = (0., 3.)
    trunc_plot = None
    minsearch_order = 20  # number of neighbors to consider for local minima
    minsearch_pad = 10e6
    minsearch_pad_width = 25  # 0 or value, for padding.. solves the edge problem
    mec = 'k'
    mfc = 'white'
    lw = 2.5
    s_circle = 100
    s_marker = 60
    ticksize = 14
    fontsize = 22

    for m_idx, m in enumerate(potentials.keys()):
        fig, ax = plt.subplots(figsize=fig_defaults['figsize']['main'])
        tmp_potentials = -potentials[m] #+ U_offset
        if trunc_plot:
            tr_min_idx = np.where(rates_Ax == trunc_plot[0])[0][0]
            tr_max_idx = np.where(rates_Ax == trunc_plot[1])[0][0]
            tmp_potentials = tmp_potentials[tr_min_idx:tr_max_idx, tr_min_idx:tr_max_idx]
            print(f"Warning! Truncating plot to {trunc_plot}!!")

        ax.set_ylabel(r'$\nu_{\mathrm{S}1}$', fontsize=fontsize)  # rows in tmp_potentials
        ax.set_xlabel(r'$\nu_{\mathrm{S}2}$', fontsize=fontsize)  # columns in tmp_potentials

        from matplotlib.colors import LogNorm, SymLogNorm
        norm = SymLogNorm(linthresh=1.) if lognorm else None
        img = ax.matshow(tmp_potentials, cmap='coolwarm', origin='lower', norm=norm)

        lmininima_x = argrelmin(np.pad(tmp_potentials, pad_width=minsearch_pad_width, constant_values=minsearch_pad),
                                mode='wrap', order=minsearch_order, axis=0)
        lmininima_y = argrelmin(np.pad(tmp_potentials, pad_width=minsearch_pad_width, constant_values=minsearch_pad),
                                mode='wrap',order=minsearch_order, axis=1)

        min_cols = set(zip(lmininima_x[0], lmininima_x[1]))
        min_rows = set(zip(lmininima_y[0], lmininima_y[1]))
        minima = min_cols.intersection(min_rows)

        for minimum in minima:
            # tmp_potential indices, not the plot
            min_i, min_j = minimum[0] - minsearch_pad_width, minimum[1] - minsearch_pad_width
            if verbose:
                print(f"Minimum found for m = {m} @ ({rates_Ax[min_i]}, {rates_Ax[min_j]}), indices = ({min_i}, {min_j})")
            pars['modularity'] = m
            etol = 0.2 if rates_Ax[min_i] < 10. and rates_Ax[min_j] < 10. else 3.
            if calc.test_siegert_A_multi_correctness(rates_Ax[min_i], rates_Ax[min_j], pars, etol=etol, verbose=verbose):
                ax.scatter(x=min_j, y=min_i, s=s_circle, marker='o', edgecolor=mec, facecolor=mfc)
            ax.scatter(x=min_j, y=min_i, s=s_marker, marker='x', color=mec, lw=lw)

        if plot_text:
            for i in range(len(rates_Ax)):
                for j in range(len(rates_Ax)):
                    ax.text(j, i, "{:.2f}".format(tmp_potentials[i][j]), fontsize=4)

        # manually plot a minimum at 0, may be missed if too large step size is used
        if plot_manual_zero:
            ax.scatter(x=0, y=0, s=s_circle, marker='o', edgecolor=mec, facecolor=mfc)
            ax.scatter(x=0, y=0, s=s_marker, marker='x', color=mec, lw=lw)

        ax.tick_params(axis='x', which='both', labelbottom=True, labeltop=False, top=False,)
        if trunc_plot:
            yticks = np.linspace(tr_min_idx, tr_max_idx, 5).astype(int)  # variables are calculated and set before
        else:
            yticks = np.linspace(0, len(rates_Ax), 5).astype(int)
        yticks[-1] -= 1

        ax.set_xticks(yticks)
        ax.set_yticks(yticks)
        ax.set_xticklabels(np.round(rates_Ax[yticks], decimals=2) if not set_ticklabels_int else rates_Ax[yticks].astype(int))
        ax.set_yticklabels(np.round(rates_Ax[yticks], decimals=2) if not set_ticklabels_int else rates_Ax[yticks].astype(int))
        ax.tick_params(axis='both', labelsize=ticksize)
        margin = 3
        ax.set_ylim((-margin, max(yticks)+margin))
        ax.set_xlim((-margin, max(yticks)+margin))
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", "10%", pad="4%")
        cbar = fig.colorbar(img, cax=cax,
                            format=matplotlib.ticker.LogFormatterMathtext() if lognorm else None)
        pot_diff = abs(np.max(tmp_potentials) - np.min(tmp_potentials))
        cbar.set_ticks((np.min(tmp_potentials) + 0.1*pot_diff, np.max(tmp_potentials) - 0.1*pot_diff ))
        cbar.ax.set_yticklabels(['low', 'high'], fontsize=ticksize)
        cbar.ax.set_ylabel(r'Potential $U$', fontsize=18)

        filename = f"plots/fig9_s2_m={m}"
        if trunc_plot:
            filename += f"_truncated={trunc_plot}"
        fig.tight_layout()
        fig.savefig(f"{filename}.pdf")