import sys
import os

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import scipy.optimize
import scipy.integrate
import h5py_wrapper as h5w
from pathos.multiprocessing import ProcessingPool as PathosPool
import pathos

from mft_helpers import siegert
from mft_helpers import mft_plotting
# CHECK_SIEGERT = True
CHECK_SIEGERT = False
# DEBUG = True
DEBUG = False

########################################
def derived_parameters(pars):
    '''Calculate parameter values derived from other parameters.'''

    pars['NE'] = np.int(pars['beta'] * pars['N'])  ## number of excitatory neurons
    pars['NI'] = pars['N'] - pars['NE']  ## number of inhibitory neurons
    pars['KX'] = pars['ext_K']
    pars['JEE'] = pars['J']  ## weight of EE synapses (before homeostasis; pA)
    pars['R_m'] = pars['tau_m']/pars['C_m']                         ## membrane resistance (MOhm?)

    return pars

#############################################
def get_parameter_set(analysis_pars):
    home= os.path.expanduser("~")
    data_path='%s/%s/%s' % (home,analysis_pars['data_root_path'],analysis_pars['parameterspace_label'])
    sys.path.insert(0, data_path)
    import mft_parameters as data_pars
    P=data_pars.set_parameter_space()
    return P,data_path


########################################
def convert_weights_to_fb_weights(J,pars):
    '''Conversion of synaptic weights (pA) to weights as used in Fourcaud-Brunel 2002 (mV; see notes lam_max_FB.pdf).'''
    #conversion_factor = pars['tau_s']*pars['R_m']/pars['tau_m']  
    conversion_factor = pars['tau_s']/pars['C_m']  
    #conversion_factor = pars['R_m']  
    return conversion_factor * J

########################################
def psc_max_2_psp_max(psc_max,tau_m,tau_s,R_m):
    '''Converts PSC maximum (pA) to PSP maximum (mV) for exponential PSCs. '''
    return psc_max*R_m*tau_s/(tau_s-tau_m)*(
        (tau_m/tau_s)**(-tau_m/(tau_m-tau_s))- \
        (tau_m/tau_s)**(-tau_s/(tau_m-tau_s)) \
    )


def compute_kappas(pars, n_stim, m, is_sigma=False):
    """
    as in David's code, also plotted.. seems fine
    :return:
    """
    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    alpha = pars['noise_alpha']

    # used to square g if we're computing the variance
    g2 = 1. if not is_sigma else g

    kappa_A_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.)
    kappa_A_NA = (N_C - 1.) / N_C * (1 + gamma * g * g2) + (1 - alpha) * ((N_C - 1.) * (1. - m)) / (
                (N_C - 1.) * (1 - m) + 1.)

    kappa_NA_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.)
    kappa_NA_NA = (N_C - 1.) / N_C * (1. + gamma * g * g2) + (1. - alpha) \
                  * (1. + (N_C - 2.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.)
    return kappa_A_A, kappa_A_NA, kappa_NA_A, kappa_NA_NA


def compute_kappas_multi(pars, n_stim, m, is_sigma=False):
    """
    Compute kappas for two input streams (two active clusters).
    :return:
    """
    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    alpha = pars['noise_alpha']

    # used to square g if we're computing the variance
    g2 = 1. if not is_sigma else g

    kappa_A_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.)
    # this is from the other active stream, hence the additional 1-m in the numerator of the last term
    kappa_A_Aother = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.)
    kappa_A_NA = (N_C - 2.) / N_C * (1 + gamma * g * g2) + (1 - alpha) * ((N_C - 2.) * (1. - m)) / (
                (N_C - 1.) * (1 - m) + 1.)


    kappa_NA_A = 1. / N_C * (1. + gamma * g * g2) + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.)
    kappa_NA_NA = (N_C - 2.) / N_C * (1. + gamma * g * g2) + (1. - alpha) \
                  * (1. + (N_C - 3.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.)
    return kappa_A_A, kappa_A_Aother, kappa_A_NA, kappa_NA_A, kappa_NA_NA


def stationary_state_supercluster(nuA, nuNA, pars):
    """
    Main function to compute the stationary states for the superclusters.
    :return:
    """
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        siegert_func = siegert.nu0_fb
    else:
        assert False, "Unknown method?"

    n_stim = 10  # also the number of clusters
    m = pars['modularity']
    alpha = pars['noise_alpha']  # 0.25, used to reduce the background noise in the deeper modules

    kappa_A_A, kappa_A_NA, kappa_NA_A, kappa_NA_NA = compute_kappas(pars, n_stim, m)
    kappa_A_A_std, kappa_A_NA_std, kappa_NA_A_std, kappa_NA_NA_std = \
        compute_kappas(pars, n_stim, m, is_sigma=True)

    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    muA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_A_A * nuA + kappa_A_NA * nuNA)

    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_NA_A * nuA + kappa_NA_NA * nuNA)  # nuA1 + nuA2 etc

    sigmaA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
          * (alpha * pars['nuX'] \
             + kappa_A_A_std * nuA + kappa_A_NA_std * nuNA))

    sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * (alpha * pars['nuX'] \
                        + kappa_NA_A_std * nuA + kappa_NA_NA_std * nuNA))


    nuA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA, sigmaA) * 1e3)
    nuNA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)

    state = [nuA_out, nuNA_out, muA, muNA, sigmaA, sigmaNA]
    return state


