import logging  # coding: utf-8
import matplotlib.pyplot
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.signal import argrelextrema, argrelmin
import os
import h5py_wrapper as h5w

# from mft_helpers.data_io import dataset_energy
from mft_helpers import calc
from helpers import h_figdef
from pathos.multiprocessing import ProcessingPool as PathosPool

mod_colors = {
    'purple': '#867CE8',
    'fgreen': 'forestgreen',
    'fgreen_comp': '#8B228B',
    'msgreen': 'mediumseagreen',
    'msgreen_comp': '#B33C7E',
    'g1': '#5FA075',
    'g1_comp': '#A05F8A',
    'nu_first_layer_low': '',
    'nu_first_layer_standard': '',
    'nu_first_layer_high': '',
    # 'tetradic_set1': ['#C71585', '#C7B015', '#15C757', '#152CC7'],
    # 'tetradic_set2': ['#1F77B4', '#A71FB4', '#B45C1F', '#2CB41F']
    # 'tetradic_set1': ['mediumvioletred', 'tab:orange', 'tab:blue', 'tab:green']
    # 'tetradic_set1': ['tab:red', 'tab:orange', 'tab:blue', 'tab:green']
    'tetradic_set': ['tab:red', '#D926B7', 'cornflowerblue', 'mediumseagreen']
}

fig_defaults = {
    'figsize': {
        'main': (6, 4),
        'zoomed': (3, 2),
        'grid_cross_sec': (15, 10),
    },

    'ticksize': {
        'main_fp': 14
    },

    'labelsize': {
        'main_fp': 18
    },
    'legendsize': {
        'main_fp': 12
    }
}

