"""
Contains functions to compute stationary firing rates, and the DC-susceptibilities.

(Moritz Helias;  Jannis Schuecker; Tom Tetzlaff, t.tetzlaff@fz-juelich.de, 2018)

Modified by David Dahmen and Barna Zajzon.
"""

import scipy
import scipy.integrate
import scipy.stats
import scipy.special
from scipy.special import erf, zetac, lambertw, erfcx, dawsn, roots_legendre
#from matplotlib.pylab import *  # for plot
from numpy import *
import numpy as np
from scipy.integrate import quad


def nu0_fb433(tau_m, tau_s, tau_r, V_th_rel, V_0_rel, mu, sigma):
    """Helper function implementing nu0_fb433 without quantities."""
    # use zetac function (zeta-1) because zeta is not giving finite values for
    # arguments smaller 1.
    alpha = np.sqrt(2.) * abs(zetac(0.5) + 1)

    nu0 = nu_0(tau_m, tau_r, V_th_rel, V_0_rel, mu, sigma)
    nu0_dPhi = _nu0_dPhi(tau_m, tau_r, V_th_rel, V_0_rel, mu, sigma)
    return nu0 * (1 - np.sqrt(tau_s * tau_m / 2) * alpha * nu0_dPhi)


def nu_0(tau_m, tau_r, V_th_rel, V_0_rel, mu, sigma):
    """
    Calculates stationary firing rates for delta shaped PSCs.

    Parameters:
    -----------
    tau_m: float
        Membrane time constant in seconds.
    tau_r: float
        Refractory time in seconds.
    V_th_rel: float
        Relative threshold potential in mV.
    V_0_rel: float
        Relative reset potential in mV.
    mu: float
        Mean neuron activity in mV.
    sigma: float
        Standard deviation of neuron activity in mV.

    Returns:
    --------
    float:
        Stationary firing rate in Hz.
    """
    if np.any(V_th_rel - V_0_rel < 0):
        raise ValueError('V_th should be larger than V_0!')
    y_th = (V_th_rel - mu) / sigma
    y_r = (V_0_rel - mu) / sigma

    # bring into appropriate shape
    y_th = np.atleast_1d(y_th)
    y_r = np.atleast_1d(y_r)
    assert y_th.shape == y_r.shape
    assert y_th.ndim == y_r.ndim == 1

    # determine order of quadrature
    params = {'start_order': 10, 'epsrel': 1e-12, 'maxiter': 10}
    gl_order = _get_erfcx_integral_gl_order(y_th=y_th, y_r=y_r, **params)

    # separate domains
    mask_exc = y_th < 0
    mask_inh = 0 < y_r
    mask_interm = (y_r <= 0) & (0 <= y_th)

    # calculate rescaled siegert
    nu = np.zeros(shape=y_th.shape)
    params = {'tau_m': tau_m, 't_ref': tau_r, 'gl_order': gl_order}
    nu[mask_exc] = _siegert_exc(y_th=y_th[mask_exc],
                                y_r=y_r[mask_exc], **params)
    nu[mask_inh] = _siegert_inh(y_th=y_th[mask_inh],
                                y_r=y_r[mask_inh], **params)
    nu[mask_interm] = _siegert_interm(y_th=y_th[mask_interm],
                                      y_r=y_r[mask_interm], **params)

    # include exponential contributions
    nu[mask_inh] *= np.exp(-y_th[mask_inh]**2)
    nu[mask_interm] *= np.exp(-y_th[mask_interm]**2)

    # convert back to scalar if only one value calculated
    if nu.shape == (1,):
        return nu.item(0)
    else:
        return nu