def stationary_state_supercluster_multi(nuA1, nuA2, nuNA, pars):
    """
    Main function to compute the stationary states for the superclusters, in the case of two input streams.
    :return:
    """
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        siegert_func = siegert.nu0_fb
    else:
        assert False, "Unknown method?"

    n_stim = 10  # also the number of clusters
    m = pars['modularity']
    alpha = pars['noise_alpha']  # 0.25, used to reduce the background noise in the deeper modules

    kappa_A_A, kappa_A_Aother, kappa_A_NA, kappa_NA_A, kappa_NA_NA = compute_kappas_multi(pars, n_stim, m)
    kappa_A_A_std, kappa_A_Aother_std, kappa_A_NA_std, kappa_NA_A_std, kappa_NA_NA_std = \
        compute_kappas_multi(pars, n_stim, m, is_sigma=True)

    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    muA1 = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_A_A * nuA1 + kappa_A_Aother * nuA2 + kappa_A_NA * nuNA)
    muA2 = pars['tau_m'] * 1e-3 * KE * J \
           * (alpha * pars['nuX'] \
              + kappa_A_Aother * nuA1 + kappa_A_A * nuA2 + kappa_A_NA * nuNA)
    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + kappa_NA_A * nuA1 + kappa_NA_A * nuA2  + kappa_NA_NA * nuNA)  # nuA1 + nuA2 etc

    sigmaA1 = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
          * (alpha * pars['nuX'] \
             + kappa_A_A_std * nuA1 + kappa_A_Aother_std * nuA2 + kappa_A_NA_std * nuNA))
    sigmaA2 = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * (alpha * pars['nuX'] \
                        + kappa_A_Aother_std * nuA1 + kappa_A_A_std * nuA2 + kappa_A_NA_std * nuNA))
    sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * (alpha * pars['nuX'] \
                        + kappa_NA_A_std * nuA1 + kappa_NA_A_std * nuA2 + kappa_NA_NA_std * nuNA))

    nuA1_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA1, sigmaA1) * 1e3)
    nuA2_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA2, sigmaA2) * 1e3)
    nuNA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)

    state = [nuA1_out, nuA2_out, nuNA_out, muA1, muA2, muNA, sigmaA1, sigmaA2, sigmaNA]
    return state


####################################################################


