import sys
import os

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import scipy.optimize
import scipy.integrate
import h5py_wrapper as h5w
from pathos.multiprocessing import ProcessingPool as PathosPool
import pathos
import copy

from mft_helpers import siegert
from mft_helpers import mft_plotting
from mft_helpers.siegert import piecewise_linear_approx as pla_
import h5py

# CHECK_SIEGERT = True
CHECK_SIEGERT = False
# DEBUG = True
DEBUG = False

########################################
def derived_parameters(pars):
    '''Calculate parameter values derived from other parameters.'''

    pars['NE'] = np.int(pars['beta'] * pars['N'])  ## number of excitatory neurons
    pars['NI'] = pars['N'] - pars['NE']  ## number of inhibitory neurons
    pars['KX'] = pars['ext_K']
    pars['JEE'] = pars['J']  ## weight of EE synapses (before homeostasis; pA)
    pars['R_m'] = pars['tau_m']/pars['C_m']                         ## membrane resistance (MOhm?)

    return pars

#############################################
def get_parameter_set(analysis_pars):
    home= os.path.expanduser("~")
    data_path='%s/%s/%s' % (home,analysis_pars['data_root_path'],analysis_pars['parameterspace_label'])
    sys.path.insert(0, data_path)
    import mft_parameters as data_pars
    P=data_pars.set_parameter_space()
    return P,data_path


########################################
def convert_weights_to_fb_weights(J,pars):
    '''Conversion of synaptic weights (pA) to weights as used in Fourcaud-Brunel 2002 (mV; see notes lam_max_FB.pdf).'''
    #conversion_factor = pars['tau_s']*pars['R_m']/pars['tau_m']  
    conversion_factor = pars['tau_s']/pars['C_m']  
    #conversion_factor = pars['R_m']  
    return conversion_factor * J

########################################
def psc_max_2_psp_max(psc_max,tau_m,tau_s,R_m):
    """
    Converts PSC maximum (pA) to PSP maximum (mV) for exponential PSCs.
    :param psc_max:
    :param tau_m:
    :param tau_s:
    :param R_m:
    :return:
    """
    return psc_max*R_m*tau_s/(tau_s-tau_m)*(
        (tau_m/tau_s)**(-tau_m/(tau_m-tau_s))- \
        (tau_m/tau_s)**(-tau_s/(tau_m-tau_s)) \
    )


def compute_kappas(pars, n_stim, m, is_sigma=False):
    """
    Computes effective couplings kappa as defined in Eqs. 17, 18, without the scaling mathcal{J}.

    :param pars:
    :param n_stim:
    :param m:
    :param is_sigma: [bool] whether we're computing the variance (by squaring g)
    :return:
    """
    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    alpha = pars['noise_alpha']

    # used to square g if we're computing the variance
    g2 = 1. if not is_sigma else g

    kappa_A_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.)
    kappa_A_NA = (N_C - 1.) / N_C * (1 + gamma * g * g2) + (1 - alpha) * ((N_C - 1.) * (1. - m)) / (
                (N_C - 1.) * (1 - m) + 1.)

    kappa_NA_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.)
    kappa_NA_NA = (N_C - 1.) / N_C * (1. + gamma * g * g2) + (1. - alpha) \
                  * (1. + (N_C - 2.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.)
    return kappa_A_A, kappa_A_NA, kappa_NA_A, kappa_NA_NA


def compute_kappas_multi(pars, n_stim, m, is_sigma=False):
    """
    Compute kappas for two input streams (two active clusters).
    :return:
    """
    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    alpha = pars['noise_alpha']

    # used to square g if we're computing the variance
    g2 = 1. if not is_sigma else g

    kappa_A_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.)
    # this is from the other active stream, hence the additional 1-m in the numerator of the last term
    kappa_A_Aother = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.)
    kappa_A_NA = (N_C - 2.) / N_C * (1 + gamma * g * g2) + (1 - alpha) * ((N_C - 2.) * (1. - m)) / (
                (N_C - 1.) * (1 - m) + 1.)


    kappa_NA_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.)
    kappa_NA_NA = (N_C - 2.) / N_C * (1. + gamma * g * g2) + (1. - alpha) \
                  * (1. + (N_C - 3.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.)
    return kappa_A_A, kappa_A_Aother, kappa_A_NA, kappa_NA_A, kappa_NA_NA


def stationary_state_supercluster(nuA, nuNA, pars, verbose=False):
    """
    Main function to compute the stationary states for the superclusters.
    :return:
    """
    if verbose:
        print(f"[stationary_state_supercluster]: {nuA}, {nuNA}")

    if pars['method'] == 'nu0_fb433':
        func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        func = siegert.nu0_fb
    elif pars['method'] == 'piecewise_linear':
        func = siegert.piecewise_linear_approx
    else:
        assert False, "Unknown method?"

    #print('in: ',nuA,nuNA)
    n_stim = 10  # also the number of clusters
    m = pars['modularity']
    alpha = pars['noise_alpha']  # 0.25, used to reduce the background noise in the deeper modules

    kappa_A_A, kappa_A_NA, kappa_NA_A, kappa_NA_NA = compute_kappas(pars, n_stim, m)
    kappa_A_A_std, kappa_A_NA_std, kappa_NA_A_std, kappa_NA_NA_std = \
        compute_kappas(pars, n_stim, m, is_sigma=True)

    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    muA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_A_A * nuA + kappa_A_NA * nuNA)

    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_NA_A * nuA + kappa_NA_NA * nuNA)  # nuA1 + nuA2 etc

    if verbose:
        print(muA,muNA)
    
    if pars['method'] == 'nu0_fb433' or pars['method'] == 'nu0_fb':
        sigmaA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                         * (alpha * pars['nuX'] \
                            + kappa_A_A_std * nuA + kappa_A_NA_std * nuNA))

        sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                          * (alpha * pars['nuX'] \
                             + kappa_NA_A_std * nuA + kappa_NA_NA_std * nuNA))
    elif pars['method'] == 'piecewise_linear':
        sigmaA = None
        sigmaNA = None
 
    # if pars['tau_m'] * 1e-3 * KE * J ** 2  * (alpha * np.abs(pars['nuX']) + kappa_A_A_std * nuA + kappa_A_NA_std * nuNA) < 0 :
    #      import pdb;pdb.set_trace()

    #print('test ',nuA,nuNA,kappa_A_A_std, kappa_A_NA_std, pars['tau_m'] * 1e-3 * KE * J ** 2  * (alpha * pars['nuX'] + kappa_A_A_std * nuA + kappa_A_NA_std * nuNA), sigmaA)
    # if nuA < 0. or nuNA < 0.:
    #     print("negative !!!")
    #     import pdb;pdb.set_trace()


    # if sigmaA < 0. or sigmaNA < 0.:
    #     print("upps !!! ", nuA, nuNA)

    nuA_out = np.abs(func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA, sigmaA) * 1e3)
    nuNA_out = np.abs(func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)

    if nuNA_out ==np.nan:#< 0  or nuNA_out < 0:
        import pdb;pdb.set_trace()
    #print('out: ', nuA,nuNA)#, muA, sigmaA, nuA_out,nuNA_out)
    assert nuA_out >= 0. and nuNA_out >= 0.
    assert nuA_out <= 500. and nuNA_out <= 500.

    state = [nuA_out, nuNA_out, muA, muNA, sigmaA, sigmaNA]
    return state


