# imports
import json, brian2, sys
import numpy as np
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import model_utils, brianutils
import helpers.analysis as han
import helpers.plotting.plotting as hplpl
import helpers.plotting.schematics as hpls
import helpers.simulation as hs

import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.image as mpimg
import matplotlib
matplotlib.use('Agg')
from brianutils import units
from matplotlib import patches

#%% load cfg
S = json.load(open('cfg/simulation_params.json'))

dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run_trace = eval(S['S']['inf_run']['trace'],units)
runtime_trace = eval(S['S']['runtime']['trace'],units)
inf_run_fi = eval(S['S']['inf_run']['fi'],units)
runtime_fi = eval(S['S']['runtime']['fi'],units)
I_in_fi = eval(S['ranges']['I_in_fi'],units)
pnfs = eval(S['ranges']['PN_freqs'],units)
pnf_idx = eval(S['ranges']['pnf_idx'])

model = S['S']['model']

cmap = plt.get_cmap(S['plotting']['cmaps'][1])


#%% load images
'''
EOD_img = mpimg.imread('fig/img/fish_field.png')
circuit_img = mpimg.imread('fig/img/fish_scheme.png')
'''
#%% simulate electrocyte with periodic input.
M = json.load(open(model))
M['ode'].pop(1)
M['parameters']['Na_in'] = M['parameters']['Na_in_0']
M['init_states'].pop('Na_in')

#%% get mean input current for naturalistic PN stimuli (200-600 Hz)
print('simulating electrocyte with periodic input')
statemon, M_inp = hs.run_sim(M,{},inf_run_trace,runtime_trace,pnfs=pnfs)
I_stims = [np.mean(model_utils.eval_func(M_inp, "I_stim", statemon)[i,:int(1/pnf/dt)]) for i, pnf in enumerate(pnfs)]

# save traces
I_stim_periodic = model_utils.eval_func(M_inp, "I_stim", statemon).T
v_traces= statemon.v.T

#%% repeat but for constant inputs.
print('simulating electrocyte with constant input')
statemon, spikemon,_ = hs.run_sim(M,{'I_stim':I_stims},inf_run_fi,runtime_fi,mode='fi')

# save traces
v_traces_md = statemon.v[:,:int(runtime_trace/dt)].T

# save firing rates
I_in_md, f_md = han.get_fI(I_stims, spikemon, statemon)

#%% get the whole fI curve.
print('getting fI curve')
statemon, spikemon,_ = hs.run_sim(M,{'I_stim':I_in_fi},inf_run_fi,runtime_fi,mode='fi')
I_in, f_ = han.get_fI(I_in_fi, spikemon, statemon)

#%% now I have everything for the plot.
greys = plt.get_cmap('Greys')

align_these = []
align_those = []
plt.rc('axes', labelsize=8)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.rc('axes', titlesize=10)
plt.rc('font', size=8)

#%%
f = plt.figure(figsize=(8,2.5),constrained_layout=True)

gs0 = f.add_gridspec(2,3,figure=f,width_ratios=[2,1,1])

gs = gridspec.GridSpecFromSubplotSpec(2,2,subplot_spec=gs0[:,0])

ax = f.add_subplot(gs[0,0])
ax.set_title('A',loc='left',weight='bold',x=-0.6)
hplpl.plot_stim(ax, I_stim_periodic[:,pnf_idx], runtime_trace,c=cmap(0))
align_these.append(ax)
ax = f.add_subplot(gs[1,0])
hplpl.plot_v(ax, v_traces[:,pnf_idx], runtime_trace,c=cmap(0))
align_these.append(ax)

ax = f.add_subplot(gs[0,1])
hplpl.plot_stim(ax, np.ones(len(v_traces_md[:,pnf_idx]))*I_stims[pnf_idx], runtime_trace)
ax.set_ylabel('')
ax.set_yticklabels([])

ax = f.add_subplot(gs[1,1])
hplpl.plot_v(ax, v_traces_md[:,pnf_idx], runtime_trace)

ax.set_ylabel('')
ax.set_yticklabels([])

ax = f.add_subplot(gs0[0,1])
hplpl.plot_If(ax, pnfs, I_stims)
ax.set_yticks([0.5,1])
ax.set_ylim([0,1])
ax.set_xticks([200,400,600])
ax.set_title('B',loc='left',weight='bold',x=-0.5)
align_those.append(ax)

ax = f.add_subplot(gs0[1,1])
hplpl.plot_fi(ax, I_in, f_)
ax.set_yticks([200,600])
inset = patches.Rectangle((I_stims[0]/brian2.uA,0),(I_stims[-1]-I_stims[0])/brian2.uA,600,fc=greys(0.4),ec=greys(0.4))
ax.add_patch(inset)

ax.set_title('C',loc='left',weight='bold',x=-0.5)
ax.set_ylabel('Mean-driven\nelectrocyte\nfiring rate [Hz]')
ax.set_xlim([0,2.5])
align_those.append(ax)

ax = f.add_subplot(gs0[:,2])
hplpl.plot_freq_diff(ax, pnfs, f_md)
ax.set_title('D',loc='left',weight='bold',x=-0.4)
ax.set_ylabel('Electrocyte firing rate [Hz]')
ax.set_xticks([200,400,600])
ax.set_yticks([200,400,600])
f.align_ylabels(align_these)

plt.savefig('fig/3.pdf',dpi=300,bbox_inches='tight')
#%%
