# imports
import json, brian2, sys
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import brianutils
import helpers.simulation as hs

import numpy as np
from brianutils import units

#%%
S = json.load(open('cfg/simulation_params.json'))
dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run = eval(S['S']['inf_run']['stim_protocol'],units)
rec_dt = eval(S['S']['rec_dt']['voltage_dependent_pump'],units)

model = S['S']['model']

T0, T1, T2 = eval(S['stimuli']['long_square']['T0'],units), eval(S['stimuli']['long_square']['T1'],units), eval(S['stimuli']['long_square']['T2'],units)
bl_idx = eval(S['stimuli']['long_square']['baseline_idx'])
rise_idx = eval(S['stimuli']['long_square']['rise_idx'])

inf_run_fi = eval(S['S']['inf_run']['fi'],units)
runtime_fi = eval(S['S']['runtime']['fi'],units)

na_in_min, na_in_max = eval(S['ranges']['Na_in_ss'],units)

#%%
runtime = (T0+T1+T2)
mean_syns = np.load('data/mean_syns.npy')
stim_dev = {'high': mean_syns[rise_idx], 'low': 0}

#%%
M = json.load(open(model))

M['definitions']['I_K'] = '(2/3)*g_K_max*n**4*(v-E_K)'
M['definitions']['I_K_L'] = '(2/3)*g_K_l*(v-E_K)'

M['parameters'].pop('Ip_mean')
M['functions']['Ip_mean'] = '(1/2)*(I_K+I_K_L)'
M['functions']['g_Na_l'] = '0*uS'

# get to steady state quicker by changing Na_in_init
print('using binary search to get steady state ionic concentrations')

for _ in range(10):
    na_in_cur = (na_in_max+na_in_min)/2
    M['init_states']['Na_in'] = str(na_in_cur).replace(' ','*')

    # run short simulations for naturalistic mean inputs
    statemon, spikemon, _ = hs.run_sim(M, {'syn_clamp':[mean_syns[bl_idx]]}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)

    # check Na_in
    if np.max(statemon.Na_in[0,-int(statemon.N/2):]) > np.max(statemon.Na_in[0,:int(statemon.N/2)]):
        na_in_min = na_in_cur
    else:
        na_in_max = na_in_cur

#%%
for dev in ['high','low']:
    print('simulating %s synaptic inputs with voltage-dependent pump'%dev)
    statemon, spikemon, _ = hs.run_sim(M, {}, inf_run, runtime, mode='long_square',syns=[[mean_syns[bl_idx],T0],[stim_dev[dev],T0+T1],[mean_syns[bl_idx],T0+T1+T2]],spikemon_rec=True, variables=['Na_in','v','n'],rec_dt=rec_dt)
    np.savez('data/dev_input_vd_%s.npz'%dev,st=spikemon.t[:],nain=statemon.Na_in[0],v=statemon.v[0],n=statemon.n[0])