def stationary_state_supercluster_multi(nuA1, nuA2, nuNA, pars):
    """
    Main function to compute the stationary states for the superclusters, in the case of two input streams.
    :return:
    """
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        siegert_func = siegert.nu0_fb
    else:
        assert False, "Unknown method?"

    n_stim = 10  # also the number of clusters
    m = pars['modularity']
    alpha = pars['noise_alpha']  # 0.25, used to reduce the background noise in the deeper modules

    kappa_A_A, kappa_A_Aother, kappa_A_NA, kappa_NA_A, kappa_NA_NA = compute_kappas_multi(pars, n_stim, m)
    kappa_A_A_std, kappa_A_Aother_std, kappa_A_NA_std, kappa_NA_A_std, kappa_NA_NA_std = \
        compute_kappas_multi(pars, n_stim, m, is_sigma=True)

    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    muA1 = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_A_A * nuA1 + kappa_A_Aother * nuA2 + kappa_A_NA * nuNA)
    muA2 = pars['tau_m'] * 1e-3 * KE * J \
           * (alpha * pars['nuX'] \
              + kappa_A_Aother * nuA1 + kappa_A_A * nuA2 + kappa_A_NA * nuNA)
    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_NA_A * nuA1 + kappa_NA_A * nuA2  + kappa_NA_NA * nuNA)  # nuA1 + nuA2 etc

    sigmaA1 = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
          * (alpha * pars['nuX'] \
             + kappa_A_A_std * nuA1 + kappa_A_Aother_std * nuA2 + kappa_A_NA_std * nuNA))
    sigmaA2 = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * (alpha * pars['nuX'] \
                        + kappa_A_Aother_std * nuA1 + kappa_A_A_std * nuA2 + kappa_A_NA_std * nuNA))
    sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * (alpha * pars['nuX'] \
                        + kappa_NA_A_std * nuA1 + kappa_NA_A_std * nuA2 + kappa_NA_NA_std * nuNA))

    nuA1_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA1, sigmaA1) * 1e3)
    nuA2_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA2, sigmaA2) * 1e3)
    nuNA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)

    state = [nuA1_out, nuA2_out, nuNA_out, muA1, muA2, muNA, sigmaA1, sigmaA2, sigmaNA]
    return state


####################################################################


def stationary_state_cluster(nuA, nuNA, pars):
    """
    Main function to compute the stationary states for the clusters in a given SSN, considering the stationary
    firing rates in the previous SSN.
    :return:
    """
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        siegert_func = siegert.nu0_fb
    else:
        assert False, "Unknown method?"

    n_stim = 10  # also the number of clusters
    m = pars['modularity']

    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    alpha = pars['noise_alpha']

    nuAfp = pars['nuAfp']  # stationary rates of active/stimulated subpop from previous SSN
    nuNAfp = pars['nuNAfp']  # stationary rates of non-active/non-stimulated subpop from previous SSN

    muA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
          + (N_C - 1.) / N_C * (1 + gamma * g)* nuNA \
          + (1 - alpha) * ((N_C - 1.) * (1. - m)) / ((N_C - 1.) * (1 - m) + 1.) * nuNAfp)

    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
          + (N_C - 1.) / N_C * (1. + gamma * g) * nuNA \
             + (1. - alpha) * (1. + (N_C - 2.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.) * nuNAfp)


    sigmaA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
          * (alpha * pars['nuX'] \
             + 1. / N_C * (1. + gamma * g**2) * nuA \
             + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
             + (N_C - 1.) / N_C * (1 + gamma * g**2)* nuNA \
             + (1 - alpha) * ((N_C - 1.) * (1. - m)) / ((N_C - 1.) * (1 - m) + 1.) * nuNAfp))


    sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * (alpha * pars['nuX'] \
                        + 1. / N_C * (1. + gamma * g**2) * nuA \
                        + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
                        + (N_C - 1.) / N_C * (1. + gamma * g**2) * nuNA \
                        + (1. - alpha) * (1. + (N_C - 2.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.) * nuNAfp))

    nuA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA, sigmaA) * 1e3)
    nuNA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)

    state = [nuA_out, nuNA_out, muA, muNA, sigmaA, sigmaNA]
    return state


