import json, brian2, sys, copy
import numpy as np
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import model_utils, brianutils
import helpers.analysis as han
import helpers.plotting.plotting as hplpl
import helpers.plotting.schematics as hpls
import helpers.simulation as hs

import matplotlib.pyplot as plt
import pickle as pkl
from matplotlib import gridspec
import matplotlib
matplotlib.use('Agg')
from brianutils import units
from scipy.interpolate import interp1d

#%% load cfg
S = json.load(open('cfg/simulation_params.json'))

dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run_trace = eval(S['S']['inf_run']['trace'],units)
runtime_trace = eval(S['S']['runtime']['trace'],units)
inf_run_fi = eval(S['S']['inf_run']['fi'],units)
runtime_fi = eval(S['S']['runtime']['fi'],units)
I_in_fi = eval(S['ranges']['I_in_fi_prc'],units)
pnfs = eval(S['ranges']['PN_freqs'],units)
pnf_idx = eval(S['ranges']['pnf_idx'])
model = S['S']['model']

cmap = plt.get_cmap(S['plotting']['cmaps'][0])

n_prc = eval(S['ranges']['prc_res'])
pert = eval(S['ranges']['prc_pert'],units)
runtime_prc = runtime_fi

#%% get perturbations with orig param set.
M = json.load(open(model))
M['ode'].pop(1)
M['parameters']['Na_in'] = M['parameters']['Na_in_0']
M['init_states'].pop('Na_in')

print('getting mean inputs')
statemon, M_inp = hs.run_sim(M,{},inf_run_trace,runtime_trace,pnfs=pnfs)

#%% get perturbation for one limit cycle
I_stim = model_utils.eval_func(M_inp, "I_stim", statemon).T
f_stims = []
for (pnf, istim) in zip(pnfs,I_stim.T):
    f_stims.append(interp1d(np.linspace(0,(1/pnf),int((1/pnf)/dt)), istim[:int((1/pnf)/dt)]))

T = np.arange(0,1/pnfs[pnf_idx],dt)
mean_stim = np.mean(f_stims[pnf_idx](T))*brian2.amp

#%% fI curve
print('getting fI curve')
statemon, spikemon, _ = hs.run_sim(M, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi')
I_in, f_, fst = han.get_fI(I_in_fi, spikemon, statemon,return_first_spike_time=True)

#%% get the phase response curves
# get perturbation times for the prcs.
stimes = []
for f, fstt in zip(f_,fst):
    if f>0:
        stimes.extend(np.linspace(fstt*brian2.second,(fstt + 1/f)*brian2.second,n_prc))

print('getting prcs')
statemon, spikemon, M_inp = hs.run_sim(M, {'I_stim':I_in[f_>0]}, inf_run_fi, runtime_prc, pnfs=[np.array(stimes),np.arange(len(stimes))], mode='prcs', pert=pert)
prcs = [han.get_prc([spikemon.spike_trains()[j] for j in range(i*n_prc,(i+1)*n_prc)], statemon.v[i*n_prc:(i+1)*n_prc,:], f_[f_>0][i], inf_run_fi, dt)/pert for i in range(len(f_[f_>0]))]

#%% get entrainment ranges
ranges = [han.get_entrainment_range(M, prc, f_stims[pnf_idx], pnfs[pnf_idx], dt) for prc in prcs]

#%%
plt.figure()
im = plt.imshow([(mean_stim - I_in_fi)/brian2.uA],cmap=cmap)
plt.close()

f = plt.figure(figsize=(6,4),constrained_layout=True)
plt.rc('axes', labelsize=8)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.rc('axes', titlesize=10)
plt.rc('font', size=8)

gs = f.add_gridspec(2,2)
ax = f.add_subplot(gs[0,0])
hplpl.plot_freq_dpump(ax,I_in,f_,mean_stim,pnfs[pnf_idx])
plt.title('A',loc='left',fontweight='bold',x=-0.28)

ax = f.add_subplot(gs[0,1])
hplpl.plot_prcs(ax,prcs,im)
plt.title('B',loc='left',fontweight='bold',x=-0.35)

ax = f.add_subplot(gs[1,:])
hplpl.plot_entrainment_range(ax,I_in,f_,ranges,mean_stim,pnfs[pnf_idx])
plt.title('C',loc='left',fontweight='bold',x=-0.12)
plt.savefig('fig/A2.pdf',bbox_inches='tight')

plt.close()
