import json, sympy, scipy, brian2

import numpy as np
from numpy import exp

import brianutils, model_utils
from brianutils import units

import pickle as pkl
from scipy.interpolate import interp1d
import copy
from scipy.signal import find_peaks, argrelextrema

def get_fI(I_in,spikemon,statemon,return_first_spike_time=False):
    '''Gets frequency-input curve from brian2 spike- and statemonitor.

    Parameters
    ----------
    I_in : numpy array
        numpy array with input currents
    spikemon : brian2.SpikeMonitor
    statemon : brian2.StateMonitor
    return_first_spike : bool, optional
        set to True to return first spiketime for each input current

    Returns
    -------
    numpy array
        input currents
    numpy array
        firing rates
    numpy array
        time of first spike
    '''

    inf_run = statemon.t[0]
    dt = statemon.t[1] - statemon.t[0]
    vv = statemon.v[:]
    spike_trains = spikemon.spike_trains()
    f_, fst = [], []

    for i in range(len(I_in)):
        st = spike_trains[i]

        if len(st)>2:
            vm = vv[i,int((st[0]-inf_run)/dt):int((st[1]-inf_run)/dt)]

            # check if amplitude of spikes is stable
            vm_last_spike = vv[i,int((st[-2]-inf_run)/dt):int((st[-1]-inf_run)/dt)]

            if np.abs(np.max(vm)-np.max(vm_last_spike)) < 0.01*brian2.mV:
                # lc is stable (enough)
                f_.append((1/(st[1]-st[0]))/brian2.Hz)
                fst.append(st[1] + dt*np.argmax(vm_last_spike))
            else:
                if len(f_)>0 and f_[-1]>0:
                    break
                f_.append(0)
                fst.append(0)
        else:
            if len(f_)>0 and f_[-1]>0:
                break
            f_.append(0)
            fst.append(0)

    if not return_first_spike_time:
        return I_in[:len(f_)], np.array(f_)
    else:
        return I_in[:len(f_)], np.array(f_), np.array(fst)


def get_currents(M, statemon, Tlc = None, mode = 'mean', pnf_idx=0):
    '''Gets all currents

    Parameters
    ----------
    M : dict
        dictionary with model equations
    statemon : brian2.StateMonitor
    Tlc : int, optional
        limit cycle period in timesteps
    mode : string, optional
        If not 'mean', return time evolution of currents, otherwise return
        mean (default is 'mean')
    pnf_idx : int, optional
        neuron idx to get currents from (default is 0)
    Returns
    -------
    list
        list of mean Na and K currents if mode is 'mean', otherwise list of
        timepoints, membrane voltages and a list of all separate currents.
    '''

    I_NaT = model_utils.eval_func(M,'I_NaT',statemon)
    I_NaP = model_utils.eval_func(M,'I_NaP',statemon)
    I_Na_L = model_utils.eval_func(M,'I_Na_L',statemon)

    syn_app = model_utils.eval_func(M,'syn_clamp',statemon)
    I_AChRNa = model_utils.eval_func(M,'I_AChRNa',statemon)*syn_app

    I_Na_tot = I_NaT+I_NaP+I_Na_L+I_AChRNa

    I_K = model_utils.eval_func(M,'I_K',statemon)
    I_K_L = model_utils.eval_func(M,'I_K_L',statemon)
    I_AChRK = model_utils.eval_func(M,'I_AChRK',statemon)*syn_app

    I_K_tot = I_K + I_K_L + I_AChRK

    I_pump = model_utils.eval_func(M,'I_pump',statemon)

    if mode=='mean':
        return [np.mean(I_Na_tot[i,-T:]) for i,T in enumerate(Tlc)], [np.mean(I_K_tot[i,-T:]) for i,T in enumerate(Tlc)]
    else:
        if type(I_Na_L) == int and type(I_pump) == int:
            return [statemon.t[:], statemon.v[pnf_idx], [I_NaT[pnf_idx], I_NaP[pnf_idx], 0, I_AChRNa[pnf_idx], I_K[pnf_idx], I_K_L[pnf_idx], I_AChRK[pnf_idx], 0]]
        elif type(I_Na_L) == int:
                return [statemon.t[:], statemon.v[pnf_idx], [I_NaT[pnf_idx], I_NaP[pnf_idx], 0, I_AChRNa[pnf_idx], I_K[pnf_idx], I_K_L[pnf_idx], I_AChRK[pnf_idx], I_pump[pnf_idx]]]
        else:
            return [statemon.t[:], statemon.v[pnf_idx], [I_NaT[pnf_idx], I_NaP[pnf_idx], I_Na_L[pnf_idx], I_AChRNa[pnf_idx], I_K[pnf_idx], I_K_L[pnf_idx], I_AChRK[pnf_idx], I_pump[pnf_idx]]]