####################################################################
def stationary_state_first_layer(nuA, nuNA, pars):
    """
    Main function to compute the stationary states for the clusters.
    :return:
    """
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        siegert_func = siegert.nu0_fb
    else:
        assert False, "Unknown method?"

    n_stim = 10  # also the number of clusters
    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    lmbd = pars['lmbd']

    muA = pars['tau_m'] * 1e-3 * KE * J \
          * ( (1.+lmbd) * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (N_C - 1.) / N_C * (1 + gamma * g)* nuNA )

    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * ( 1. * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (N_C - 1.) / N_C * (1. + gamma * g) * nuNA )


    sigmaA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
          * ( (1.+lmbd) * pars['nuX'] \
             + 1. / N_C * (1. + gamma * g**2) * nuA \
             + (N_C - 1.) / N_C * (1 + gamma * g**2)* nuNA ))

    sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * ( 1. * pars['nuX'] \
                        + 1. / N_C * (1. + gamma * g**2) * nuA \
                        + (N_C - 1.) / N_C * (1. + gamma * g**2) * nuNA ))


    nuA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA, sigmaA) * 1e3)
    nuNA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)


    state = [nuA_out, nuNA_out, muA, muNA, sigmaA, sigmaNA]
    return state


####################################################################
def optimize_self_consistent_state(rates, pars):
    nuA = np.abs(rates[0])
    nuNA = np.abs(rates[1])
    # print(rates)
    # exit()

    state = stationary_state_supercluster(nuA, nuNA, pars)

    nuA_out = state[0]
    nuNA_out = state[1]
    return [nuA - nuA_out , nuNA - nuNA_out]


####################################################################
def optimize_self_consistent_state_layers(rates, pars):
    nuA = np.abs(rates[0])
    nuNA = np.abs(rates[1])
    # print(rates)
    # exit()

    state = stationary_state_cluster(nuA, nuNA, pars)

    nuA_out = state[0]
    nuNA_out = state[1]
    pars['full_state_vars'] = copy.deepcopy(state)  # store all variables, including mu and sigma

    return [nuA - nuA_out , nuNA - nuNA_out]


####################################################################
def optimize_self_consistent_state_first_layer(rates, pars):
    nuA = np.abs(rates[0])
    nuNA = np.abs(rates[1])
    # print(rates)
    # exit()

    state = stationary_state_first_layer(nuA, nuNA, pars)

    nuA_out = state[0]
    nuNA_out = state[1]
    pars['full_state_vars'] = copy.deepcopy(state)  # store all variables, including mu and sigma

    return [nuA - nuA_out , nuNA - nuNA_out]


##################################################################
def run_stability_analysis(rates, pars):
    """

    :param rates:
    :param pars:
    :return:
    """
    nuA = np.abs(rates[0])
    nuNA = np.abs(rates[1])
    # test again whether rates are a fixed point
    err = np.array(stationary_state_supercluster(nuA, nuNA, pars, verbose=False))[:2] - rates

    if np.sum(np.abs(err)) > 1e-3:
        # print(f"Not a fixed point? Substituting {rates} back in the self-consistent equations leads to error: {err}")
        return None

    m = pars['modularity']
    KE = pars['KE']
    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    alpha = pars['noise_alpha']
    n_stim = 10  # also the number of clusters
    kappa_A_A, kappa_A_NA, kappa_NA_A, kappa_NA_NA = compute_kappas(pars, n_stim, m)

    muA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_A_A * nuA + kappa_A_NA * nuNA)

    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_NA_A * nuA + kappa_NA_NA * nuNA)  # nuA1 + nuA2 etc
    
    if pars['method'] == 'nu0_fb433':
        print("Computing stability of FP using effective connectivity matrix")
        kappa_A_A_std, kappa_A_NA_std, kappa_NA_A_std, kappa_NA_NA_std = \
            compute_kappas(pars, n_stim, m, is_sigma=True)

        sigmaA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                         * (alpha * pars['nuX'] \
                            + kappa_A_A_std * nuA + kappa_A_NA_std * nuNA))

        sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                          * (alpha * pars['nuX'] \
                             + kappa_NA_A_std * nuA + kappa_NA_NA_std * nuNA))

        scale_to_s = 1e-3
        tau_m = pars['tau_m'] * scale_to_s
        tau_s = pars['tau_s'] * scale_to_s
        tau_r = pars['tau_r'] * scale_to_s
        V_th = pars['V_th']
        V_reset = pars['V_reset']

        # lin contribution of eff coupling, alpha^tilde
        alpha_tilde_tgt_nuA = siegert.d_nu_d_nu_in_fb(tau_m, tau_s, tau_r, V_th, V_reset, J, None, muA, sigmaA)
        alpha_tilde_tgt_nuNA = siegert.d_nu_d_nu_in_fb(tau_m, tau_s, tau_r, V_th, V_reset, J, None, muNA, sigmaNA)

        stability_matrix = -1. / pars['KE'] * np.eye(2) \
                           + np.array([[kappa_A_A * alpha_tilde_tgt_nuA, kappa_A_NA * alpha_tilde_tgt_nuNA],
                                       [kappa_NA_A * alpha_tilde_tgt_nuA, kappa_NA_NA * alpha_tilde_tgt_nuNA]])
    elif pars['method'] == 'nu0_fb':
        func = None
        raise NotImplementedError('Not yet implemented')
    elif pars['method'] == 'piecewise_linear':
        func = siegert.derivative_piecewise_linear_approx
        deriv = lambda x: func(pars, pars['tau_r'], pars['V_th'], x)
        stability_matrix = -1. / (pars['tau_m'] * pars['KE'] * J) \
                           * np.eye(2) + np.array([[kappa_A_A * deriv(muA), kappa_A_NA * deriv(muNA)],
                                                   [kappa_NA_A * deriv(muA), kappa_NA_NA * deriv(muNA)]])
    else:
        raise ValueError("Unknown method?")

    eigvals = np.linalg.eigvals(stability_matrix)
    stable = True
    for eig in eigvals:
        if np.real(eig) >= 0:
            stable = False
    # if stable == False:
    #     import pdb;pdb.set_trace()
    # print(rates, stable)
    return stable


