# imports
import json, brian2, sys, os
import numpy as np
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import model_utils, brianutils
import helpers.analysis as han
import helpers.plotting.plotting as hplpl
import helpers.plotting.schematics as hpls
import helpers.simulation as hs

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.patches import Ellipse, Rectangle
import matplotlib
#matplotlib.use('Agg')
from brianutils import units

#%% load params from cfg file
S = json.load(open('cfg/simulation_params.json'))

dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run_trace = eval(S['S']['inf_run']['trace'],units)
runtime_trace = eval(S['S']['runtime']['trace'],units)
inf_run_fi = eval(S['S']['inf_run']['fi'],units)
runtime_fi = eval(S['S']['runtime']['fi'],units)
I_in_fi = eval(S['ranges']['I_in_fi'],units)
pnfs = eval(S['ranges']['PN_freqs'],units)
pnf_idx = eval(S['ranges']['pnf_idx'])
I_pumps = eval(S['ranges']['I_pump'],units)
model = S['S']['model']
cmap = plt.get_cmap(S['plotting']['cmaps'][0])

#%% load model and disable ion displacement
M = json.load(open(model))
M['ode'].pop(1)
M['parameters']['Na_in'] = M['parameters']['Na_in_0']
M['init_states'].pop('Na_in')

#%% get mean input current for naturalistic PN stimuli (200-600 Hz)
print('getting mean input current for naturalistic PN stimuli (200-600 Hz)')
statemon, M_inp = hs.run_sim(M,{},inf_run_trace,runtime_trace,pnfs=pnfs)
I_stims = [np.mean(model_utils.eval_func(M_inp, "I_stim", statemon)[i,:int(1/pnf/dt)]) for i, pnf in enumerate(pnfs)]
mean_syns = [np.mean(model_utils.eval_func(M_inp, "syn_clamp", statemon)[i,:int(1/pnf/dt)]) for i, pnf in enumerate(pnfs)]

# save mean inputs
if not os.path.exists('data/'):
    os.makedirs('data/')
np.save('data/mean_syns.npy',mean_syns)