def _nu0_dPhi(tau_m, tau_r, V_th_rel, V_0_rel, mu, sigma):
    """Calculate nu0 * ( Phi(sqrt(2)*y_th) - Psi(sqrt(2)*y_r) ) safely."""
    if np.any(V_th_rel - V_0_rel < 0):
        raise ValueError('V_th should be larger than V_0!')
    y_th = (V_th_rel - mu) / sigma
    y_r = (V_0_rel - mu) / sigma

    # bring into appropriate shape
    y_th = np.atleast_1d(y_th)
    y_r = np.atleast_1d(y_r)
    assert y_th.shape == y_r.shape
    assert y_th.ndim == y_r.ndim == 1

    # determine order of quadrature
    params = {'start_order': 10, 'epsrel': 1e-12, 'maxiter': 10}
    gl_order = _get_erfcx_integral_gl_order(y_th=y_th, y_r=y_r, **params)

    # separate domains
    mask_exc = y_th < 0
    mask_inh = 0 < y_r
    mask_interm = (y_r <= 0) & (0 <= y_th)

    # calculate rescaled siegert
    nu = np.zeros(shape=y_th.shape)
    params = {'tau_m': tau_m, 't_ref': tau_r, 'gl_order': gl_order}
    nu[mask_exc] = _siegert_exc(y_th=y_th[mask_exc],
                                y_r=y_r[mask_exc], **params)
    nu[mask_inh] = _siegert_inh(y_th=y_th[mask_inh],
                                y_r=y_r[mask_inh], **params)
    nu[mask_interm] = _siegert_interm(y_th=y_th[mask_interm],
                                      y_r=y_r[mask_interm], **params)

    # calculate rescaled Phi
    Phi_th = np.zeros(shape=y_th.shape)
    Phi_r = np.zeros(shape=y_r.shape)
    Phi_th[mask_exc] = _Phi_neg(s=np.sqrt(2)*y_th[mask_exc])
    Phi_r[mask_exc] = _Phi_neg(s=np.sqrt(2)*y_r[mask_exc])
    Phi_th[mask_inh] = _Phi_pos(s=np.sqrt(2)*y_th[mask_inh])
    Phi_r[mask_inh] = _Phi_pos(s=np.sqrt(2)*y_r[mask_inh])
    Phi_th[mask_interm] = _Phi_pos(s=np.sqrt(2)*y_th[mask_interm])
    Phi_r[mask_interm] = _Phi_neg(s=np.sqrt(2)*y_r[mask_interm])

    # include exponential contributions
    Phi_r[mask_inh] *= np.exp(-y_th[mask_inh]**2 + y_r[mask_inh]**2)
    Phi_r[mask_interm] *= np.exp(-y_th[mask_interm]**2)

    # calculate nu * dPhi
    nu_dPhi = nu * (Phi_th - Phi_r)

    # convert back to scalar if only one value calculated
    if nu_dPhi.shape == (1,):
        return nu_dPhi.item(0)
    else:
        return nu_dPhi



def _get_erfcx_integral_gl_order(y_th, y_r, start_order, epsrel, maxiter):
    """Determine order of Gauss-Legendre quadrature for erfcx integral."""
    # determine maximal integration range
    a = min(np.abs(y_th).min(), np.abs(y_r).min())
    b = max(np.abs(y_th).max(), np.abs(y_r).max())

    # adaptive quadrature from scipy.integrate for comparison
    I_quad = quad(erfcx, a, b, epsabs=0, epsrel=epsrel)[0]

    # increase order to reach desired accuracy
    order = start_order
    for _ in range(maxiter):
        I_gl = _erfcx_integral(a, b, order=order)[0]
        rel_error = np.abs(I_gl / I_quad - 1)
        if rel_error < epsrel:
            return order
        else:
            order *= 2
    msg = f'Quadrature search failed to converge after {maxiter} iterations. '
    msg += f'Last relative error {rel_error:e}, desired {epsrel:e}.'
    raise RuntimeError(msg)



def _erfcx_integral(a, b, order):
    """Fixed order Gauss-Legendre quadrature of erfcx from a to b."""
    assert np.all(a >= 0) and np.all(b >= 0)
    x, w = roots_legendre(order)
    x = x[:, np.newaxis]
    w = w[:, np.newaxis]
    return (b - a) * np.sum(w * erfcx((b-a) * x / 2 + (b+a) / 2), axis=0) / 2



def _siegert_exc(y_th, y_r, tau_m, t_ref, gl_order):
    """Calculate Siegert for y_th < 0."""
    assert np.all(y_th < 0)
    I_erfcx = _erfcx_integral(np.abs(y_th), np.abs(y_r), gl_order)
    return 1 / (t_ref + tau_m * np.sqrt(np.pi) * I_erfcx)


def _siegert_inh(y_th, y_r, tau_m, t_ref, gl_order):
    """Calculate Siegert without exp(-y_th**2) factor for 0 < y_th."""
    assert np.all(0 < y_r)
    e_V_th_2 = np.exp(-y_th**2)
    I_erfcx = 2 * dawsn(y_th) - 2 * np.exp(y_r**2 - y_th**2) * dawsn(y_r)
    I_erfcx -= e_V_th_2 * _erfcx_integral(y_r, y_th, gl_order)
    return 1 / (e_V_th_2 * t_ref + tau_m * np.sqrt(np.pi) * I_erfcx)


def _siegert_interm(y_th, y_r, tau_m, t_ref, gl_order):
    """Calculate Siegert without exp(-y_th**2) factor for y_r <= 0 <= y_th."""
    assert np.all((y_r <= 0) & (0 <= y_th))
    e_V_th_2 = np.exp(-y_th**2)
    I_erfcx = 2 * dawsn(y_th)
    I_erfcx += e_V_th_2 * _erfcx_integral(y_th, np.abs(y_r), gl_order)
    return 1 / (e_V_th_2 * t_ref + tau_m * np.sqrt(np.pi) * I_erfcx)



def _Phi_neg(s):
    """Calculate Phi(s) for negative arguments"""
    assert np.all(s <= 0)
    return np.sqrt(np.pi / 2.) * erfcx(np.abs(s) / np.sqrt(2))


def _Phi_pos(s):
    """Calculate Phi(s) without exp(-s**2 / 2) factor for positive arguments"""
    assert np.all(s >= 0)
    return np.sqrt(np.pi / 2.) * (2 - np.exp(-s**2 / 2.)*erfcx(s / np.sqrt(2)))