def entrainment_index(st0,st1):
    '''Gets entrainment index

    Parameters
    ----------
    st0 : numpy array
        spiketimes of neuron 0
    st1 : numpy array
        spiketimes of neuron 1

    Returns
    -------
    float
        entrainment index
    '''

    if st0[0] < st1[0]:
        bins = st0
        rest = st1
    else:
        bins = st1
        rest = st0

    rest = st1
    ix = np.digitize(rest,bins) - 1 # saving the left time-interval bound
    return np.sum([np.exp(2j*np.pi* (rest[i]-bins[ix[i]-1]) / (bins[ix[i]-1]-bins[ix[i]]))  for i in range(len(rest))])/len(rest)

def get_prc(spikes, v, f, inf_run, dt):
    '''Gets Phase Response Curve from spiketimes, membrane voltage, and firing
    rates

    Parameters
    ----------
    spikes : list
        list of spiketimes for each perturbation
    v : list
        list of voltage traces for each perturbation
    f : float
        firing rate
    inf_run : brian2.second
        runtime before data was recorded
    dt : brian2.second
        recording timestep

    Returns
    -------
    numpy array
        Phase Response Curve
    '''

    prc = []

    for n, (st,vv) in enumerate(zip(spikes,v)):
        vm = vv[int((st[2]-inf_run)/dt):int((st[3]-inf_run)/dt)]
        ctm = st[2] + np.argmax(vm)*dt

        vm = vv[int((st[1]-inf_run)/dt):int((st[2]-inf_run)/dt)]
        te = st[1] + np.argmax(vm)*dt

        prc.append(((te/brian2.second+1/f) - ctm/brian2.second)/(1/f))

    return fix_prc(np.array(prc))

def fix_prc(prc):
    '''Fixes inconsistencies in the PRC

    Parameters
    ----------
    prc : nump array
        Phase Response Curve

    Returns
    -------
    numpy array
        Phase Response Curve
    '''
    prc[np.abs(prc)> 2*np.percentile(np.abs(prc),80)] = np.nan
    x = np.arange(len(prc))[~np.isnan(prc)]
    f = interp1d(x,prc[~np.isnan(prc)])
    return f(np.arange(len(prc)))

def get_entrainment_range(M, prc, I_stim, pnf, dt, psi_res=100):
    '''Gets entrainment range for model M at input I_stim with freq pnf

    Parameters
    ----------
    M : dictionary
        model equations
    prc : numpy array
        Phase Response Curve
    I_stim : scipy.interp1d
        function with stimulus shape
    pnf : float
        pacemaker frequency
    dt : float
        timestep in seconds
    psi_res : int, optional
        resolution for calculating psi

    Returns
    -------
    float
        minimum pacemaker frequency
    float
        maximum pacemaker frequency
    '''

    df_mins = []
    df_maxs = []

    PRC = interp1d(np.linspace(0,1,len(prc)),prc)

    psi_ = np.linspace(0,1,psi_res)
    d_psi = np.zeros(psi_res)
    T_pn = 1/pnf
    T = np.arange(0,T_pn,dt)

    for i, psi in enumerate(psi_):
        d_psi[i] = (1/T_pn)*np.sum(PRC(np.mod(psi+pnf/brian2.Hz*T,1))*((I_stim(T)-np.mean(I_stim(T)))*brian2.amp/eval(M['parameters']['C'],units)/brian2.volt))*dt

    return np.min(d_psi), np.max(d_psi)
