# imports
import json, brian2, sys
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import brianutils
import helpers.simulation as hs
import helpers.analysis as han

import numpy as np
from brianutils import units

#%% load simulation params from cfg
S = json.load(open('cfg/simulation_params.json'))
dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run = eval(S['S']['inf_run']['stim_protocol'],units)
rec_dt = eval(S['S']['rec_dt']['constant_pump'],units)
inf_run_trace = eval(S['S']['inf_run']['trace'],units)
runtime_trace = eval(S['S']['runtime']['trace'],units)

fbl_chirp = eval(S['stimuli']['chirp']['baseline_freq'],units)
fbl_rises = eval(S['stimuli']['rises']['baseline_freq'],units)

T0_chirp, T1_chirp, T2_chirp, n_chirp, chirp_duration = [eval(v,units) for v in S['stimuli']['chirp']['long'].values()]
T_chirp_short, n_chirp_short, chirp_duration_short = [eval(v,units) for v in S['stimuli']['chirp']['short'].values()]

_, fd_rises, T0_rises, T1_rises, n_rises, tau_rises = [eval(v,units) for v in S['stimuli']['rises'].values()]

model = S['S']['model']

vm_target = eval(S['ranges']['vm_target'],units)
gna_range = eval(S['ranges']['gNa_range'],units)

#%%
runtime_chirp = (T0_chirp+(T1_chirp+chirp_duration/fbl_chirp)*n_chirp+T2_chirp)
runtime_chirp_short = (T_chirp_short+(T_chirp_short+chirp_duration_short/fbl_chirp)*n_chirp_short+T_chirp_short)
runtime_rises = (T0_rises + n_rises*tau_rises + T1_rises)

pnfs = [fbl_chirp, fbl_rises]*brian2.Hz

#%% tune pump currents and sodium conductance for spiking with fixed amplitude
# for periodic inputs

print('tuning pump rates and sodium channel conductance to match energetic demand and target spike amplitude')

M = json.load(open(model))
M['ode'].pop(1)
M['parameters']['Na_in'] = M['parameters']['Na_in_0']
M['init_states'].pop('Na_in')

Ip_pred = np.zeros(len(pnfs)*2)*brian2.uA
gna_cur = eval(M['parameters']['g_Na_max'],units).repeat(4)
bifpar = {'Ip_mean':[0,0]*brian2.uA, "P_Na": eval(M['parameters']['P_Na'],units)*eval(S['ranges']['syns']), "g_Na_max": gna_cur}

gna_min = np.ones(4)*gna_range[0]
gna_max = np.ones(4)*gna_range[1]

for _ in range(15):

    # run short simulations for naturalistic mean inputs
    statemon, spikemon, M_inp = hs.run_sim(M, bifpar, inf_run_trace, runtime_trace, pnfs=pnfs, N=2, spikemon_rec=True)

    # get ISIs (period (T) of the limit cycles (lc))
    Tlc = [np.mean(np.diff(spikemon.t[spikemon.i==i])) for i in range(statemon.v.shape[0])]

    # get mean Na and K currents
    INa_means, IK_means = han.get_currents(M_inp, statemon, (Tlc/dt).astype(int))

    # estimate pump current from this
    new_Ip_pred = brian2.amp*(-np.array(INa_means)/3 + np.array(IK_means)/2)/2

    # exit loop if pump current converged
    if np.max(np.abs(Ip_pred - new_Ip_pred)) < 0.001*brian2.nA:
        break

    Ip_pred = new_Ip_pred

    gna_max[np.max(statemon.v[:],axis=1)>vm_target] = gna_cur[np.max(statemon.v[:],axis=1)>vm_target]
    gna_min[np.max(statemon.v[:],axis=1)<vm_target] = gna_cur[np.max(statemon.v[:],axis=1)<vm_target]

    bifpar['Ip_mean'] = Ip_pred

    gna_cur = (gna_max + gna_min)/2
    bifpar['g_Na_max'] = gna_cur


np.save('data/I_pump_fitted_com.npy',Ip_pred)
np.save('data/I_gNa_fitted_com.npy',gna_cur)

#%% get the percentage of synaptic currents from these simulations
import helpers.plotting.plotting as hplpl
import matplotlib.pyplot as plt
Ip_pred
pnf_idx = 2
t, v, currents = han.get_currents(M_inp,statemon,mode='all',pnf_idx=pnf_idx)
spikes = spikemon.t[spikemon.i==pnf_idx]

#%%
for i in range(4):
    plt.plot(currents[i])

#%%
f = plt.figure()
ax = f.add_subplot()

currents

hplpl.plot_current_scape(ax, M, t, v, currents, spikes, dt, inf_run_trace, sep=False,bar=False)

