import json, sympy, scipy, brian2, sys, os

import numpy as np
from numpy import exp

import brianutils, model_utils
from brianutils import units

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib as mpl

import pickle as pkl
from scipy.interpolate import interp1d
import copy
from scipy.signal import find_peaks, argrelextrema

N_TR = 1

def PN_input(pnfs, inf_run, runtime):
    """Generates spiketimes for duration in_run + runtime with frequencies in pnfs.

    Parameters
    ----------
    pnfs : list
        list of PN frequencies
    inf_run : float
        time (in seconds) for the inf run
    runtime : float
        time (in seconds)

    Returns
    -------
    numpy array
        a numpy array (size (N,M)) with M spiketimes for N neurons
    numpy array
        a numpy array (size (N,M)) with M indexes for N neurons
    """
    times = np.zeros((len(pnfs),1+int(np.max(pnfs)*(inf_run+runtime))))
    idxs = np.ones((len(pnfs),1+int(np.max(pnfs)*(inf_run+runtime))))
    idxs = (idxs.T*range(len(pnfs))).T

    for i,pnf in enumerate(pnfs):
        if len(np.arange(0,times.shape[1]*(1/pnf),1/pnf)) == len(times[i])+1:
            times[i] = np.arange(0,times.shape[1]*(1/pnf),1/pnf)[:-1]
        else:
            times[i] = np.arange(0,times.shape[1]*(1/pnf),1/pnf)
    return times, idxs


def run_sim(M_in, bifpar, inf_run, runtime, mode='periodic', pnfs=[None], syns=None, dt=0.001*brian2.ms, statemon_rec=True, spikemon_rec=False, variables=True, rec_dt = 0.001*brian2.ms, N=1, pert=None):
    """Runs simulation of model M_in in brian2.

    Parameters
    ----------
    M_in : dict
        dictionary with model equations
    bifpar : dict
        dictionary with parameters to reset
    inf_run : brian2.second
        simulation runtime before recording
    runtime : brian2.second
        recorded simulation runtime
    mode : string, optional
        stimulation mode. Options are: 'periodic', 'long_square', 'fi',
        'prcs', 'syn_clamp' (default is 'periodic')
    pnfs : list, optional
        pacemaker nucleus frequencies
    syns : list, optional
        list of synaptic input strengths and durations for long square protocol
        (default is None)
    dt : brian2.second, optional
        simulation timestep (default is 0.001*brian2.ms)
    statemon_rec : bool, optional
        Set to True to record simulation states (default is True)
    spikemon_rec : bool, optional
        Set to True to record spiketimes (default is False)
    variables : bool or list, optional
        Define variables to record. If True, all variables are recorded
        (default is True)
    rec_dt : brian2.second, optional
        Recording timestep (default is 0.001*brian2.ms)
    N : int, optional
        Number of neurons to record per stimulus (default is 1)
    pert : brian2.volt, optional
        Perturbation strength for PRC protocol (default is None)

    Returns
    -------
    brian2.StateMonitor

    brian2.SpikeMonitor

    dict
        dictionary with model equation run in brian2
    """

    # copy model equation dict
    M = copy.deepcopy(M_in)

    if mode == 'periodic' or mode =='long_square':

        # add input to model equations
        M["ode"].append("dt_p/dt = 1")
        M["init_states"]["t_p"] = "(0*ms)"

        if mode=='periodic':
            # define piece-wise equation for synapse
            M["functions"]["syn_clamp"] = "((t_p/(0.05*ms)*int(t_p<(0.05*ms)) + 1*int(t_p >= 0.05*ms)*int(t_p<0.25*ms) + exp(-(t_p-0.25*ms)/(0.1*ms))*int(t_p>=0.25*ms)))"

            # get input spiketimes
            if len(pnfs.shape)<=1:
                times, idxs = PN_input(pnfs, inf_run, runtime)
            else:
                times, idxs = pnfs

            N_neurons = N*len(np.unique(idxs))
        else:
            # define piece-wise equation for long square stimulus
            M["functions"]["syn_clamp"] = "(%f*int(t_p<(%f*second)) + %f*int(t_p >= %f*second)*int(t_p<%f*second) + %f*int(t_p>=%f*second))"%(syns[0][0],syns[0][1]/brian2.second,syns[1][0],syns[0][1]/brian2.second,syns[1][1]/brian2.second,syns[2][0],syns[1][1]/brian2.second)
            N_neurons = 1

    elif mode =='fi' or mode == 'prcs':

        # control I_stim experimentally
        if 'I_stim' in M['definitions'].keys():
            M['definitions'].pop('I_stim')
            M['parameters']['I_stim'] = '0*nA'

        # record spiketimes
        spikemon_rec=True

        if mode== 'fi':
            N_neurons = len(bifpar['I_stim'])
        else:
            times, idxs = pnfs
            N_neurons = len(np.unique(idxs))
            prcs = np.zeros((len(bifpar['I_stim']),int(N_neurons/len(bifpar['I_stim']))))

    elif mode == 'syn_clamp':
        N_neurons = len(bifpar['syn_clamp'])

    ode, _ = model_utils.create_equation_obj(M, {k:v for k,v in bifpar.items() if k not in M['init_states'].keys()})
    brian2.defaultclock.dt = dt
    brian2.start_scope()

    neurons= brian2.NeuronGroup(N_neurons, model=ode, method='rk4', threshold='v > -30*mV',
                                refractory='v >= -30*mV')

    model_utils.init_states(ode,neurons,M)

    # set model parameters to specified values
    for key,val in bifpar.items():
        if N>1 and len(val)==N:
            try:
                b_unit = eval(str(val).split('] ')[-1],units)
                val = np.tile(val/b_unit,len(np.unique(idxs)))*b_unit
            except:
                val = np.tile(val,len(np.unique(idxs)))
        elif mode == 'prcs':
            val = val.repeat(int(N_neurons/len(val)))

        setattr(neurons,key,val)

    if mode == 'periodic' or mode == 'prcs':
        # connect neurons to input (perturbations)
        input = brian2.SpikeGeneratorGroup(
                len(times),
                idxs.flatten().astype(int),
                times.flatten()*brian2.second,
                dt=dt,
            )

        if mode == 'periodic':
            synapses = brian2.Synapses(
                input, neurons, on_pre="t_p=0*ms", dt=dt
                )
            if N*len(np.unique(idxs))>1:
                synapses.connect(i=np.repeat(range(len(np.unique(idxs))),N),j=np.arange(len(np.unique(idxs))*N))
            else:
                synapses.connect()
        else:
            synapses = brian2.Synapses(
                input, neurons, on_pre="v+=%.2f*mV"%(pert/brian2.mV), dt=dt
                )
            synapses.connect(j="i")

        net = brian2.Network([neurons,synapses,input])
    else:
        net = brian2.Network([neurons])

    ## RUN simulation
    net.run(inf_run) # inf run

    if statemon_rec:
        statemon = brian2.StateMonitor(neurons,variables,True,dt=rec_dt)
        net.add(statemon)  # manually add the monitors
    if spikemon_rec:
        spikemon = brian2.SpikeMonitor(neurons)
        net.add(spikemon)

    net.run(runtime)

    if statemon_rec:
        if spikemon_rec:
            return statemon, spikemon, M
        else:
            return statemon, M

    if spikemon_rec:
        return spikemon, M
    else:
        return