#######################################
def fixed_points(pars, stability_analysis=False, multithreaded=False, modularity_values=np.arange(0., 1.001, 0.1),
                 round_rates=None):
    """
    Compute fixed points
    :param pars:
    :param stability_analysis:
    :param multithreaded: sequential or multithreaded (one thread for each m) computation
    :param modularity_values:
    :param round_rates: [str] Whether to round the fixed point rates: options are None, 'numpy', or 'manual'
    :return:
    """
    print('\nComputing self-consistent states...')
    np.random.seed(pars['seed_initial_rates'])

    results = {}
    siegert.experiment_pars = pars

    def _calc_fp_for_m(args_dict):
        """
        Helper function for calculating the FPs for a given m
        :param args_dict:
        :return:
        """
        m = args_dict['m']
        print("_calc_fp_for_m() for m={}".format(m))
        nu = []
        f = []
        state = []

        pars['modularity'] = m

        ## search for fixed points for multiple random initial rates
        for ci in range(pars['n_initial_rates']):
            initial_rates = (np.random.uniform(pars['intv_initial_ratesA'][0], pars['intv_initial_ratesA'][1], 1),
                             np.random.uniform(pars['intv_initial_ratesNA'][0], pars['intv_initial_ratesNA'][1], 1))

            if pars['method'] == 'nu0_fb433' or pars['method'] == 'nu0_fb':
                sol = scipy.optimize.root(optimize_self_consistent_state, initial_rates, args=(pars))
            elif pars['method'] == 'piecewise_linear':
                sol = scipy.optimize.root(optimize_self_consistent_state, initial_rates, args=(pars), method='lm')
            else:
                raise ValueError('Unknown method?')

            nu += [np.abs(sol.x)]
            f += [sol.fun]

        ## test also explicitly the following initial rates
        if pars['method'] == 'piecewise_linear':
            other_initial_rates = [[0., 0.], [1. / (pars['tau_r'] * 1e-3), 1. / (pars['tau_r'] * 1e-3)],
                                   [0., 1. / (pars['tau_r'] * 1e-3)], [1. / (pars['tau_r'] * 1e-3), 0.]]
            for r in other_initial_rates:
                initial_rates = (np.array([r[0]]), np.array([r[1]]))
                if pars['method'] == 'nu0_fb433' or pars['method'] == 'nu0_fb':
                    sol = scipy.optimize.root(optimize_self_consistent_state, initial_rates, args=(pars))
                elif pars['method'] == 'piecewise_linear':
                    # sol = scipy.optimize.root(optimize_self_consistent_state, initial_rates, args=(pars))
                    sol = scipy.optimize.root(optimize_self_consistent_state, initial_rates, args=(pars), method='lm')
                nu += [np.abs(sol.x)]
                f += [sol.fun]

        ## extract solutions where errors is <  max_error
        f = np.sum(np.array(f) ** 2, axis=1)  ## error of fixed-point solution (Euclidean distance)
        ind = np.where(f < np.abs(pars['max_error']))[0]
        nu = np.array(nu)
        nu = nu[ind, :]

        ## remove multiplicity in solutions
        if pars['method'] == 'piecewise_linear':
            buf, ind = np.unique((nu / pars['rate_precision']).astype(int), axis=0, return_index=True)
        else:
            buf, ind = np.unique((nu[:, 0] / pars['rate_precision']).astype(int), return_index=True)

        ## for the siegert function, we need to round solution according to rate_precision - otherwise weird results
        if pars['method'] == 'nu0_fb433' and round_rates:
            if round_rates == 'numpy':
                decimal_precision = -int(np.log10(pars['rate_precision']))
                nu = np.round(nu[ind, :], decimals=decimal_precision)
            elif round_rates == 'manual':
                nu = (nu[ind, :] / pars['rate_precision']).astype(int) * pars['rate_precision']
        else:
            nu = nu[ind, :]

        time_interval = (0, 2000)
        # time_interval = (0,1)
        # derivative = lambda t,nu: np.array(stationary_state_supercluster(nu[0], nu[1], pars))[:2] - nu
        derivative = lambda t, nu_: np.array(stationary_state_supercluster(nu_[0], nu_[1], pars, verbose=False))[:2] - nu_

        stability = []
        if stability_analysis:
            for i, rates in enumerate(nu):
                if stability_analysis == 'matrix':
                    try:
                        stable = run_stability_analysis(rates, pars)
                        stability.append(stable)
                    except NotImplementedError as e:
                        raise NotImplementedError(e)

                elif stability_analysis == 'manual':
                    try:
                        sol = scipy.integrate.solve_ivp(derivative, time_interval, rates).y[:, -1]
                        dists_to_other_fps = []
                        for j, other_rates in enumerate(nu):
                            dists_to_other_fps.append(np.abs(np.sum(sol - other_rates)))
                        min_dist_to_other_fps = np.min(dists_to_other_fps)
                        ind = np.where(dists_to_other_fps == min_dist_to_other_fps)[0]
                        if ind == i:
                            stability.append(True)
                        else:
                            stability.append(False)
                    except Exception as e:
                        print("WARNING! Exception occurred during stability analysis: {}".format(str(e)))
                        stability.append(-1)
                else:
                    raise NotImplementedError('Please specify `matrix` or `manual` for stability analysis!')

        n_fp = nu.shape[0]  # number of fixed points
        print('\t{} fixed points found for m = {}'.format(n_fp, m))
        for cs in range(n_fp):
            print('\t\t nu_A=%.2f/s, nu_NA=%.2f/s, ' % (nu[cs, 0], nu[cs, 1]) + 'stability = ' + str(stability[cs]))

        for cs in range(n_fp):
            print('\tRecomputing the state for nu[A, NA] = ({}) for m = {}'.format(nu[cs, :], m))
            state += [stationary_state_supercluster(nu[cs, 0], nu[cs, 1], pars)]
        return nu, state, stability

    if multithreaded:
        thread_args_dict = [{'m': x} for x in modularity_values]
        pool = PathosPool(len(thread_args_dict))
        pool_results = pool.map(_calc_fp_for_m, thread_args_dict)
        results = {m: pool_results[idx] for idx, m in enumerate(modularity_values)}
    else:
        for m in modularity_values:
            results[m] = _calc_fp_for_m({'m': m})

    print('\nFinished computing self-consistent limit states...')
    return results


