# imports
import json, brian2, sys
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import brianutils
import helpers.simulation as hs

import numpy as np
from brianutils import units
#%%
S = json.load(open('cfg/simulation_params.json'))
dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run = eval(S['S']['inf_run']['stim_protocol'],units)
rec_dt = eval(S['S']['rec_dt']['constant_pump'],units)

model = S['S']['model']

T0, T1, T2 = eval(S['stimuli']['long_square']['T0'],units), eval(S['stimuli']['long_square']['T1'],units), eval(S['stimuli']['long_square']['T2'],units)
bl_idx = eval(S['stimuli']['long_square']['baseline_idx'])
rise_idx = eval(S['stimuli']['long_square']['rise_idx'])

#%%
M = json.load(open(model))
I_pump = np.load('data/fitted_I_pump.npy')*brian2.amp
mean_syns = np.load('data/mean_syns.npy')

# fitted parameters for PN = 300 Hz
stim_dev = {'high': mean_syns[rise_idx], 'low': 0}

runtime = (T0+T1+T2)

#%%
for dev in ['high','low']:
    print('running simulation with %s synaptic iput'%dev)
    statemon, spikemon, _ = hs.run_sim(M, {'Ip_mean':I_pump[bl_idx]}, inf_run, runtime, mode='long_square',syns=[[mean_syns[bl_idx],T0],[stim_dev[dev],T0+T1],[mean_syns[bl_idx],T0+T1+T2]],spikemon_rec=True, variables=['Na_in'],rec_dt=rec_dt)
    np.savez('data/dev_input_%s.npz'%dev,st=spikemon.t[:],nain=statemon.Na_in[0])