#%% create chirp input spikes
stimes = []

sts = np.cumsum(np.ones(int(inf_run*fbl_chirp + T0_chirp*fbl_chirp))*(1/(fbl_chirp))).tolist()

for _ in range(n_chirp):
    sts.extend((sts[-1] + chirp_duration/fbl_chirp + np.cumsum(np.ones(int(T1_chirp*fbl_chirp))*(1/(fbl_chirp)))).tolist())

sts.extend((sts[-1] + np.cumsum(np.ones(int(T2_chirp*fbl_chirp))*(1/(fbl_chirp)))).tolist())

#%% chirp simulation
# should I set gnamax? same results without? then its easier to understand..
bifpar = {'Ip_mean':Ip_pred[0],'g_Na_max':gna_cur[0]}
pnfs = np.vstack([[sts],[np.zeros(len(sts))]])

#%%
for buf in ['','buf']:
    print('simulating chirps %s buffer'%('without' if buf == '' else 'with'))
    M = json.load(open(model))
    if 'buf' in buf:
        M['functions']['K_out'] = 'K_out_0'
    statemon_chirps, spikemon_chirps, _ = hs.run_sim(M,bifpar,inf_run,runtime_chirp,pnfs=pnfs,spikemon_rec=True,variables=['Na_in','v'])
    sts = []
    for v in spikemon_chirps.spike_trains().values():
        sts.append(v)
    np.savez('data/chirp_%s.npz'%buf,st=sts,nain=statemon_chirps.Na_in[:,:],v=statemon_chirps.v[:,:],pnfs=pnfs)

#%%
# Short chirp for the schematics.
stimes = []

sts = np.cumsum(np.ones(int(inf_run*fbl_chirp + T_chirp_short*fbl_chirp))*(1/(fbl_chirp))).tolist()

for _ in range(n_chirp_short):
    sts.extend((sts[-1] + chirp_duration_short/fbl_chirp + np.cumsum(np.ones(int(T_chirp_short*fbl_chirp))*(1/(fbl_chirp)))).tolist())

sts.extend((sts[-1] + np.cumsum(np.ones(int(T_chirp_short*fbl_chirp))*(1/(fbl_chirp)))).tolist())

#%% chirp simulation
bifpar = {'Ip_mean':Ip_pred[0],'g_Na_max':gna_cur[0]}
pnfs = np.vstack([[sts],[np.zeros(len(sts))]])

M = json.load(open(model))
print('simulating short chirp for schematic')
statemon_chirps, spikemon_chirps, _ = hs.run_sim(M,bifpar,inf_run,runtime_chirp,pnfs=pnfs,spikemon_rec=True,variables=['Na_in','v'])
sts = []
for v in spikemon_chirps.spike_trains().values():
    sts.append(v)
np.savez('data/chirp_short.npz',st=sts,nain=statemon_chirps.Na_in[:,:],v=statemon_chirps.v[:,:],pnfs=pnfs)

#%% frequency rises --> here the buffer also makes a difference ..
def ex_decay(t,fd,tau,fbl):
    return fd*np.exp(-t/tau) + fbl

sts = np.cumsum(np.ones(int(fbl_rises*(inf_run+T0_rises)))*(1/(fbl_rises))).tolist()

for _ in range(n_rises):
    t = 0*brian2.second
    for _ in range(100000):
        f = ex_decay(t, fd_rises, tau_rises/2, fbl_rises)
        isi = 1/f
        t = t+isi
        sts.append(sts[-1]+isi)
        if t > tau_rises:
            break

sts.extend((sts[-1] + np.cumsum(np.ones(int(fbl_rises*T1_rises))*(1/(fbl_rises)))).tolist())

pnfs = np.vstack([[sts],[np.zeros(len(sts))]])

#%% I can parralelize this
M = json.load(open(model))
bifpar = {'Ip_mean':Ip_pred[2:], "P_Na" : eval(M['parameters']['P_Na'],units)*eval(S['ranges']['syns']),'g_Na_max':gna_cur[2:]}

print('simulating frequency rises with strong and weak synapse')
statemon_rises, spikemon_rises, _ = hs.run_sim(M,bifpar,inf_run,runtime_rises,pnfs=pnfs,N=2,spikemon_rec=True,variables=['Na_in','v'])

#%%
sts = []
idxs = []
for i,v in enumerate(spikemon_rises.spike_trains().values()):
    sts.extend(v)
    idxs.extend(np.ones(len(v))*i)

np.savez('data/rises_syn.npz',st=np.vstack([sts,idxs]),nain=statemon_rises.Na_in[:,:],pnfs=pnfs,v=statemon_rises.v[:,:])