def fixed_points_piecewise_linear(pars, nuXs, modularity_range, N_C=10., multithreaded=False, verbose=False):
    """
    For a range of nuX and modularity values, it computes the stable and unstable fixed points
    for the piecewise linear activation function.

    :param pars:
    :param nuXs:
    :param modularity_range:
    :param N_C:
    :param multithreaded:
    :return: [list of dicts] Returns a list of dictionaries, for each nuX value. Each dictionary contains two keys
     for saturated and non-saturated FPs, the values being a list of stable / unstable FPs for a given m value.
    """
    J = convert_weights_to_fb_weights(pars['J'], pars)

    mu_min = pars['pcwlin_mu_min']
    mu_max = pars['pcwlin_mu_max']
    nu_max = pars['pcwlin_nu_max']
    siegert.experiment_pars = pars

    alpha = pars['noise_alpha']
    gamma = pars['NI'] / pars['NE']
    g = - pars['g']
    tau = pars['tau_m'] * 1e-3
    KE = pars['KE']

    def __compute_fps_for_nuX(args):
        nuX_ = args['nuX']
        print(f"Computing fixed points of piecewise linear for nuX: {nuX_}")
        pars_ = copy.deepcopy(pars)
        pars_['nuX'] = nuX_

        results = {'fp_non_saturated': [], 'fp_saturated': []}
        # Calculate fixed point
        for m in modularity_range:
            pars_['modularity'] = m
            kappa_SS = 1. / N_C * (1 + gamma * g) + (1 - alpha) * 1. / ((N_C - 1) * (1 - m) + 1)
            kappa_NSS = 1. / N_C * (1 + gamma * g) + (1 - alpha) * (1. - m) / ((N_C - 1) * (1 - m) + 1)

            ###########################
            # non-saturated fixed point
            mu_S = mu_min * (1 + (alpha * nuX_ - mu_min / (tau * KE * J)) / (
                        mu_min / (tau * KE * J) - kappa_SS * nu_max * mu_min / (mu_max - mu_min)))
            mu_NS = (tau * KE * J) * (alpha * nuX_ + kappa_NSS * nu_max * (mu_S - mu_min) / (mu_max - mu_min))

            nu_S = pla_(pars_['tau_m'], pars_['tau_s'], pars_['tau_r'], pars_['V_th'], pars_['V_reset'], mu_S,
                        None) * 1e3
            nu_NS = pla_(pars_['tau_m'], pars_['tau_s'], pars_['tau_r'], pars_['V_th'], pars_['V_reset'], mu_NS,
                         None) * 1e3

            stable = run_stability_analysis([nu_S, nu_NS], pars_)

            from matplotlib import pyplot as plt
            if stable is True and nu_S > 0 and np.abs(nu_NS) < 1e-10:
                # plt.scatter([nuX_],[m],color='k',s=0.4)
                stable_num = 1
            elif stable is False and nu_S > 0 and np.abs(nu_NS) < 1e-10:
                # plt.scatter([nuX_],[m],color='r',s=0.4)
                stable_num = 0
            else:
                stable_num = np.nan

            if verbose:
                if stable is None:
                    print('m: {}, nuX: {}; mus: {}, {}  \t\t nus: {}, {}\t\tstable: NOT A FP'.format(
                        m, nuX_, mu_S, mu_NS, nu_S, nu_NS, stable))
                else:
                    print('m: {}, nuX: {}; mus: {}, {}  \t\t nus: {}, {}\t\tstable: {}'.format(m, nuX_, mu_S, mu_NS, nu_S, nu_NS, stable))

            if nu_S < nu_max:
                results['fp_non_saturated'].append(stable_num)
            else:
                # we can just skip it because it'll anyway be checked and added in the next step
                results['fp_non_saturated'].append(np.nan)  # has no effect, in the end.. so it's okay
                # results['fp_saturated'].append(stable_num)  # messes up the indexing

            #######################
            # saturated fixed point
            nu_S = nu_max
            nu_NS = 0.
            stable = run_stability_analysis([nu_S, nu_NS], pars_)
            if verbose:
                if stable is None:
                    print('m: {}, nuX: {}; mus: {}, {}  \t\t nus: {}, {}\t\t'
                          'stable: NOT A FP'.format(m, nuX_, mu_S, mu_NS, nu_S, nu_NS, stable))
                else:
                    print('m: {}, nuX: {}; mus: {}, {}  \t\t nus: {}, {}\t\t'
                          'stable: {}'.format(m, nuX_, mu_S, mu_NS, nu_S, nu_NS, stable))
            if stable is True:
                # pl.scatter([nuX],[m],color='k',s=5.,marker='h')
                # plt.scatter([nuX_],[m],color='k',s=5., marker='h')
                # pl.scatter([nuX],[m],color='k',s=0.4)
                results['fp_saturated'].append(1)
            elif stable is False:
                # plt.scatter([nuX_],[m],color='r',s=5.,marker='h')
                # pl.scatter([nuX],[m],color='r',s=0.4)
                results['fp_saturated'].append(0)
            else:
                results['fp_saturated'].append(np.nan)
        return results

    pool_results = []
    if multithreaded:
        thread_args_dict = [{'nuX': x} for x in nuXs]
        pool = PathosPool(len(thread_args_dict))
        pool_results = pool.map(__compute_fps_for_nuX, thread_args_dict)
    else:
        for x in nuXs:
            pool_results.append(__compute_fps_for_nuX({'nuX': x}))

    return pool_results