def stationary_state_cluster(nuA, nuNA, pars):
    """
    Main function to compute the stationary states for the clusters.
    :return:
    """
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        siegert_func = siegert.nu0_fb
    else:
        assert False, "Unknown method?"

    n_stim = 10  # also the number of clusters
    m = pars['modularity']
    alpha = pars['noise_alpha']  # 0.25, used to reduce the background noise in the deeper modules

    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    alpha = pars['noise_alpha']

    nuAfp = pars['nuAfp']
    nuNAfp = pars['nuNAfp']

    muA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
          + (N_C - 1.) / N_C * (1 + gamma * g)* nuNA \
          + (1 - alpha) * ((N_C - 1.) * (1. - m)) / ((N_C - 1.) * (1 - m) + 1.) * nuNAfp)

    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * (alpha * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
          + (N_C - 1.) / N_C * (1. + gamma * g) * nuNA \
             + (1. - alpha) * (1. + (N_C - 2.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.) * nuNAfp)


    sigmaA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
          * (alpha * pars['nuX'] \
             + 1. / N_C * (1. + gamma * g**2) * nuA \
             + (1. - alpha) * 1. / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
             + (N_C - 1.) / N_C * (1 + gamma * g**2)* nuNA \
             + (1 - alpha) * ((N_C - 1.) * (1. - m)) / ((N_C - 1.) * (1 - m) + 1.) * nuNAfp))


    sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * (alpha * pars['nuX'] \
                        + 1. / N_C * (1. + gamma * g**2) * nuA \
                        + (1. - alpha) * (1. - m) / ((N_C - 1.) * (1. - m) + 1.) * nuAfp \
                        + (N_C - 1.) / N_C * (1. + gamma * g**2) * nuNA \
                        + (1. - alpha) * (1. + (N_C - 2.) * (1. - m)) / ((N_C - 1.) * (1. - m) + 1.) * nuNAfp))


    nuA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA, sigmaA) * 1e3)
    nuNA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)


    state = [nuA_out, nuNA_out, muA, muNA, sigmaA, sigmaNA]
    return state


####################################################################


def stationary_state_first_layer(nuA, nuNA, pars):
    """
    Main function to compute the stationary states for the clusters.
    :return:
    """
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    elif pars['method'] == 'nu0_fb':
        siegert_func = siegert.nu0_fb
    else:
        assert False, "Unknown method?"

    n_stim = 10  # also the number of clusters
    J = convert_weights_to_fb_weights(pars['JEE'], pars)
    KE = 800

    N_C = n_stim
    g = -pars['g']
    gamma = 0.25  # K_I/K_E
    lmbd = pars['lmbd']

    muA = pars['tau_m'] * 1e-3 * KE * J \
          * ( (1.+lmbd) * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (N_C - 1.) / N_C * (1 + gamma * g)* nuNA )

    muNA = pars['tau_m'] * 1e-3 * KE * J \
          * ( 1. * pars['nuX'] \
          + 1. / N_C * (1. + gamma * g) * nuA \
          + (N_C - 1.) / N_C * (1. + gamma * g) * nuNA )


    sigmaA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
          * ( (1.+lmbd) * pars['nuX'] \
             + 1. / N_C * (1. + gamma * g**2) * nuA \
             + (N_C - 1.) / N_C * (1 + gamma * g**2)* nuNA ))

    sigmaNA = np.sqrt(pars['tau_m'] * 1e-3 * KE * J ** 2 \
                     * ( 1. * pars['nuX'] \
                        + 1. / N_C * (1. + gamma * g**2) * nuA \
                        + (N_C - 1.) / N_C * (1. + gamma * g**2) * nuNA ))


    nuA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muA, sigmaA) * 1e3)
    nuNA_out = np.abs(siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], muNA, sigmaNA) * 1e3)


    state = [nuA_out, nuNA_out, muA, muNA, sigmaA, sigmaNA]
    return state



####################################################################
def optimize_self_consistent_state(rates, pars):
    nuA = np.abs(rates[0])
    nuNA = np.abs(rates[1])
    # print(rates)
    # exit()

    state = stationary_state_supercluster(nuA, nuNA, pars)

    nuA_out = state[0]
    nuNA_out = state[1]

    return [nuA - nuA_out , nuNA - nuNA_out]


####################################################################
def optimize_self_consistent_state_layers(rates, pars):
    nuA = np.abs(rates[0])
    nuNA = np.abs(rates[1])
    # print(rates)
    # exit()

    state = stationary_state_cluster(nuA, nuNA, pars)

    nuA_out = state[0]
    nuNA_out = state[1]

    return [nuA - nuA_out , nuNA - nuNA_out]

