import sys
import os

import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import scipy.optimize
import scipy.integrate
import h5py_wrapper as h5w
from pathos.multiprocessing import ProcessingPool as PathosPool
import pathos
import copy

from mft_helpers import siegert
from mft_helpers import mft_plotting
from mft_helpers import calc
from mft_helpers.siegert import piecewise_linear_approx as pla_

# DEBUG = True
DEBUG = False

def calculate_effective_coupling(pars, m, lmbda):
    """
    Computes the effective coupling across the SSNs for a given modularity m and input
    intensity lambda.
    :param pars:
    :param m:
    :param lmbda:
    :return:
    """
    presynaptic = True
    contribution = 'lin'

    pars['modularity'] = m
    # rates_ssn0 = [8.5, 5.9]  # we know these are the stationary rates in SSN0, calculated previously
    rates_ssn0 = calc.nu_first_layer(pars, lmbds=[lmbda])[lmbda]
    nuA, nuNA, muA, muNA, sigmaA, sigmaNA = pars['full_state_vars']

    initial_rates = 5 * np.ones(2)  # initial rates for fixed point iteration (has no influence on result)
    previous_rates = rates_ssn0

    J =calc.convert_weights_to_fb_weights(pars['JEE'], pars)

    tau_m = pars['tau_m'] * 1e-3  # in seconds
    tau_s = pars['tau_s'] * 1e-3
    tau_r = pars['tau_r'] * 1e-3
    V_th = pars['V_th']
    V_reset = pars['V_reset']

    w_ecs = [[], []]
    ssn = 0
    while ssn < 6:
        # compute effective coupling
        w_ecs_rec_stim = siegert.d_nu_d_nu_in_fb(tau_m, tau_s, tau_r, V_th, V_reset, J, None, muA, sigmaA)
        w_ecs_rec_nonstim = siegert.d_nu_d_nu_in_fb(tau_m, tau_s, tau_r, V_th, V_reset, J, None, muNA, sigmaNA)

        w_ecs[0].append(w_ecs_rec_stim)
        w_ecs[1].append(w_ecs_rec_nonstim)

        if ssn == 5:
            break
        ssn += 1

        pars['nuAfp'] = previous_rates[0]
        pars['nuNAfp'] = previous_rates[1]
        sol = scipy.optimize.root(calc.optimize_self_consistent_state_layers, initial_rates, args=(pars))
        previous_rates = np.abs(sol.x)

        nuA_out, nuNA_out, muA, muNA, sigmaA, sigmaNA = pars['full_state_vars']

    return w_ecs