def compute_kappas(pars, n_stim, m, is_sigma=False):
    """
    Note that the term \italic{J} = tau * K_E * J is not included.
    :return:
    """
    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    alpha = pars['noise_alpha']

    # used to square g if we're computing the variance
    g2 = 1. if not is_sigma else g

    kappa_A_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.)
    kappa_A_NA = (N_C - 1.) / N_C * (1 + gamma * g * g2) + (1 - alpha) * ((N_C - 1.) * (1. - m)) / (
                (N_C - 1.) * (1 - m) + 1.)

    kappa_NA_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.)
    kappa_NA_NA = (N_C - 1.) / N_C * (1. + gamma * g * g2) + (1. - alpha) \
                  * (1. + (N_C - 2.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.)
    return kappa_A_A, kappa_A_NA, kappa_NA_A, kappa_NA_NA


########################################################################################################################
# Plotting functions
########################################################################################################################
def plot_fixed_points(results, pars, stability_analysis=False, filename_suffix='', vis_zero_log=False, save=True,
                      overlapping=False, frameon=False, loc_legend_stability='center', loc_legend_stimulated='lower',
                      figsize=None, legendsize=None, narrowlegend=False):
    """
    Plots the stable and unstable fixed points as a function of the modularity.

    >> Fig 7B
    :param results:
    :param pars:
    :param stability_analysis:
    :param filename_suffix: [str]
    :return:
    """
    if figsize and figsize == 'narrow':
        fig = plt.figure(figsize=(4, 4))
    elif figsize:
        fig = plt.figure(figsize=figsize)
    else:
        fig = plt.figure(figsize=fig_defaults['figsize']['main'])
    ax_A = plt.subplot(111)

    m_min = 500
    m_max = 0
    line_stable = None
    line_unstable = None
    c_stable = 'k'
    c_unstable = 'tab:red'
    plot_lower_log_ylim = 1e-2 if vis_zero_log else -1.  # just some lower bound on low-rate FPs for sensible log plots
    m_crit = None

    for idx, (m, (nu, state,stability)) in enumerate(results.items()):
        m_min = min(m_min, m)
        m_max = max(m_max, m)
        n_fp = nu.shape[0]
        # print('\t%d fixed points found:' % n_fp)
        la = 'stimulated' if idx == 0 else None
        lna = 'non-stimulated' if idx == 0 else None
        ms = 5

        if m_crit is None and n_fp > 1:  # simple detection of m_crit
            m_crit = m

        for cs in range(n_fp):
            nu_A = max(nu[cs, 0], plot_lower_log_ylim)
            nu_NA = max(nu[cs, 1], plot_lower_log_ylim)
            alpha = 1 if not overlapping else 0.3
            ms_ns = ms if not overlapping else ms*1.5
            ms = ms if not overlapping else ms*1.3
            if stability[cs] is True:
                line_stable, = ax_A.plot([m - 0.001], [nu_A], 'o', color=c_stable, label=la, lw=3., ms=ms, alpha=alpha)
                ax_A.plot([m + 0.001], [nu_NA], 'x', color=c_stable, label=lna, lw=3., ms=ms_ns)
            elif stability[cs] is False:
                line_unstable, = ax_A.plot([m - 0.001], [nu_A], 'o', color=c_unstable, label=la, lw=3., ms=ms)
                print("unstable fp at nu_A={} for m={}".format(nu_A, m))
                ax_A.plot([m + 0.001], [nu_NA], 'x', color=c_unstable, label=lna, lw=3., ms=ms_ns)
            elif stability[cs] == -1:
                # -1 because stability could not be determined as it's not really a fp (plugging in the value fails);
                # however, we ignore this here because it's probably a numerical instability as it's the only such point
                # ax_A.plot([m - 0.001], [nu[cs, 0]], 'o-', color='k', label=la, lw=3.)
                # ax_A.plot([m + 0.001], [nu[cs, 1]], 'x-', color='k', label=lna, lw=3.)
                ax_A.plot([m - 0.001], [nu_A], 'o', color=c_unstable, label=la, lw=3., ms=ms)
                print("unstable fp at nu_A={} for m={}".format(nu_A, m))
                ax_A.plot([m + 0.001], [nu_NA], 'x', color=c_unstable, label=lna, lw=3., ms=ms_ns)

    labelsize = fig_defaults['labelsize']['main_fp']
    ticksize = fig_defaults['ticksize']['main_fp']
    legendsize = fig_defaults['legendsize']['main_fp'] * 1.2 if legendsize is None else legendsize
    handletextpad = 0.01

    if pars['method'] != 'piecewise_linear':
        ax_A.set_ylim([3e-2,1e3])
        ax_A.set_yscale('log')
        if loc_legend_stimulated:
            legend_active = ax_A.legend(loc=loc_legend_stimulated, fontsize=12, handletextpad=handletextpad,
                                        frameon=frameon)
    else:
        ax_A.set_yscale('log')
        # ax_A.plot([m_crit, m_crit], [2e-3, 6e2], '-', c='mediumseagreen')
        ax_A.set_title(f'm_crit: {m_crit}')
        if loc_legend_stimulated:
            legend_active = ax_A.legend(loc='center left', fontsize=legendsize, handletextpad=handletextpad,
                                        frameon=frameon)

    ax_A.set_xlim([0., 1.0])
    ax_A.set_xlabel('m', fontsize=labelsize)
    ax_A.set_ylabel(r'$\nu$ (spks/sec)', fontsize=labelsize)
    ax_A.tick_params(axis='both', labelsize=ticksize)

    # legend_active = ax_A.legend(loc=loc_legend_stimulated, fontsize=legendsize, frameon=frameon)
    if loc_legend_stability:
        handletextpad = handletextpad if narrowlegend else None
        plt.legend([l for l in [line_stable, line_unstable]], ['stable', 'unstable'], loc=loc_legend_stability,
                   frameon=frameon, fontsize=legendsize, handletextpad=handletextpad)
    try:
        plt.gca().add_artist(legend_active)
    except:
        pass

    fig.tight_layout()
    filename = 'fixed_points_initRates={}_m=[{}-{}]_g={}_met={}_nuX={}{}.pdf'.format(
        pars['intv_initial_ratesA'],
        np.round(m_min, decimals=2), np.round(m_max, decimals=2),
        pars['g'],
        pars['method'],
        pars['nuX'],
        filename_suffix
    )

    if save:
        fig.savefig('{}/{}'.format(pars['figure_root_path'], filename))
        logging.info('Finished plotting figure: {}/{}'.format(pars['figure_root_path'], filename))
    return fig, ax_A, filename


def plot_nu_vs_layers(results, pars):
    """
    [Original] Plots the firing rates as a function of the number of layers.
    """
    fig = plt.figure(figsize=(6, 4))
    ax_A = plt.subplot(211)
    ax_NA = plt.subplot(212)

    lw = 2.

    for idx, (m, nu) in enumerate(results.items()):
        ax_A.plot(range(1, len(nu) + 1), nu[:, 0], label='m=' + str(np.round(m, 2)), lw=lw)
        ax_NA.plot(range(1, len(nu) + 1), nu[:, 1], label='m=' + str(np.round(m, 2)), lw=lw)

    ax_A.legend(loc='best',fontsize=3.)
    ax_A.set_yscale('log')
    ax_A.set_ylabel(r'$\nu_{A}$')
    ax_NA.set_ylabel(r'$\nu_{NA}$')
    ax_A.set_ylim([2e-1,5e2])
    ax_NA.set_ylim([2e-1,5e2])
    fig.savefig(os.path.join(pars['figure_root_path'],'nu_vs_layers_g={}.pdf'.format(pars['g'])))


def plot_nu_vs_layers_fig2(results, pars, rates_initial_layer):
    """
    [Fig 2d] Plots the firing rates as a function of the number of layers. Custom colormaps and plotting ranges.
    """
    colors = ['twilight_shifted']
    color_offset = 3  # increase the number of colors when generating palette, to increase range and reduce extremities
    lw = 4.
    ticksize = 24

    for cname in colors:
        # trim unwanted m values
        results = {m: v for m, v in results.items() if 0.75 <= m <= 0.9}
        n_colors = len(results.keys()) + color_offset * 2

        if cname == 'div':
            # colors = sns.diverging_palette(220, 20, n=n_colors, center='dark')
            colors = sns.diverging_palette(250, 15, s=75, l=40, sep=1, n=n_colors, center="light")
        elif cname == 'circular':
            colors = sns.color_palette(n_colors=n_colors)
        else:
            colors = sns.color_palette(cname, n_colors)#[::-1]

        colors = colors[color_offset:]

        fig = plt.figure(figsize=(11, 9))
        for active in [True, False]:
            ax = plt.subplot2grid((2, 1), (1 - int(active), 0))

            for idx, (m, nu) in enumerate(results.items()):
                nu = nu[:50]
                if active:
                    ax.plot(range(1,len(nu)+1),nu[:,0],label='m='+str(np.round(m,2)),color=colors[idx], lw=lw)
                else:
                    ax.plot(range(1,len(nu)+1),nu[:,1], '--', label='m='+str(np.round(m,2)),color=colors[idx], lw=lw)

            ax.set_yscale('log')
            if active:
                ax.set_ylim([6e-1,5e2])
            else:
                ax.set_ylim([2e-1,5e2])

            ax.tick_params(axis='both', labelsize=ticksize)
            sns.despine(ax=ax)

        fig.tight_layout()
        plt.subplots_adjust(hspace=0.35)
        # fig.savefig('nu_vs_layers.pdf')
        fig.savefig('fig2_d.pdf')


def plot_fixed_points_vs_lambda(modularity_values, limits_low_input, limits_standard_input, limits_high_input, pars):
    """

    :param modularity_values:
    :param limits_low_input:
    :param limits_standard_input:
    :param limits_high_input:
    :param pars:
    :return:
    """
    fig = plt.figure(figsize=fig_defaults['figsize']['main'])
    ax = plt.subplot(111)
    colors = ['gray', 'cornflowerblue', 'k']
    lw = 2.
    ms = 9
    a = 1.0
    ax.plot(modularity_values,limits_low_input[:,0], 'o-', color=colors[0], label='low input', lw=lw, ms=ms,
            markeredgecolor=colors[0], mew=2., mfc=colors[0], alpha=a)
    ax.plot(modularity_values,limits_standard_input[:,0], 'o-', color=colors[1], label='standard input', lw=lw, ms=ms,
            alpha=a, zorder=3)
    ax.plot(modularity_values,limits_high_input[:,0], 'o-', color=colors[2], label='high input', lw=lw, ms=ms,
            alpha=a)

    ax.set_xticks([0.75,0.8,0.85,0.9])
    ax.legend(loc='best', fontsize=fig_defaults['legendsize']['main_fp'])
    ax.set_yscale('log')
    ax.set_xlabel('m', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.set_ylabel(r'$\nu^\mathrm{S}$ (spks/sec)', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.tick_params(axis='both', labelsize=fig_defaults['ticksize']['main_fp'])
    sns.despine()

    fig.tight_layout()
    fig.savefig('fixed_points_vs_lambda.pdf')
    print("Finished plotting.")


def plot_nu_first_layer(results, pars, version=1):
    """

    :param results:
    :param pars:
    :param version:
    :return:
    """
    fig = plt.figure(figsize=fig_defaults['figsize']['main'])
    ax = plt.subplot(111)

    if version == 1:
        colors = ['tomato', 'cornflowerblue', 'mediumseagreen']
    else:
        # colors = [mod_colors['tetradic_set'][1], mod_colors['tetradic_set'][2], mod_colors['tetradic_set'][3]]
        # colors = ['silver', 'cornflowerblue', 'black']
        colors = ['gray', 'cornflowerblue', 'black']

    lmbd_highlighted = [0.01, 0.05, 0.25]
    mew = 1.5
    lw = 2.5
    ms = 9

    lambdas = sorted(results.keys())
    base_color = 'lightgray'
    # base_color = 'gainsboro'
    ax.plot(lambdas, [nu[0] for nu in results.values()], '-', color=base_color, lw=lw, ms=ms)
    ax.plot(lambdas, [nu[1] for nu in results.values()], '-', color=base_color, mew=mew, lw=lw, ms=ms)

    for idx, (lmbd, nu) in enumerate(results.items()):
        # if (lmbd * 1e2) % 2 < 1e-3:
        #     continue
        c = base_color if lmbd not in lmbd_highlighted else colors[lmbd_highlighted.index(lmbd)]
        if idx == len(results) - 1:
        # if idx == 3:
            ax.plot(lmbd,nu[0],'o-',color=c,label='stimulated', lw=lw, ms=ms)
            ax.plot(lmbd,nu[1],'x-',color=c,label='non-stimulated', mew=mew, lw=lw, ms=ms)
        else:
            ax.plot(lmbd,nu[0],'o-',color=c, lw=lw, ms=ms)
            ax.plot(lmbd,nu[1],'x-',color=c, mew=mew, lw=lw, ms=ms)

    ax.legend(loc='best',  fontsize=fig_defaults['legendsize']['main_fp'])
    # ax.legend(loc='best',  fontsize=14)
    ax.set_ylim(1., 24.)
    ax.set_yticks([5., 10., 15., 20.])
    # ax.get_yaxis().get_major_formatter().labelOnlyBase = True
    # ax.set_yticklabels(['', , 100.])

    # ax.set_xlabel(r'stimulus intensity ($\lambda)$', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.set_xlabel(r'$\lambda$', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.set_ylabel(r'$\nu_0$ (spks/sec)', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.tick_params(axis='both', labelsize=fig_defaults['ticksize']['main_fp'])
    sns.despine()

    fig.tight_layout()
    fig.savefig('nu_first_layer_g={}_v{}.pdf'.format(pars['g'], version))


def plot_kappas(pars_):
    """
    4 panel figure with the different kappas (S, NS).

    :param pars_:
    :return:
    """
    n_stim = 10

    for sigma in [False, True]:
        fig = plt.figure(figsize=(7, 5))
        axes = [plt.subplot(221 + i) for i in range(4)]

        for idx, (noise, g) in enumerate([(0.25, 12.), (0.85, 12.), (0.25, 6.)]):
            # manually fix crucial params here
            pars_['noise_alpha'] = noise
            pars_['g'] = g
            color = ['b', 'r', 'g'][idx]

            res = [[] for _ in range(4)]
            modularity_values = np.arange(0.1, 1.001, 0.01)
            for m in modularity_values:
                kappa_tuple = compute_kappas(pars_, n_stim, m, is_sigma=sigma)
                for kappa_idx in range(4):
                    res[kappa_idx].append(kappa_tuple[kappa_idx])

            title = [r'$\kappa_\mathrm{S,S}$', r'$\kappa_\mathrm{S,NS}$',
                     r'$\kappa_\mathrm{NS,S}$', r'$\kappa_\mathrm{NS,NS}$']
            label = r'$\alpha$ = {}, g = {}'.format(noise, g)
            for kappa_idx in range(4):
                axes[kappa_idx].plot(modularity_values, res[kappa_idx], '-', label=label, color=color)
                axes[kappa_idx].set_ylabel(title[kappa_idx])
                axes[kappa_idx].set_xlabel('m')
                axes[kappa_idx].spines['right'].set_visible(False)
                axes[kappa_idx].spines['top'].set_visible(False)
                if kappa_idx == 0:
                    axes[kappa_idx].legend(loc='best')

        fig.tight_layout()
        fig.savefig(os.path.join(pars_['figure_root_path'], 'kappa_{}.pdf'.format(
            'variance' if sigma else 'mean')))


def plot_kappas_reduced(pars):
    """

    :param pars:
    :return:
    """
    labelsize = 18
    ticksize = 14
    n_stim = 10
    modularity_values = np.arange(0., 1.001, 0.01)

    kappa_A = np.zeros_like(modularity_values)
    kappa_NA = np.zeros_like(modularity_values)

    for i, m in enumerate(modularity_values):
        kappa_A[i]= compute_kappas(pars, n_stim, m, is_sigma=False)[0]
        kappa_NA[i] = compute_kappas(pars, n_stim, m, is_sigma=False)[3]
        if m == 0.75:
            kappa_star = kappa_A[i]

    fig = plt.figure(figsize=(6, 4))
    ax = plt.subplot(111)
    ax.plot(modularity_values,np.zeros_like(modularity_values),'k-')
    ax.plot(modularity_values,kappa_star*np.ones_like(modularity_values),'--',color='tomato')
    ax.plot(modularity_values,kappa_A,'-',color='cornflowerblue',label=r'$\kappa_\mathrm{S,S}$')
    ax.plot(modularity_values,kappa_NA,'--',color='cornflowerblue',label=r'$\kappa_\mathrm{NS,NS}$')
    ax.fill_between([0.,0.75],-2.*np.ones(2),2.*np.ones(2),color='0.80')
    ax.set_xticks([0.,0.5,0.75,1.])
    ax.set_yticks([-1.,0.,0.5])
    ax.text(0.45,0.5,'fading regime')
    ax.text(0.77,0.5,'active regime')
    ax.text(0.71,2.*kappa_star,r'$\kappa^\ast$',color='tomato', fontdict={'fontsize': 'large'})
    ax.set_ylim([-1.25,0.75])
    ax.set_xlim([0,1.])
    ax.set_xlabel('m', fontsize=labelsize)
    ax.set_ylabel(r'$\kappa$', fontsize=labelsize)
    ax.tick_params(axis='both', labelsize=ticksize)

    ax.legend(loc=2, fontsize=16)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    fig.tight_layout()
    fig.savefig(os.path.join(pars['figure_root_path'],'kappa_reduced.pdf'), dpi=600)


def plot_kappas_mean(pars_):
    """
    4 panel figure with the different kappas (S, NS).

    :param pars_:
    :return:
    """
    n_stim = 10
    sigma = False

    fig = plt.figure(figsize=(7, 4))
    axes = [plt.subplot(221 + i) for i in range(4)]
    labelsize = 16
    ticksize = 12

    for idx, (noise, g) in enumerate([(0.25, 12.), (0.85, 12.), (0.25, 6.)]):
        # manually fix crucial params here
        pars_['noise_alpha'] = noise
        pars_['g'] = g
        color = ['b', 'r', 'g'][idx]

        res = [[] for _ in range(4)]
        modularity_values = np.arange(0.1, 1.001, 0.01)
        for m in modularity_values:
            kappa_tuple = compute_kappas(pars_, n_stim, m, is_sigma=sigma)
            for kappa_idx in range(4):
                res[kappa_idx].append(kappa_tuple[kappa_idx])

        title = [r'$\kappa_\mathrm{S,S}$', r'$\kappa_\mathrm{S,NS}$',
                 r'$\kappa_\mathrm{NS,S}$', r'$\kappa_\mathrm{NS,NS}$']
        # title = [r'$\kappa_\mathrm{S,S} / \mathcal{J}$', r'$\kappa_\mathrm{S,NS} / \mathcal{J}$',
        #          r'$\kappa_\mathrm{NS,S} / \mathcal{J}$', r'$\kappa_\mathrm{NS,NS} / \mathcal{J}$']
        label = r'$\alpha$ = {}, g = -{}'.format(noise, g)
        for kappa_idx in range(4):
            axes[kappa_idx].plot(modularity_values, res[kappa_idx], '-', label=label, color=color)
            axes[kappa_idx].set_ylabel(title[kappa_idx], fontsize=labelsize)
            axes[kappa_idx].set_xlabel('m', fontsize=labelsize)
            axes[kappa_idx].spines['right'].set_visible(False)
            axes[kappa_idx].spines['top'].set_visible(False)
            axes[kappa_idx].tick_params(axis='both', labelsize=ticksize)
            if kappa_idx == 0:
                axes[kappa_idx].legend(loc='best', fontsize=9)

    fig.tight_layout()
    # fig.savefig('plots/fig7_s1_b.pdf')
    fig.savefig(os.path.join(pars_['figure_root_path'],'kappa_scan.pdf'), dpi=600)


def plot_noise_vs_layers(results_all, initial_rate_active, na_rates, pars_):
    """
    Plots the threshold modularity value for different NS and S cluster firing rate ratios.

    :param results_all:
    :param initial_rate_active:
    :param na_rates:
    :param pars_:
    :return:
    """
    fig = plt.figure(figsize=fig_defaults['figsize']['main'])
    ax = plt.subplot(111)
    ax.plot(na_rates/initial_rate_active, results_all, 'o-', color='k', lw=2.5, ms=9)
    ax.set_xticks(np.arange(0.1, 0.91, 0.2))
    # ax.set_xlabel(r'${\nu_\mathrm{NS}} / {\nu_\mathrm{S}}$')
    ax.set_xlabel(r'${\nu^\mathrm{NS}_0} / {\nu^\mathrm{S}_0}$', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.set_ylabel(r'$\mathrm{m}_\mathrm{switch}$', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis='both', labelsize=fig_defaults['ticksize']['main_fp'])

    fig.tight_layout()
    fig.savefig('figures/noise_vs_layers.pdf')


def plot_energy_single_input(dr_grid, potentials, U_offset, pars, zoomed_plot=False):
    """

    :param dr_grid:
    :param potentials:
    :param U_offset:
    :param pars:
    :param zoomed_plot:
    :return:
    """
    fig = plt.figure(figsize=fig_defaults['figsize']['zoomed' if zoomed_plot else 'main'])
    ax = plt.subplot(111)
    plot_rmax = 5 if zoomed_plot else 300
    plot_rmin = 1
    m_max = 0.8 if zoomed_plot else 0.85
    m_min = 0.7 if zoomed_plot else 0.7

    legend_m = [0.7, 0.75, 0.76, 0.8, 0.85]
    for k in list(potentials.keys()):
        if not m_min <= k <= m_max:
            del potentials[k]

    idxs = np.where((dr_grid <= plot_rmax) & (dr_grid >= plot_rmin))[0]
    logx = False
    logy = True
    # logy = False
    plot_minima = True
    plot_min_th = 0. if zoomed_plot else 0.
    plot_legend = False #not zoomed_plot

    dr_grid = dr_grid[idxs]

    colors_greys = sns.light_palette('grey', n_colors=6)
    colors_reds = sns.dark_palette('tomato', reverse=True, n_colors=len(potentials) - 0)
    # colors = sns.color_palette('coolwarm', len(potentials))
    # colors_greys = colors[:6]
    # colors_reds = colors[6:]
    cg_cnt = 0
    cr_cnt = 0

    for m_idx, m in enumerate(potentials.keys()):
        tmp_potentials = -potentials[m][idxs]
        lmininima = argrelextrema(tmp_potentials, np.less)  # compute local minima
        print(f"local minima for m = {m}: {lmininima}")

        # choose color and linestyle depending on the number of FPs
        if m < 0.76:  # TODO this is manually set here, but still correct (see line above)
            c = colors_greys[cg_cnt]
            cg_cnt += 1
            lstyle = '--'
        else:
            c = colors_reds[cr_cnt]
            cr_cnt += 1
            lstyle = '-'

        label = 'm = {:.2f}'.format(m) if np.any([np.isclose(m, x) for x in legend_m]) else None

        ax.plot(dr_grid, tmp_potentials + U_offset, label=label, color=c, ls=lstyle)  # plot -U to get minima

        if plot_minima:
            for lmin_idx in lmininima:
                plot_indices = lmin_idx[lmin_idx > plot_min_th]
                ax.scatter(dr_grid[plot_indices], tmp_potentials[plot_indices] + U_offset,
                           marker='o', edgecolor='k', facecolor='k', zorder=3, s=64)

    ax.tick_params(axis='both', labelsize=fig_defaults['ticksize']['main_fp'])

    # filename = f"{pars['figure_root_path']}/energy_single_mod={potentials.keys()}_trapezoidRev"
    filename = f"potential_single_mod={min(potentials.keys())}-{max(potentials.keys())}"
    if logy:
        # filename += '_logY'
        ax.set_yscale('log')
    ax.minorticks_off()
    ax.set_yticks([])

    if not zoomed_plot:
        ax.set_xlabel(r'$\nu_S$ (spks/sec)', fontsize=fig_defaults['labelsize']['main_fp'])
        ax.set_ylabel('Potential', fontsize=fig_defaults['labelsize']['main_fp'])
    else:
        filename += '_zoomed'
        ax.set_xticks(np.arange(1, 5.1, 1))
    sns.despine(ax=ax)

    if logx:
        ax.set_xscale('log')

    fig.tight_layout()
    fig.savefig(f"{filename}.pdf")


def plot_energy_single_input_specific(potentials, dr_grid):
    """
    Plots the potentials for specific m values.

    :param potentials:
    :param dr_grid:
    :return:
    """
    U_offset = 0

    fig = plt.figure(figsize=fig_defaults['figsize']['main'])
    ax = plt.subplot(111)
    plot_rmax = 30
    plot_rmin = 1

    plot_m = [0.76, 0.83, 0.87]  # if we want only one trace
    for k in list(potentials.keys()):
        if ~np.any([np.isclose(k, x) for x in plot_m]):
            del potentials[k]

    idxs = np.where((dr_grid <= plot_rmax) & (dr_grid >= plot_rmin))[0]
    logx = False
    logy = True
    # logy = False

    dr_grid = dr_grid[idxs]
    colors = sns.dark_palette('tomato', reverse=True, n_colors=25)
    init_fp_colors = ['gray', 'cornflowerblue', 'k']
    pot_lines = []
    dotlines = []

    for m_idx, m in enumerate(potentials.keys()):
        tmp_potentials = -potentials[m][idxs] / 10

        pl, = ax.plot(dr_grid, tmp_potentials + U_offset, '-', lw=3, c=colors[m_idx * 12])  # plot -U to get minima
        pot_lines.append(pl)

        # get closest rate idx
        dl, = ax.plot([], [], 'o', markersize=12, color=init_fp_colors[m_idx])
        dotlines.append(dl)

    ax.tick_params(axis='both', labelsize=fig_defaults['ticksize']['main_fp'])
    legend_pot_lines = plt.legend(pot_lines, ["m = {:.2f}".format(m) for m in potentials.keys()], loc='lower left',
                                  fontsize=fig_defaults['legendsize']['main_fp']
                                  )
    plt.legend(dotlines, ['low input', 'standard input', 'high input'], loc='center right',  markerscale=.8,
               fontsize=fig_defaults['legendsize']['main_fp'])
    plt.gca().add_artist(legend_pot_lines)

    if logy:
        ax.set_yscale('symlog')
    ax.minorticks_off()
    ax.set_yticks([])

    ax.set_xlabel(r'$\nu^\mathrm{S}$ (spks/sec)', fontsize=fig_defaults['labelsize']['main_fp'])
    ax.set_ylabel('Potential', fontsize=fig_defaults['labelsize']['main_fp'])
    sns.despine(ax=ax)

    if logx:
        ax.set_xscale('log')

    fig.tight_layout()
    fig.savefig("potentials_three_m.pdf")


def plot_energy_multi(pars, rates_Ax, potentials):
    """
    Main function to plot a 2D figure of the potential on the r1,r2 grid.

    :param pars:
    :return:
    """
    try:
        del potentials['info']
    except:
        pass

    # plotting params
    lognorm = False
    plot_text = False
    plot_manual_zero = True
    set_ticklabels_int = True
    matplotlib.pyplot.rcParams['ytick.minor.visible'] = False
    matplotlib.pyplot.rcParams['xtick.minor.visible'] = False

    # trunc_plot = (0., 3.)
    trunc_plot = None
    minsearch_order = 20  # number of neighbors to consider for local minima
    minsearch_pad = 10e6
    minsearch_pad_width = 25  # 0 or value, for padding.. solves the edge problem
    mec = 'k'
    mfc = 'white'
    lw = 2.5
    s_circle = 100
    s_marker = 60
    ticksize = 14
    fontsize = 22

    for m_idx, m in enumerate(potentials.keys()):
    # for m_idx, m in enumerate([1.]):
        fig, ax = plt.subplots(figsize=fig_defaults['figsize']['main'])
        tmp_potentials = -potentials[m] #+ U_offset
        if trunc_plot:
            tr_min_idx = np.where(rates_Ax == trunc_plot[0])[0][0]
            tr_max_idx = np.where(rates_Ax == trunc_plot[1])[0][0]
            tmp_potentials = tmp_potentials[tr_min_idx:tr_max_idx, tr_min_idx:tr_max_idx]
            print(f"Warning! Truncating plot to {trunc_plot}!!")

        ax.set_ylabel(r'$\nu_{\mathrm{S}1}$', fontsize=fontsize)  # rows in tmp_potentials
        ax.set_xlabel(r'$\nu_{\mathrm{S}2}$', fontsize=fontsize)  # columns in tmp_potentials

        from matplotlib.colors import LogNorm, SymLogNorm
        norm = SymLogNorm(linthresh=1.) if lognorm else None
        img = ax.matshow(tmp_potentials, cmap='coolwarm', origin='lower', norm=norm)

        lmininima_x = argrelmin(np.pad(tmp_potentials, pad_width=minsearch_pad_width, constant_values=minsearch_pad),
                                mode='wrap', order=minsearch_order, axis=0)
        lmininima_y = argrelmin(np.pad(tmp_potentials, pad_width=minsearch_pad_width, constant_values=minsearch_pad),
                                mode='wrap',order=minsearch_order, axis=1)

        min_cols = set(zip(lmininima_x[0], lmininima_x[1]))
        min_rows = set(zip(lmininima_y[0], lmininima_y[1]))
        minima = min_cols.intersection(min_rows)

        for minimum in minima:
            # tmp_potential indices, not the plot
            min_i, min_j = minimum[0] - minsearch_pad_width, minimum[1] - minsearch_pad_width
            print(f"Minimum found for m = {m} @ ({rates_Ax[min_i]}, {rates_Ax[min_j]}), indices = ({min_i}, {min_j})")
            pars['modularity'] = m
            etol = 0.2 if rates_Ax[min_i] < 10. and rates_Ax[min_j] < 10. else 3.
            if calc.test_siegert_A_multi_correctness(rates_Ax[min_i], rates_Ax[min_j], pars, etol=etol):
                ax.scatter(x=min_j, y=min_i, s=s_circle, marker='o', edgecolor=mec, facecolor=mfc)
            ax.scatter(x=min_j, y=min_i, s=s_marker, marker='x', color=mec, lw=lw)

        if plot_text:
            for i in range(len(rates_Ax)):
                for j in range(len(rates_Ax)):
                    ax.text(j, i, "{:.2f}".format(tmp_potentials[i][j]), fontsize=4)

        # manually plot a minimum at 0, may be missed if too large step size is used
        if plot_manual_zero:
            ax.scatter(x=0, y=0, s=s_circle, marker='o', edgecolor=mec, facecolor=mfc)
            ax.scatter(x=0, y=0, s=s_marker, marker='x', color=mec, lw=lw)

        ax.tick_params(axis='x', which='both', labelbottom=True, labeltop=False, top=False,)
        if trunc_plot:
            yticks = np.linspace(tr_min_idx, tr_max_idx, 5).astype(int)  # variables are calculated and set before
        else:
            yticks = np.linspace(0, len(rates_Ax), 5).astype(int)
        yticks[-1] -= 1

        ax.set_xticks(yticks)
        ax.set_yticks(yticks)
        ax.set_xticklabels(np.round(rates_Ax[yticks], decimals=2) if not set_ticklabels_int else rates_Ax[yticks].astype(int))
        ax.set_yticklabels(np.round(rates_Ax[yticks], decimals=2) if not set_ticklabels_int else rates_Ax[yticks].astype(int))
        ax.tick_params(axis='both', labelsize=ticksize)
        margin = 3
        ax.set_ylim((-margin, max(yticks)+margin))
        ax.set_xlim((-margin, max(yticks)+margin))
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", "10%", pad="4%")
        cbar = fig.colorbar(img, cax=cax,
                            format=matplotlib.ticker.LogFormatterMathtext() if lognorm else None)
        pot_diff = abs(np.max(tmp_potentials) - np.min(tmp_potentials))
        cbar.set_ticks((np.min(tmp_potentials) + 0.1*pot_diff, np.max(tmp_potentials) - 0.1*pot_diff ))
        cbar.ax.set_yticklabels(['low', 'high'], fontsize=ticksize)
        cbar.ax.set_ylabel(r'Potential $U$', fontsize=18)

        filename = f"energy_multi_mod={m}"
        if trunc_plot:
            filename += f"_truncated={trunc_plot}"
        fig.tight_layout()
        fig.savefig(f"{filename}.pdf")


def plot_energy_multi_cross_section_single(pars):
    """
    Plot a single plot of the potential cross section, for a specific set of parameters and one direction (rA1 or rA2).
    :param dataset_energy:
    :param pars:
    :return:
    """
    potentials = load_h5_data(pars)

    try:
        rates_Ax = potentials['info']['rates_Ax']
    except:
        rates_Ax = dataset_energy[pars['output_label']]['rates_Ax']
        pass

    try:
        del potentials['info']
    except:
        pass

    U_offset = 1e4  # arbitrary offset
    plot1_m = [0.83]
    fixed_rate = 'rA2'
    fixed_rate_val = [1.0, 2.0]

    # local minima search params
    minsearch_order = 5  # number of neighbors to consider for local minima
    minsearch_pad = 1e7
    minsearch_pad_width = 20
    ##############
    # plotting
    fig, axes = plt.subplots(max(2, len(plot1_m)), len(fixed_rate_val), figsize=fig_defaults['figsize']['grid_cross_sec'])

    matplotlib.pyplot.rcParams['ytick.minor.visible'] = False
    matplotlib.pyplot.rcParams['xtick.minor.visible'] = False

    for m_idx, m in enumerate(plot1_m):
        logx = False
        logy = True

        for rfix_idx, rfix in enumerate(fixed_rate_val):
            if fixed_rate == 'rA2':
                j = np.where(rates_Ax == rfix)[0]
                tmp_potentials_1D = -potentials[m][:, j].T[0]  # just transform column corresponding to rA2_fix to 1D arr

            # compute local minima
            lmininima = argrelmin(
                np.pad(tmp_potentials_1D, pad_width=minsearch_pad_width, constant_values=minsearch_pad),
                mode='wrap', order=minsearch_order)
            # print(f"local minima for m = {m}: {lmininima}")
            label = f"m = {m}, rfix [{fixed_rate}] = {rfix} "
            # plot -U to get minima
            axes[m_idx][rfix_idx].plot(rates_Ax, tmp_potentials_1D + U_offset, label=label)

            for lmin_idx in lmininima:
                axes[m_idx][rfix_idx].plot(rates_Ax[lmin_idx], tmp_potentials_1D[lmin_idx] + U_offset, 'x', color='k')

            axes[m_idx][rfix_idx].set_title(f"m = {m}", fontsize=16)

            axes[m_idx][rfix_idx].set_xlabel(r'$\nu_{A1}$')
            axes[m_idx][rfix_idx].set_ylabel(r'$U$')
            axes[m_idx][rfix_idx].legend(loc='best')
            sns.despine(ax=axes[m_idx][rfix_idx])

            if logy:
                axes[m_idx][rfix_idx].set_yscale('log')
                # axes[m_idx][rfix_idx].minorticks_off()
                # axes[m_idx][rfix_idx].set_yticks([])

            if logx:
                axes[m_idx][rfix_idx].set_xscale('log')

    filename = f"{pars['figure_root_path']}/energy_multi_crossSecGridMinima{pars['output_label']}"

    # if logy:
    #     filename += '_logY'
    #
    # if logx:
    #     filename += '_logX'

    fig.tight_layout()
    fig.savefig(f"{filename}.pdf")
    plt.show()


def plot_energy_multi_cross_section(dataset_energy, pars):
    """

    :param dataset_energy:
    :param pars:
    :return:
    """
    rates_Ax = dataset_energy[pars['output_label']]['rates_Ax']
    potentials = load_h5_data(pars)

    try:
        del potentials['info']
    except:
        pass

    U_offset = 0  # arbitrary offset
    plot1_m = [0.75, 0.76, 0.9]  # top panels
    # plot1_rmax = 300.  # limit of rA2 for which to plot
    plot1_rmax = 2.  # limit of rA2 for which to plot
    plot2_m = potentials.keys()  # bottom panels
    plot2_rA2_fix = [0.2, 0.8, 2.]

    ##############
    # plotting
    fig, ax = plt.subplots(2, len(plot1_m), figsize=fig_defaults['figsize']['grid_cross_sec'])

    matplotlib.pyplot.rcParams['ytick.minor.visible'] = False
    matplotlib.pyplot.rcParams['xtick.minor.visible'] = False

    colors_greys = sns.light_palette('grey', n_colors=len(rates_Ax))
    colors_reds = sns.dark_palette('tomato', reverse=True, n_colors=len(rates_Ax))

    for m_idx, m in enumerate(plot1_m):
        logx = True
        logy = False

        cg_cnt = 0
        cr_cnt = 0

        for rA2_fix in rates_Ax[rates_Ax <= plot1_rmax]:
            j = np.where(rates_Ax == rA2_fix)[0]
            tmp_potentials_1D = -potentials[m][:, j].T[0]  # just transform column corresponding to rA2_fix to 1D arr

            lmininima = argrelextrema(tmp_potentials_1D, np.less)  # compute local minima
            # print(f"local minima for m = {m}: {lmininima}")

            # choose color and linestyle depending on the number of FPs
            if len(lmininima[0]) < 1:
                c = colors_greys[cg_cnt]
                label = r'treshold $\nu_{A2}$ = ' + str(rA2_fix) if cg_cnt == 0 else None
                cg_cnt += 1
                lstyle = '-'
            else:
                label = None
                c = colors_reds[cr_cnt]
                cr_cnt += 1
                lstyle = '-'

            # plot -U to get minima
            ax[0][m_idx].plot(rates_Ax, tmp_potentials_1D + U_offset, label=label, color=c, ls=lstyle)

            for lmin_idx in lmininima:
                ax[0][m_idx].plot(rates_Ax[lmin_idx], tmp_potentials_1D[lmin_idx] + U_offset, 'x', color='k')

        ax[0][m_idx].set_title(f"m = {m}", fontsize=16)

        ax[0][m_idx].set_xlabel(r'$\nu_{A1}$')
        ax[0][m_idx].set_ylabel(r'$U$')
        ax[0][m_idx].legend(loc='best')
        sns.despine(ax=ax[0][m_idx])

        if logy:
            ax[0][m_idx].set_yscale('log')
            ax[0][m_idx].minorticks_off()
            ax[0][m_idx].set_yticks([])

        if logx:
            ax[0][m_idx].set_xscale('log')

    ##########3
    # bottom panels
    for subplot_idx, rA2_fix in enumerate(plot2_rA2_fix):
        cg_cnt = 0
        cr_cnt = 0
        for m in plot2_m:
            j = np.where(rates_Ax == rA2_fix)[0]
            tmp_potentials_1D = -potentials[m][:, j].T[0]  # just transform column corresponding to rA2_fix to 1D arr

            lmininima = argrelextrema(tmp_potentials_1D, np.less)  # compute local minima

            # choose color and linestyle depending on the number of FPs
            if len(lmininima[0]) < 1:
                c = colors_greys[cg_cnt]
                label = None
                cg_cnt += 1
                lstyle = '-'
            else:
                label = r'treshold m = ' + str(m) if cr_cnt == 0 else None
                c = colors_reds[cr_cnt]
                cr_cnt += 1
                lstyle = '-'

            # plot -U to get minima
            ax[1][subplot_idx].plot(rates_Ax, tmp_potentials_1D + U_offset, label=label, color=c, ls=lstyle)

            for lmin_idx in lmininima:
                ax[1][subplot_idx].plot(rates_Ax[lmin_idx], tmp_potentials_1D[lmin_idx] + U_offset, 'x', color='k')

        ax[1][subplot_idx].set_title(r"$\nu_{A2}$ = " + str(rA2_fix), fontsize=16)

        ax[1][subplot_idx].set_xlabel(r'$\nu_{A1}$')
        ax[1][subplot_idx].set_ylabel(r'$U$')
        ax[1][subplot_idx].legend(loc='best')
        sns.despine(ax=ax[0][subplot_idx])

        if logy:
            ax[1][subplot_idx].set_yscale('log')
            ax[1][subplot_idx].minorticks_off()
            ax[1][subplot_idx].set_yticks([])

        if logx:
            ax[1][subplot_idx].set_xscale('log')


    filename = f"{pars['figure_root_path']}/energy_multi_crossSecGrid{pars['output_label']}"

    if logy:
        filename += '_logY'

    if logx:
        filename += '_logX'

    fig.tight_layout()
    fig.savefig(f"{filename}.pdf")



def plot_energy_multi_cross_section_diagonal(pars):
    """

    :param dr_grid:
    :param potentials:
    :param U_offset:
    :param pars:
    :return:
    """
    potentials = load_h5_data(pars)

    try:
        rates_Ax = potentials['info']['rates_Ax']
    except:
        rates_Ax = dataset_energy[pars['output_label']]['rates_Ax']
        pass

    try:
        del potentials['info']
    except:
        pass

    normalize = False
    U_offset = 0  # arbitrary offset
    modularities = [0.83, 0.9, 1.]
    start_idx = [(131, 0), (309, 0), (433, 0)]  # indices in the potential matrix where to start from, for the given m

    ##############
    # plotting
    matplotlib.pyplot.rcParams['ytick.minor.visible'] = False
    matplotlib.pyplot.rcParams['xtick.minor.visible'] = False

    # fig, ax = plt.subplots(figsize=fig_defaults['figsize']['grid_cross_sec'])
    fig, ax = plt.subplots(figsize=fig_defaults['figsize']['main'])
    inset_ax = ax.inset_axes([0.75, 0.75, 0.2, 0.2])
    for m_idx, m in enumerate(modularities):
        logx = False
        logy = False

        i, j = start_idx[m_idx]

        res = []
        while i >= 0:
            tmp_potentials = -potentials[m][i, j]
            res.append(tmp_potentials)
            i -= 1
            j += 1
        res = np.array(res)
        if normalize:
            res = (res - np.min(res))/np.ptp(res)  # normalize potential to 0-1

        line, = ax.plot(list(range(j)), res + U_offset, '-', label=f"m = {m}", lw=3.)

        b = abs(res.max() - res.min())  # barrier height
        inset_ax.bar([m_idx], b, color=line.get_color(), edgecolor='k')
        print(f"Barrier height (absolute potential difference) for m={m}: {b}")

        ax.set_xlabel(r'$\nu_{A1}$')
        ax.set_ylabel(r'$U$')
        ax.legend(loc='center')
        sns.despine(ax=ax)

        if logy:
            ax.set_yscale('log')
            # ax.minorticks_off()
            # ax.set_yticks([])

        if logx:
            ax.set_xscale('log')

    inset_ax.set_yticks([])
    inset_ax.set_ylabel('barrier')

    filename = f"{pars['figure_root_path']}/wells2D_crossSecDiag_{pars['output_label']}_m={modularities}"
    if normalize:
        filename += '_normalized'

    fig.tight_layout()
    fig.savefig(f"{filename}.pdf")


def get_data_filename(pars, multi):
    if multi:
        filename = f"{pars['data_root_path']}/energy_multi{pars['output_label']}.h5"
    else:
        filename = f"{pars['data_root_path']}/energy{pars['output_label']}.h5"
    return filename


def load_h5_data(pars, multi=True):
    filename = get_data_filename(pars, multi)

    try:
        potentials = h5w.load(filename)
        return potentials
    except Exception as e:
        print(f"No previously stored data found. Exception: {str(e)}")
        return None

def usetex_font():
    import shutil
    import matplotlib.font_manager
    # shutil.rmtree(matplotlib.font_manager.get_cachedir())

    # pl.rc('text.latex', preamble='\usepackage{sfmath}')
    # plt.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']})
    # plt.rc('font', **{'family': 'sans-serif', 'sans-serif': ['DejaVu Sans']})
    plt.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Liberation Sans']})
    # plt.rc('font', **{'family': 'sans-serif'})

    # for f in matplotlib.font_manager.fontManager.ttflist:
    #     print(f.name)

    # matplotlib.rcParams['mathtext.fontset'] = 'stixsans'
    # matplotlib.rcParams['mathtext.fontset'] = 'cm'
    matplotlib.rcParams['mathtext.default'] = 'it'
    matplotlib.rcParams['mathtext.fontset'] = 'custom'
    matplotlib.rcParams['mathtext.bf'] = 'Liberation Sans:bold'
    matplotlib.rcParams['mathtext.it'] = 'Liberation Sans:italic'
    matplotlib.rcParams['mathtext.rm'] = 'Liberation Sans'
    matplotlib.rcParams['mathtext.cal'] = 'Liberation Sans'