def compute_theoretical_boundaries_pcws_linear(pars, nuX, N_C):
    """
    Returns the theoretical modularity boundaries for a given nuX.

    :param pars:
    :param nuX:
    :param N_C:
    :return:
    """
    J = convert_weights_to_fb_weights(pars['J'], pars)

    mu_min = pars['pcwlin_mu_min']
    mu_max = pars['pcwlin_mu_max']
    nu_max = pars['pcwlin_nu_max']

    alpha = pars['noise_alpha']
    gamma = pars['NI'] / pars['NE']
    g = - pars['g']
    tau = pars['tau_m'] * 1e-3
    KE = pars['KE']

    eta_min = mu_min / (pars['tau_m'] * 1e-3 * pars['KE'] * J * (1 - alpha) * nu_max)
    eta_max = mu_max / (pars['tau_m'] * 1e-3 * pars['KE'] * J * (1 - alpha) * nu_max)

    Delta = alpha * nuX / (1 - alpha) / nu_max + 1 / N_C / (1 - alpha) * (1 + gamma * g)
    # nu_x_min = (1-alpha)/alpha*nu_max*((eta_max-1.)-1./N_C/(1-alpha)*(1+gamma*g))
    # nu_x_max = (1-alpha)/alpha*nu_max*(eta_min-1./N_C/(1-alpha)*(1+gamma*g))

    m_crit1 = (mu_max - mu_min) * N_C / (tau * KE * J * nu_max) / (
                1 - alpha + (mu_max - mu_min) / (tau * KE * J * nu_max) * (N_C - 1))
    m_crit2 = N_C / (N_C - 1) - 1. / (N_C - 1) * 1. / (eta_max - Delta)
    m_crit3 = N_C / (N_C - 1) - 1. / (N_C - 1) * 1. / (1 - (eta_min - Delta) * (N_C - 1))
    return m_crit1, m_crit2, m_crit3


######################################
def nu_vs_layers(pars, modularity_values=None, rates_initial_layer=np.array([8.5,5.9]), max_layers=50,
                 multithreaded=False):
    """
        Calculate the stationary firing rates in each layer / sub-network.
    :param pars:
    :param modularity_values:
    :param rates_initial_layer:
    :param max_layers:
    :return:
    """

    print('\nComputing self-consistent states across layers...')
    np.random.seed(pars['seed_initial_rates'])

    results = {}
    limits = []
    # for m in modularity_values:
    def _calc_fp_for_m(args_dict):
        m = args_dict['m']
        # print("_calc_fp_for_m() for m={}".format(m))
        pars_ = copy.deepcopy(pars)
        pars_['modularity'] = m

        initial_rates = 5*np.ones(2)  # initial rates for fixed point iteration (has no influence on result)
        previous_rates = rates_initial_layer
        print('m = '+str(m))
        nu = []
        for _ in range(max_layers):
            #print(previous_rates)
            pars_['nuAfp'] = previous_rates[0]
            pars_['nuNAfp'] = previous_rates[1]
            sol = scipy.optimize.root(optimize_self_consistent_state_layers, initial_rates,args=(pars_))
            previous_rates = np.abs(sol.x)
            nu.append(np.abs(sol.x))
        # results[m] = np.array(nu)
        # limits.append(np.abs(sol.x))
        return np.array(nu), np.abs(sol.x)

    if multithreaded:
        thread_args_dict = [{'m': x} for x in modularity_values]
        pool = PathosPool(len(thread_args_dict))
        pool_results = pool.map(_calc_fp_for_m, thread_args_dict)
        results = {m: pool_results[idx][0] for idx, m in enumerate(modularity_values)}
        limits = [pool_results[idx][1] for idx, m in enumerate(modularity_values)]
    else:
        for m in modularity_values:
            r, l = _calc_fp_for_m({'m': m})
            results[m] = r
            limits.append(l)

    print('\nFinished computing self-consistent states across layers...')
    return results, modularity_values, np.array(limits)


#######################################
def nu_first_layer(pars, lmbds=None):
    """
    Computes self-consistent states in first layer
    :param pars:
    :param lmbds: list / array of lambdas (input intensity)
    :return:
    """
    print('\nComputing self-consistent states in first layer...')
    np.random.seed(pars['seed_initial_rates'])

    results = {}
    lmbds = np.arange(0., 0.26, 0.01) if lmbds is None else lmbds
    initial_rates = 5.*np.ones(2)  # does not influence the result
    for lmbd in lmbds:
        pars['lmbd'] = lmbd
        sol = scipy.optimize.root(optimize_self_consistent_state_first_layer, initial_rates, args=(pars))
        results[lmbd] = np.abs(sol.x)

    print('\nFinished computing self-consistent states in first layer...')
    return results


#######################################
def test_saturation(pars):
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    else:
        siegert_func = siegert.nu0_fb

    sigma = 1.
    result = []
    rates = np.arange(1., 2001., 10)

    for mu in rates:
        r = siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], mu, sigma) * 1e3
        result.append(r)

    plt.plot(rates, result)
    plt.xlabel('nu (mean)')
    plt.title("Saturation of the Siegert function with variance = {}".format(sigma))
    # plt.show()
    plt.savefig(os.path.join(pars['data_root_path'], 'siegert_saturation.pdf'))