#%% get fI curve
print('gettin fI curve')
statemon, spikemon, _ = hs.run_sim(M, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi',variables=['v'])
I_in, f_ = han.get_fI(I_in_fi, spikemon, statemon)

#%% check where the shifted fI curve (through increasing pump current)
#   intersects with the mean input current
intersects = []

for i,ip in enumerate(I_pumps):
    intersects.append(np.argmin(np.abs(I_in+ip-I_stims[pnf_idx])))

#%% co-express sodium leak channels with pump
# get Na leak conductance as a function of increasing pump densities
print('co-expression rule: g_Na_l = ' + M['functions']['g_Na_l'])
gNal = model_utils.eval_func(M,'g_Na_l',Ip_mean=I_pumps)

#%% get fI curves for increasing pump currents with co-expression of Na_l and pumps
print('getting fI curves with co-expression rule')
f_intersects_coex = []
I_in_coex = []
f_coex = []

for i,ips in enumerate(I_pumps):

    M['parameters']['Ip_mean'] = str(ips).replace(' ','*')

    statemon, spikemon, _ = hs.run_sim(M, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi',variables=['v'])
    I_in_, f__ = han.get_fI(I_in_fi, spikemon, statemon)
    I_in_coex.append(I_in_)
    f_coex.append(f__)

    f_intersects_coex.append(f__[np.argmin(np.abs(I_in_-I_stims[pnf_idx]))])

#%% fine-tune pump rates to match energetic demand
print('tuning pump rates to match energetic demand')
Ip_pred = np.zeros(len(mean_syns))*brian2.uA

for _ in range(20):

    # run short simulations for naturalistic mean inputs
    statemon, spikemon, _ = hs.run_sim(M, {'syn_clamp':mean_syns,'Ip_mean':Ip_pred}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)

    # get ISIs (period (T) of the limit cycles (lc))
    Tlc = [np.mean(np.diff(spikemon.t[spikemon.i==i])) for i in range(len(mean_syns))]

    # get mean Na and K currents
    INa_means, IK_means = han.get_currents(M, statemon, (Tlc/dt).astype(int))

    # estimate pump current from this
    new_Ip_pred = brian2.amp*(-np.array(INa_means)/3 + np.array(IK_means)/2)/2

    # exit loop if pump current converged
    if np.max(np.abs(Ip_pred - new_Ip_pred)) < 0.001*brian2.nA:
        break


    Ip_pred = new_Ip_pred

#%%
# save pump current fits
np.save('data/fitted_I_pump.npy',Ip_pred)

# get firing rates of tuned cells
fs = [1/np.mean(np.diff(spikemon.t[spikemon.i==i])) for i in range(len(I_stims))]

np.save('data/fr_fitted.npy',fs)

# get voltage, currents and spikes from tuned cell
t, v_, currents = han.get_currents(M,statemon,mode='all',pnf_idx=pnf_idx)
spikes = spikemon.t[spikemon.i==pnf_idx]

np.savez('data/currents_fig0.npz',t=t,v=v_,currents=currents,spikes=spikes)

#%% compare to case without pump
print('running simulation without pump to compare current scapes')
M['parameters']['Ip_mean'] = '0*uA'
statemon_nopump, spikemon_nopump, _ = hs.run_sim(M, {'syn_clamp':mean_syns}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)
t_nopump, v_nopump, currents_nopump = han.get_currents(M,statemon_nopump,mode='all',pnf_idx=pnf_idx)
spikes_nopump = spikemon_nopump.t[spikemon_nopump.i==pnf_idx]

#%% repeat for bigger range of stimuli for panel D
M = json.load(open(model))
M['ode'].pop(1)
M['parameters']['Na_in'] = M['parameters']['Na_in_0']
M['init_states'].pop('Na_in')

syn_show = np.logspace(-2,-0.3,5)
Ip_predD = np.zeros(len(syn_show))*brian2.uA

for _ in range(20):

    # run short simulations for naturalistic mean inputs
    statemon, spikemon, _ = hs.run_sim(M, {'syn_clamp':syn_show,'Ip_mean':Ip_predD}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)

    # get ISIs (period (T) of the limit cycles (lc))
    Tlc = [np.mean(np.diff(spikemon.t[spikemon.i==i])) for i in range(len(syn_show))]

    # get mean Na and K currents
    INa_means, IK_means = han.get_currents(M, statemon, (Tlc/dt).astype(int))

    # estimate pump current from this
    new_Ip_pred = brian2.amp*(-np.array(INa_means)/3 + np.array(IK_means)/2)/2

    # exit loop if pump current converged
    if np.max(np.abs(Ip_predD - new_Ip_pred)) < 0.001*brian2.nA:
        break

    Ip_predD = new_Ip_pred

fsD = [1/np.mean(np.diff(spikemon.t[spikemon.i==i])) for i in range(len(syn_show))]

#%% PLOTTING
import matplotlib.image as mpimg
# dummy plot to get colorbar for later
plt.figure()
im = plt.imshow([I_pumps/brian2.uA],cmap=cmap)
plt.close()

plt.rc('axes', labelsize=8)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.rc('axes', titlesize=10)
plt.rc('font', size=8)

EOD_img = mpimg.imread('fig/img/electrocyte_model.png')

f = plt.figure(figsize=(8,8))

gs00 = f.add_gridspec(2,1,figure=f,height_ratios=[1,3],hspace=0.25)
gs0 = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs00[1],hspace=0.85,height_ratios=[2,1])

# plot fish scheme
ax = f.add_subplot(gs00[0])
hpls.plot_image(ax,EOD_img)
ax.set_title('A\n',loc='left',x=0,y=0.8,weight='bold')

gswhat = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gs0[0],width_ratios=[0.1,1],wspace=0)
gs = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gswhat[1],width_ratios=[1,5],wspace=0.5)

gs_ = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs[0],hspace=1.25)

ax = f.add_subplot(gs_[0])
hplpl.plot_fr_vs_pump(ax, fsD, Ip_predD, I_pumps)
ax.set_ylabel('Pump\ncurrent [$\mu$A]')
ax.set_title('B\n',loc='left',x=-0.7,weight='bold')
ax.set_xticks([200,600])
ax.set_ylim([0,3.5])
ax.set_yticks([0,3])

ax = f.add_subplot(gs_[1])
hplpl.plot_coexpression_rule(ax, I_pumps[::3], gNal[::3])
ax.set_title('E\n',loc='left',x=-0.75,weight='bold')
ax.set_xticks([0,3])
ax.set_xlim([0,3.5])
ax.set_ylim([0,25])

gs_ = gridspec.GridSpecFromSubplotSpec(2,3,subplot_spec=gs[1], width_ratios=[2,1,0.05], wspace=0.1,hspace=1.3)

ax = f.add_subplot(gs_[0,0])
hplpl.plot_fI_pump(ax, I_in, I_pumps[::3], f_, I_stims[pnf_idx])
ax.set_xticks([0,0.63,3])
ax.set_xlim([0,3.5])
ax.set_ylim([0,600])
ax.set_yticks([0,500])
ax.set_xticklabels([])
ax.set_xticklabels(['0','0.63','3'])
ax.text(ax.get_xlim()[0],ax.get_ylim()[1]+10,'Without co-expression',verticalalignment='bottom',horizontalalignment='left')
ax.set_title('C\n',loc='left',x=-0.3,weight='bold')