####################################################################
def optimize_self_consistent_state_first_layer(rates, pars):
    nuA = np.abs(rates[0])
    nuNA = np.abs(rates[1])
    # print(rates)
    # exit()

    state = stationary_state_first_layer(nuA, nuNA, pars)

    nuA_out = state[0]
    nuNA_out = state[1]

    return [nuA - nuA_out , nuNA - nuNA_out]


#######################################
def fixed_points(pars, stability_analysis=False, multithreaded=False, modularity_values=np.arange(0., 1.001, 0.1)):
    """
    Compute fixed points
    :param pars:
    :param stability_analysis:
    :param multithreaded: sequential or multithreaded (one thread for each m) computation
    :return:
    """
    print('\nComputing self-consistent states...')
    np.random.seed(pars['seed_initial_rates'])

    results = {}

    def _calc_fp_for_m(args_dict):
        m = args_dict['m']
        print("_calc_fp_for_m() for m={}".format(m))
        nu = []
        f = []
        fjac = []
        state = []

        pars['modularity'] = m

        ## search for fixed points for multiple random initial rates
        for ci in range(pars['n_initial_rates']):
            initial_rates = (np.random.uniform(pars['intv_initial_ratesA'][0],pars['intv_initial_ratesA'][1],1),
                          np.random.uniform(pars['intv_initial_ratesNA'][0], pars['intv_initial_ratesNA'][1], 1))

            sol = scipy.optimize.root(optimize_self_consistent_state, initial_rates, args=(pars))
            #sol = scipy.optimize.root(optimize_self_consistent_state,initial_rates,args=(pars),method='lm')
            nu += [np.abs(sol.x)]
            f += [sol.fun]

        ## extract solutions where errors is <  max_error
        f=np.sum(np.array(f)**2,axis=1)  ## error of fixed-point solution (Euclidean distance)
        ind=np.where(f < np.abs(pars['max_error']))[0]
        f=f[ind]
        nu=np.array(nu)
        nu=nu[ind,:]

        ## remove multiplicity in solutions
        buf,ind=np.unique((nu[:,0]/pars['rate_precision']).astype(int),return_index=True)
        nu=(nu[ind,:]/pars['rate_precision']).astype(int)*pars['rate_precision']  ## round solution according to rate_precision

        # stability analysis
        time_interval = (0,2000)
        # derivative = lambda t,nu: np.array(stationary_state_supercluster(nu[0], nu[1], pars))[:2] - nu  # David
        derivative = lambda t,nu_: np.array(stationary_state_supercluster(nu_[0], nu_[1], pars))[:2] - nu_
        stability = []
        if stability_analysis:
            for i, rates in enumerate(nu):
                # sol = scipy.integrate.solve_ivp(derivative, time_interval, rates, args=(pars, )).y[:,-1]  # David
                try:
                    sol = scipy.integrate.solve_ivp(derivative, time_interval, rates).y[:,-1]
                    dists = []
                    for j, other_rates in enumerate(nu):
                        dists.append(np.abs(np.sum(sol-other_rates)))
                    ind = np.where(dists == np.min(dists))[0]
                    if ind == i:
                        stability.append(True)
                    else:
                        stability.append(False)
                except Exception as e:
                    print("WARNING! Exception occurred during stability analysis: {}".format(str(e)))
                    stability.append(-1)

        n_fp=nu.shape[0]
        print('\t{} fixed points found for m = {}'.format(n_fp, m))
        for cs in range(n_fp):
            print('\t\t nu_A=%.2f/s, nu_NA=%.2f/s, ' % (nu[cs,0],nu[cs,1]) + 'stability = ' +str(stability[cs]))

        for cs in range(n_fp):
            print('\tRecomputing the state for nu[A, NA] = ({}) for m = {}'.format(nu[cs,:], m))
            state += [stationary_state_supercluster(nu[cs,0], nu[cs,1], pars)]
        return nu, state, stability

    if multithreaded:
        thread_args_dict = [{'m': x} for x in modularity_values]
        pool = PathosPool(len(thread_args_dict))
        pool_results = pool.map(_calc_fp_for_m, thread_args_dict)
        results = {m: pool_results[idx] for idx, m in enumerate(modularity_values)}
    else:
        for m in modularity_values:
            results[m] = _calc_fp_for_m({'m': m})

    print('\nFinished computing self-consistent limit states...')
    return results