def _opt_calc_f_NA(rates, pars):
    """
    For a fixed nuA (rA), this function computes the stationary rate rNA of the non-active (non-stimulated)
    population.

    :param rates: tuple (nuA, nuNA)
    :param pars:
    :return:
    """
    nuA_fixed = pars['rA_fixed']  # always load the fixed rate rA
    nuNA = np.abs(rates[1])  # load current rNA as modified during optimization

    # we can use the previously implemented function to compute the siegerts, as nothing changes there
    state = stationary_state_supercluster(nuA_fixed, nuNA, pars)
    nuNA_out = state[1]
    # print(rates, nuA_fixed, state[:2])

    # for nuA, ensure the error (difference) is 0, and just optimize for nuNA
    return [0., nuNA - nuNA_out]


def _opt_calc_f_NA_multi(rates, pars):
    """
    For fixed nuA1 and nuA2 (rA1, rA2), this function computes the stationary rate rNA of
    the non-active (non-stimulated) sub-population.

    :param rates: tuple (nuA1, nuA2, nuNA)
    :param pars:
    :return:
    """
    nuA1_fixed = pars['rA1_fixed']  # always load the fixed rate rA
    nuA2_fixed = pars['rA2_fixed']  # always load the fixed rate rA
    nuNA = np.abs(rates[2])  # load current rNA as modified during optimization

    # we can use the previously implemented function to compute the siegerts, as nothing changes there
    state = stationary_state_supercluster_multi(nuA1_fixed, nuA2_fixed, nuNA, pars)
    nuNA_out = state[2]
    # print(rates, nuA_fixed, state[:2])

    # for nuA1 and nuA2, ensure the error (difference) is 0, and just optimize for nuNA
    return [0., 0., nuNA - nuNA_out]


def calc_siegert_A(r, pars):
    """

    :param r:
    :param pars:
    :return:
    """
    pars['rA_fixed'] = r  # we fix rA here and will calculate an rNA consistent with this
    rNA_init = 10.  # can be arbitrary, doesn't make any difference

    # calculate rNA = siegert(rA, rNA), to be consistent with the fixed rA=r in this integration step
    rNA_sol = scipy.optimize.root(_opt_calc_f_NA, (r, rNA_init), args=(pars))

    rNA = abs(rNA_sol.x[1])
    rA_siegert = stationary_state_supercluster(r, rNA, pars)[0]
    # print(rA_siegert)
    return rA_siegert


def calc_siegert_A_multi(r1, r2, pars):
    """
    :param r1:
    :param r2:
    :param pars:
    :return:
    """
    pars['rA1_fixed'] = r1  # we fix rA here and will calculate an rNA consistent with this
    pars['rA2_fixed'] = r2  # we fix rA here and will calculate an rNA consistent with this
    rNA_init = 10.  # can be arbitrary, doesn't make any difference

    # calculate rNA = siegert(rA, rNA), to be consistent with the fixed rA=r in this integration step
    rNA_sol = scipy.optimize.root(_opt_calc_f_NA_multi, (r1, r2, rNA_init), args=(pars, ))

    rNA = abs(rNA_sol.x[2])
    rA_siegert = stationary_state_supercluster_multi(r1, r2, rNA, pars)

    if CHECK_SIEGERT:
        check_rA_siegert = stationary_state_supercluster_multi(rA_siegert[0], rA_siegert[1], rA_siegert[2], pars)
        print(f"CHECK_SIEGERT: \n\t{rA_siegert[:3]}\n\t{check_rA_siegert[:3]},\n\t == {rA_siegert[:3] == check_rA_siegert[:3]}")
    return rA_siegert


def test_siegert_A_multi_correctness(r1, r2, pars, etol=1e-1):
    """
    Tests for siegert correctness / stability by plugging the computed rates back in.
    :param r1:
    :param r2:
    :param pars:
    :return:
    """
    pars['rA1_fixed'] = r1  # we fix rA here and will calculate an rNA consistent with this
    pars['rA2_fixed'] = r2  # we fix rA here and will calculate an rNA consistent with this
    rNA_init = 10.  # can be arbitrary, doesn't make any difference

    rNA_sol = scipy.optimize.root(_opt_calc_f_NA_multi, (r1, r2, rNA_init), args=(pars, ))

    rNA = abs(rNA_sol.x[2])
    rA_siegert = stationary_state_supercluster_multi(r1, r2, rNA, pars)

    check_rA_siegert = stationary_state_supercluster_multi(rA_siegert[0], rA_siegert[1], rA_siegert[2], pars)
    # equal_enough = np.all(np.isclose(check_rA_siegert[:3], rA_siegert[:3], atol=etol))
    equal_enough = np.all(np.isclose(check_rA_siegert[:2], np.array([r1, r2]), atol=etol))
    print(f"\tCHECK_SIEGERT: \n\t{(r1, r2)}\n\t{rA_siegert[:3]}\n\t{check_rA_siegert[:3]},\n\t == {equal_enough}")
    return equal_enough