ax = f.add_subplot(gs_[0,1])
hplpl.plot_pump_vs_freq(ax, I_pumps, f_[intersects], I_stims[pnf_idx],text=False)
ax.set_title('D\n',loc='left',x=-0.1,weight='bold')

ax.set_yticklabels([])
ax.set_ylabel('')
ax.set_ylim([0,600])
ax.set_xticks([0,3])
ax.set_yticks([0,500])

ax.text(ax.get_xlim()[1], ax.get_ylim()[0],'Input\ncurrent\n= 0.63 $\mu$A\n',horizontalalignment='right',verticalalignment='bottom')

ax = f.add_subplot(gs_[1,0])
hplpl.plot_fI_pump(ax, I_in_coex[::3], I_pumps[::3], f_coex[::3], I_stims[pnf_idx])
ax.set_title('F\n',loc='left',x=-0.3,weight='bold')
ax.set_xticks([0,0.63,3])
ax.set_xlim([0,3.5])
ax.set_yticks([0,500])

ax.set_xticklabels(['0','0.63','3'])
ax.set_ylim([0,600])
ax.text(ax.get_xlim()[0],ax.get_ylim()[1]+10,'With co-expression',verticalalignment='bottom',horizontalalignment='left')

ax = f.add_subplot(gs_[:,2])
cb = plt.colorbar(im, cax=ax)
cb.set_label('pump current [$\mu A$]',labelpad=-5)
cb.set_ticks([0,3])

ax = f.add_subplot(gs_[1,1])
hplpl.plot_pump_vs_freq(ax, I_pumps, f_intersects_coex, I_stims[pnf_idx],text=False)
ax.set_title('G\n',loc='left',x=-0.1,weight='bold')

ax.set_yticklabels([])
ax.set_ylabel('')
ax.set_ylim([0,600])
ax.set_yticks([0,500])
ax.set_xticks([0,3])

gs = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gs0[1],width_ratios=[1,4],wspace=0.4)
gs__ = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gs[1], wspace=1, width_ratios=[3.5,1])
gs_ = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gs__[0], wspace=0.05)

ax = f.add_subplot(gs_[:])
ax.set_title('I\n',loc='left',x=0,weight='bold')
ax.set_title('Currents per AP\n',loc='center',fontsize=8)
ax.axis('off')
ax = f.add_subplot(gs_[0])
AP, Na_int, K_int = hplpl.plot_current_scape(ax, M, t, v_, currents, spikes, dt, inf_run_fi, sep=False,bar=False)
ax.set_ylim([-17.5,30])
ax.set_yticks([-15,0,15])
ax.text(ax.get_xlim()[0], ax.get_ylim()[1],'\n\n  $\mathrm{syn}_\mathrm{clamp}$\n  = 0.13',horizontalalignment='left',verticalalignment='top')
ax.text(ax.get_xlim()[0],ax.get_ylim()[1],' With pump',horizontalalignment='left',verticalalignment='bottom')


h,l = ax.get_legend_handles_labels()

ax = f.add_subplot(gs_[1])
AP_nopump, Na_int_nopump, K_int_nopump = hplpl.plot_current_scape(ax, M, t_nopump, v_nopump, currents_nopump, spikes_nopump, dt, inf_run_fi, sep=False, bar=False)
ax.set_ylim([-17.5,30])
ax.set_yticks([-15,0,15])
ax.set_yticklabels([])
ax.set_ylabel('')
ax.text(ax.get_xlim()[0],ax.get_ylim()[1],' Without pump',horizontalalignment='left',verticalalignment='bottom')
ax.legend(h,l,frameon=False,ncol=1,loc='upper left',bbox_to_anchor=(1,0.5),handlelength=0.5)

ax = f.add_subplot(gs[0])
ax.set_title('H\n',loc='left',x=0,weight='bold')
ax.set_title('Action\npotentials',loc='right',fontsize=8)
hplpl.plot_spikes(ax,AP,AP_nopump,dt)
ax.set_ylim([-0.080,0.080])
ax.legend(frameon=False,loc='upper left')
ax.axis('off')

ax = f.add_subplot(gs__[1])
hplpl.plot_total_currents(ax, Na_int, K_int, Na_int_nopump, K_int_nopump)
ax.set_title('J\n',loc='left',x=-1,weight='bold')
ax.set_title('Energetic\nrequirements',loc='right',fontsize=8)

plt.savefig('fig/1.pdf',bbox_inches='tight',dpi=1000)
plt.close()