#######################################
def nu_vs_layers(pars, modularity_values=None, rates_initial_layer=np.array([8.5,5.9]), max_layers=50):
    print('\nComputing self-consistent states across layers...')
    np.random.seed(pars['seed_initial_rates'])

    results = {}
    limits = []
    for m in modularity_values:
        pars['modularity'] = m

        initial_rates = 5*np.ones(2)  # initial rates for fixed point iteration (has no influence on result)
        previous_rates = rates_initial_layer
        print('m = '+str(m))
        nu = []
        for l in range(max_layers):
            #print(previous_rates)
            pars['nuAfp'] = previous_rates[0]
            pars['nuNAfp'] = previous_rates[1]
            sol = scipy.optimize.root(optimize_self_consistent_state_layers, initial_rates,args=(pars))
            previous_rates = np.abs(sol.x)
            nu.append(np.abs(sol.x))
        results[m] = np.array(nu)
        limits.append(np.abs(sol.x))

    print('\nFinished computing self-consistent states across layers...')
    return results, modularity_values, np.array(limits)


#######################################
def nu_first_layer(pars, lmbds=None):
    """
    Computes self-consistent states in first layer
    :param pars:
    :param lmbds: list / array of lambdas (input intensity)
    :return:
    """
    print('\nComputing self-consistent states in first layer...')
    np.random.seed(pars['seed_initial_rates'])

    results = {}
    lmbds = np.arange(0., 0.26, 0.01) if lmbds is None else lmbds
    initial_rates = 5.*np.ones(2)  # does not influence the result
    for lmbd in lmbds:
        pars['lmbd'] = lmbd
        sol = scipy.optimize.root(optimize_self_consistent_state_first_layer, initial_rates,args=(pars))
        results[lmbd] = np.abs(sol.x)

    print('\nFinished computing self-consistent states in first layer...')
    return results


#######################################
def test_saturation(pars):
    if pars['method'] == 'nu0_fb433':
        siegert_func = siegert.nu0_fb433
    else:
        siegert_func = siegert.nu0_fb

    sigma = 1.
    result = []
    rates = np.arange(1., 2001., 10)

    for mu in rates:
        r = siegert_func( pars['tau_m'], pars['tau_s'], pars['tau_r'], pars['V_th'], pars['V_reset'], mu, sigma) * 1e3
        result.append(r)

    plt.plot(rates, result)
    plt.xlabel('nu (mean)')
    plt.title("Saturation of the Siegert function with variance = {}".format(sigma))
    # plt.show()
    plt.savefig(os.path.join(pars['data_root_path'], 'siegert_saturation.pdf'))


def _opt_calc_f_NA(rates, pars):
    """
    For a fixed nuA (rA), this function computes the stationary rate rNA of the non-active (non-stimulated)
    population.

    :param rates: tuple (nuA, nuNA)
    :param pars:
    :return:
    """
    nuA_fixed = pars['rA_fixed']  # always load the fixed rate rA
    nuNA = np.abs(rates[1])  # load current rNA as modified during optimization

    # we can use the previously implemented function to compute the siegerts, as nothing changes there
    state = stationary_state_supercluster(nuA_fixed, nuNA, pars)
    nuNA_out = state[1]
    # print(rates, nuA_fixed, state[:2])

    # for nuA, ensure the error (difference) is 0, and just optimize for nuNA
    return [0., nuNA - nuNA_out]