def energy_single_trapezoid(dr_grid, modularities, pars, recompute=True, data_full_path=False):
    """
    Compute the potential over the range of rA values from rate_A_interval, for a single input stream.

    :param dr_grid: (fine) grid of rates on which to compute the energy (integral)
    :param modularities:
    :param pars:
    :param recompute:
    :return: dictionary of potentials U for the values on the grid
    """
    print(f"Computing the energy for modularity values = {modularities}")
    potentials_dict = {}

    if not recompute:
        try:
            with h5py.File(data_full_path, "r") as f:
                for m in f.keys():
                    potentials_dict[float(m)] = np.array(f[m])
        except FileNotFoundError as e:
            print(f"No previously stored data found : {data_full_path}\n\nError: {e}")
        except Exception as ee:
            raise ee

    for m in modularities:
        pars['modularity'] = m

        if ~np.any([np.isclose(m, k) for k in potentials_dict]) and not recompute:
            potentials_grid = []
            siegert_rA_samples = []

            print(f"Calculating Siegert functions over the grid for m = {m}")
            for rA in dr_grid:
                siegert_rA_samples.append(calc_siegert_A(rA, pars))  # compute & store siegert Phi_A(rA, f(rA))

            # cumulative integral of G(rA) along the integration grid dr
            print("Calculating integral G(rA)")
            cum_integral_G_rA = scipy.integrate.cumtrapz(siegert_rA_samples, dr_grid, initial=0)

            # iterate over each rA on the grid and compute corresponding U(rA)
            for idx in range(len(dr_grid)):
                p = -1 / 2. * (dr_grid[idx] ** 2) + cum_integral_G_rA[idx]  # potential U(rA)
                potentials_grid.append(p)

            potentials_dict[m] = np.array(potentials_grid)
            print(f"Computed potentials for m={m}")
        else:
            print(f"Found previous data for m={m}")
            pass  # do nothing, keep old data

    return potentials_dict


def energy_multi_trapezoid_direct_path(rates_Ax, modularities, pars, recompute=True, data_full_path=None):
    """
    Compute the potential over the range of rA values from rate_A_interval, for a single input stream.
    :param rates_Ax: (coarser) rate grid for which U(r1, r2) will be computed
    :param modularities:
    :param pars:
    :param recompute:
    :param data_full_path: full path for previously computed data and/or where to store the new one, if not found
    :return: array of potentials U for the values on the grid
    """
    cpu_count = pathos.helpers.cpu_count()
    print(f"Computing the energy for modularity values = {modularities}, with #cpus = {cpu_count}")
    potentials_dict = {}

    if not recompute:
        try:
            with h5py.File(data_full_path, "r") as f:
                for m in f.keys():
                    potentials_dict[float(m)] = np.array(f[m])
        except FileNotFoundError as e:
            print(f"No previously stored data found : {data_full_path}\n\nError: {e}")
        except Exception as ee:
            raise ee

    grid_size_scaling = 4
    for m in modularities:
        pars['modularity'] = m

        if m not in potentials_dict or recompute:
            # 2D potentials are stored in a rA1 x rA2 matrix
            potentials_dict[m] = np.zeros((len(rates_Ax), len(rates_Ax)))

            def _mth_calc_potential_rA1_rA2(args_dict):
                """
                Multithreaded worker.
                :param args_dict:
                :return:
                """
                rA1_ = args_dict['rA1']
                rA2_ = args_dict['rA2']
                if DEBUG:
                    print(f"--------\nCalculating potential for (rA1, rA2)=({rA1_}, {rA2_})")

                dr_grid_ = np.linspace(0., rA1_, int(rA1_) * grid_size_scaling)
                integral_rA1 = scipy.integrate.trapz(
                    [calc_siegert_A_multi(z, z * rA2_ / rA1_, pars)[0] for z in dr_grid_], dr_grid_)

                dr_grid_ = np.linspace(0., rA2_, int(rA2_) * grid_size_scaling)
                integral_rA2 = scipy.integrate.trapz(
                    [calc_siegert_A_multi(z * rA1_ / rA2_, z, pars)[1] for z in dr_grid_], dr_grid_)

                p = -1 / 2. * (rA1_ ** 2) - 1 / 2. * (rA2_ ** 2) + integral_rA1 + integral_rA2
                if DEBUG:
                    print(f"[DEBUG] Computing potential for ({rA1_}, {rA2_}) = {p}", flush=True)
                return p

            max_threads = pars['max_cpu_threads_2Dpotentials']
            # n_data_points = len(rates_Ax) * (len(rates_Ax) + 1) // 2
            # datapoints_unrolled = list(itertools.product(rates_Ax, rates_Ax))
            datapoints_unrolled = [(rates_Ax[i], rates_Ax[j]) for i in range(len(rates_Ax)) for j in range(i, len(rates_Ax), 1)]
            n_data_points = len(datapoints_unrolled)

            # print(f"Need to calculate potentials on a 2D grid of {len(rates_Ax)}x{len(rates_Ax)},"
            #       f"with a total of {n_data_points} necessary computations (optimized, diagonally).")

            for thread_it in range(np.ceil(n_data_points / max_threads).astype(int)):
                print(f">>> Processing threads (rates) {thread_it * max_threads} - {(thread_it + 1) * max_threads}")
                subset_rates = datapoints_unrolled[thread_it * max_threads : (thread_it + 1) * max_threads]
                thread_args_dict = [{'rA1': rA1, 'rA2': rA2} for rA1, rA2 in subset_rates]
                pool = PathosPool(len(thread_args_dict))
                results = pool.map(_mth_calc_potential_rA1_rA2, thread_args_dict)

                for r_idx, r in enumerate(results):
                    mat_idx_i = np.where(rates_Ax == subset_rates[r_idx][0])[0][0]
                    mat_idx_j = np.where(rates_Ax == subset_rates[r_idx][1])[0][0]
                    potentials_dict[m][mat_idx_i][mat_idx_j] = r
                    # potentials_dict[m][mat_idx_j][mat_idx_i] = r  # can mirror already here

                print(f"Computed potentials for m={m} "
                      f"and threads {thread_it * max_threads} - {(thread_it + 1) * max_threads}", flush=True)
                h5w.save(data_full_path, potentials_dict, overwrite_dataset=True)
        else:
            print(f"Found previous data for m={m}")
            # mirror temporarily here as well
            for i in range(potentials_dict[m].shape[0]):
                for j in range(i, potentials_dict[m].shape[1], 1):
                    potentials_dict[m][j, i] = potentials_dict[m][i, j]
            # h5w.save(filename, potentials_dict, overwrite_dataset=True)
            pass  # do nothing, keep old data

    return potentials_dict