def _opt_calc_f_NA_multi(rates, pars):
    """
    For fixed nuA1 and nuA2 (rA1, rA2), this function computes the stationary rate rNA of
    the non-active (non-stimulated) sub-population.

    :param rates: tuple (nuA1, nuA2, nuNA)
    :param pars:
    :return:
    """
    nuA1_fixed = pars['rA1_fixed']  # always load the fixed rate rA
    nuA2_fixed = pars['rA2_fixed']  # always load the fixed rate rA
    nuNA = np.abs(rates[2])  # load current rNA as modified during optimization

    # we can use the previously implemented function to compute the siegerts, as nothing changes there
    state = stationary_state_supercluster_multi(nuA1_fixed, nuA2_fixed, nuNA, pars)
    nuNA_out = state[2]
    # print(rates, nuA_fixed, state[:2])

    # for nuA1 and nuA2, ensure the error (difference) is 0, and just optimize for nuNA
    return [0., 0., nuNA - nuNA_out]


def calc_siegert_A(r, pars):
    """

    :param r:
    :param pars:
    :return:
    """
    pars['rA_fixed'] = r  # we fix rA here and will calculate an rNA consistent with this
    rNA_init = 10.  # can be arbitrary, doesn't make any difference

    # calculate rNA = siegert(rA, rNA), to be consistent with the fixed rA=r in this integration step
    rNA_sol = scipy.optimize.root(_opt_calc_f_NA, (r, rNA_init), args=(pars))

    rNA = abs(rNA_sol.x[1])
    rA_siegert = stationary_state_supercluster(r, rNA, pars)[0]
    # print(rA_siegert)
    return rA_siegert


def calc_siegert_A_multi(r1, r2, pars):
    """
    :param r1:
    :param r2:
    :param pars:
    :return:
    """
    pars['rA1_fixed'] = r1  # we fix rA here and will calculate an rNA consistent with this
    pars['rA2_fixed'] = r2  # we fix rA here and will calculate an rNA consistent with this
    rNA_init = 10.  # can be arbitrary, doesn't make any difference

    # calculate rNA = siegert(rA, rNA), to be consistent with the fixed rA=r in this integration step
    rNA_sol = scipy.optimize.root(_opt_calc_f_NA_multi, (r1, r2, rNA_init), args=(pars, ))

    rNA = abs(rNA_sol.x[2])
    rA_siegert = stationary_state_supercluster_multi(r1, r2, rNA, pars)

    if CHECK_SIEGERT:
        check_rA_siegert = stationary_state_supercluster_multi(rA_siegert[0], rA_siegert[1], rA_siegert[2], pars)
        print(f"CHECK_SIEGERT: \n\t{rA_siegert[:3]}\n\t{check_rA_siegert[:3]},\n\t == {rA_siegert[:3] == check_rA_siegert[:3]}")
    return rA_siegert


def test_siegert_A_multi_correctness(r1, r2, pars, etol=1e-1):
    """
    Tests for siegert correctness / stability by plugging the computed rates back in.
    :param r1:
    :param r2:
    :param pars:
    :return:
    """
    pars['rA1_fixed'] = r1  # we fix rA here and will calculate an rNA consistent with this
    pars['rA2_fixed'] = r2  # we fix rA here and will calculate an rNA consistent with this
    rNA_init = 10.  # can be arbitrary, doesn't make any difference

    rNA_sol = scipy.optimize.root(_opt_calc_f_NA_multi, (r1, r2, rNA_init), args=(pars, ))

    rNA = abs(rNA_sol.x[2])
    rA_siegert = stationary_state_supercluster_multi(r1, r2, rNA, pars)

    check_rA_siegert = stationary_state_supercluster_multi(rA_siegert[0], rA_siegert[1], rA_siegert[2], pars)
    # equal_enough = np.all(np.isclose(check_rA_siegert[:3], rA_siegert[:3], atol=etol))
    equal_enough = np.all(np.isclose(check_rA_siegert[:2], np.array([r1, r2]), atol=etol))
    print(f"\tCHECK_SIEGERT: \n\t{(r1, r2)}\n\t{rA_siegert[:3]}\n\t{check_rA_siegert[:3]},\n\t == {equal_enough}")
    return equal_enough

