import json, brian2, sys, copy
import numpy as np
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import model_utils, brianutils
import helpers.analysis as han
import helpers.plotting.plotting as hplpl
import helpers.plotting.schematics as hpls
import helpers.simulation as hs

import matplotlib.pyplot as plt
import pickle as pkl
from matplotlib import gridspec
import matplotlib
matplotlib.use('Agg')
from brianutils import units

#%%
S = json.load(open('cfg/simulation_params.json'))

dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run_trace = eval(S['S']['inf_run']['trace'],units)
runtime_trace = eval(S['S']['runtime']['trace'],units)
inf_run_fi = eval(S['S']['inf_run']['fi'],units)
runtime_fi = eval(S['S']['runtime']['fi'],units)
I_in_fi = eval(S['ranges']['I_in_fi'],units)
pnfs = eval(S['ranges']['PN_freqs'],units)
pnf_idx = eval(S['ranges']['pnf_idx'])
model = S['S']['model']

pump_densities = eval(S['ranges']['pump_density'])
na_in_min, na_in_max = eval(S['ranges']['Na_in_ss'],units)

T0, T1, T2 = eval(S['stimuli']['long_square']['T0'],units), eval(S['stimuli']['long_square']['T1'],units), eval(S['stimuli']['long_square']['T2'],units)
bl_idx = eval(S['stimuli']['long_square']['baseline_idx'])
rise_idx = eval(S['stimuli']['long_square']['rise_idx'])
dt_rec = eval(S['S']['rec_dt']['voltage_dependent_pump'],units)

cmap = plt.get_cmap(S['plotting']['cmaps'][0])

#%%
I_pump = np.load('data/fitted_I_pump.npy')*brian2.amp
mean_syns = np.load('data/mean_syns.npy')
fs = np.load('data/fr_fitted.npy')*brian2.Hz
stim_dev = {'high': mean_syns[rise_idx], 'low': 0}

M = json.load(open(model))
M_vd = json.load(open(model))

mean_syns
I_pump
#%% SIMULATIONS
# 1. simulate model with voltage dependent pump until steady state
# it will not be at steady state exactly because the pump cannot depend on neuro-
# transmitter release.

M_vd['definitions']['I_K'] = '(2/3)*g_K_max*n**4*(v-E_K)'
M_vd['definitions']['I_K_L'] = '(2/3)*g_K_l*(v-E_K)'

M_vd['parameters'].pop('Ip_mean')
M_vd['functions']['Ip_mean'] = '(1/2)*(I_K+I_K_L)'
M_vd['functions']['g_Na_l'] = '0*uS'

# get to steady state quicker by changing Na_in_init
print('using binary search to get steady state ionic concentrations')

for _ in range(10):
    na_in_cur = (na_in_max+na_in_min)/2
    M_vd['init_states']['Na_in'] = str(na_in_cur).replace(' ','*')

    # run short simulations for naturalistic mean inputs
    statemon, spikemon, _ = hs.run_sim(M_vd, {'syn_clamp':[mean_syns[pnf_idx]]}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)

    # check Na_in
    if np.max(statemon.Na_in[0,-int(statemon.N/2):]) > np.max(statemon.Na_in[0,:int(statemon.N/2)]):
        na_in_min = na_in_cur
    else:
        na_in_max = na_in_cur

# get currents from simulation in steady state
t, v, currents = han.get_currents(M_vd,statemon,mode='all')
spikes = spikemon.t[:]


#%% 2. fI curves for voltage-dependent pump
M_vd_cion = copy.deepcopy(M_vd)
M_vd_cion['ode'].pop(1)
M_vd_cion['parameters']['Na_in'] = M_vd_cion['init_states']['Na_in']
M_vd_cion['init_states'].pop('Na_in')

I_vd = []
f_vd = []

print('getting fI curves for varying (voltage-dependent) pump denisties')
for mult in pump_densities:
    M_vd_cion['functions']['Ip_mean'] = str(mult)+'*(1/2)*(I_K+I_K_L)'
    statemon, spikemon, _ = hs.run_sim(M_vd_cion, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi',variables=['v'])
    I_in, f_ = han.get_fI(I_in_fi, spikemon, statemon)
    I_vd.append(I_in)
    f_vd.append(f_)

#%% 2.b fI curves for constant pump
M_cion = copy.deepcopy(M)

M_cion['ode'].pop(1)
M_cion['parameters']['Na_in'] = M_cion['parameters']['Na_in_0']
M_cion['parameters']['Ip_mean'] = '%.4f*uA'%(I_pump[pnf_idx]/brian2.uA)

I_nvd = []
f_nvd = []

print('getting fI curves for varying (voltage-agnostic) pump denisties')
for mult in pump_densities:
    M_cion['definitions']["I_pump"] = str(mult)+"*Ip_mean*4*(1+exp((Na_in_0 - Na_in)/(3*mmolar)))**-1*(1+exp((K_out_0 - K_out)/mmolar))**-1"
    statemon, spikemon, _ = hs.run_sim(M_cion, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi',variables=['v'])
    I_in, f_  = han.get_fI(I_in_fi, spikemon, statemon)
    I_nvd.append(I_in)
    f_nvd.append(f_)

#%% combine in one plot.
plt.figure()
im = plt.imshow([pump_densities*100],cmap=cmap)
plt.close()

plt.rc('axes', labelsize=8)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.rc('axes', titlesize=10)
plt.rc('font', size=8)

f = plt.figure(figsize=(8,5))#,constrained_layout=True)

gs0 = f.add_gridspec(2,1,figure=f,height_ratios=[0.6,1.5],hspace=0.75)

gs = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gs0[0],width_ratios=[2,1],wspace=1)

ax = f.add_subplot(gs[0])
ax.set_title('A\n\n',loc='left',x=-0.275,weight='bold')
ax.set_title('Currents per AP\n\n',loc='center',fontsize=8)
ax.axis('off')

gs_ = gridspec.GridSpecFromSubplotSpec(2,2,subplot_spec=gs[0],height_ratios=[1,1],hspace=0.2, wspace=0.05)

ax = f.add_subplot(gs_[:,0])

# plot non-voltage dep currents
t_c, v_c, currents_c, spikes_c = np.load('data/currents_fig0.npz').values()
AP_c, Na_int_c, K_int_c = hplpl.plot_current_scape(ax, json.load(open(model)), t_c*brian2.second, v_c, currents_c, spikes_c*brian2.second, dt, inf_run_fi, sep=False)

ax.set_ylim([-17.5,27.5])
ax.set_yticks([-15,0,15])
ax.set_ylabel('Current\ncontribution [$\mu$A]')
ax.set_title('Constant',fontsize=8)
ax.text(ax.get_xlim()[0] + (ax.get_xlim()[1]-ax.get_xlim()[0])*0.05, ax.get_ylim()[1] - (ax.get_ylim()[1]-ax.get_ylim()[0])*0.05,'$\mathrm{syn}_\mathrm{clamp}$\n= 0.13',fontsize=9,horizontalalignment='left',verticalalignment='top')

h,l = ax.get_legend_handles_labels()

ax = f.add_subplot(gs_[:,1])
AP, Na_int, K_int = hplpl.plot_current_scape(ax, M, t, v, currents, spikes, dt, inf_run_fi, sep=False,bar=False)
ax.set_ylim([-17.5,27.5])
ax.set_yticks([-15,0,15])
ax.set_yticklabels([])
ax.set_ylabel('')
ax.set_title('Voltage\ndependent',fontsize=8)


ax.legend(h,l,frameon=False,ncol=1,loc='upper left',bbox_to_anchor=(1,1),handlelength=0.5)

ax = f.add_subplot(gs[1])
ax.set_title('Energetic\nrequirements\n',loc='center',fontsize=8)
ax.set_title('B\n\n',loc='left',x=-0.65,weight='bold')
hplpl.plot_totcur_compare(ax,Na_int_c,K_int_c,K_int_c/2,Na_int,K_int,K_int/2)

gs__ = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gs0[1],wspace=0.65,width_ratios=[1,2.5])
gs = gridspec.GridSpecFromSubplotSpec(2,2,width_ratios=[10,0.5],subplot_spec=gs__[0],hspace=1,height_ratios=[1,1])

ax = f.add_subplot(gs[0,:])
ax.set_title('C\n\n',loc='left',x=-0.5,weight='bold')
ax.set_title('Tuning curves\n\n',loc='center',fontsize=8)
ax.axis('off')

ax = f.add_subplot(gs[0,0])
hplpl.plot_fI_pump(ax,I_nvd,None,f_nvd,None,False)
ax.set_xlabel('')
ax.set_xticks([0,2])
ax.set_xticklabels([])
ax.set_xlim([0,2.5])
ax.set_title('Constant', fontsize=8)

ax = f.add_subplot(gs[1,0])
hplpl.plot_fI_pump(ax,I_vd,None,f_vd,None,False)
ax.set_xticks([0,2])
ax.set_xticklabels([])
ax.set_xlim([0,2.5])
ax.set_title('Voltage\ndependent',fontsize=8)

ax = f.add_subplot(gs[:,1])
cb = plt.colorbar(im,cax=ax)
cb.set_label('Pump density [%]',labelpad=-10)
cb.set_ticks([50,150])

gs = gridspec.GridSpecFromSubplotSpec(3,2,subplot_spec=gs__[1],height_ratios=[1.5,2,2])

ax = f.add_subplot(gs[:])
ax.set_title('D\n\n',loc='left',x=-0.25,weight='bold')
ax.set_title('Signal generation\n',loc='center',fontsize=8)
ax.axis('off')

to_align = []

for i,dev in enumerate(['low','high']):

    a = np.load('data/dev_input_vd_%s.npz'%dev,allow_pickle=True)
    st, nain, v, n = a['st'], a['nain'], a['v'], a['n']
    ip = model_utils.eval_func(M_vd, "I_pump", Na_in=nain*brian2.mM,n=n,v=v*brian2.volt)

    ax = f.add_subplot(gs[0,i])
    if i==0:
        to_align.append(ax)
    hplpl.plot_input(ax, mean_syns[bl_idx], stim_dev[dev], T0, T1, T2)
    ax.set_ylim([0,0.2])
    if i==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])
    ax.set_yticks([0,0.2])
    ax.set_xlabel('')
    ax.set_xticklabels('')

    ax = f.add_subplot(gs[1,i])
    if i==0:
        to_align.append(ax)

    hplpl.plot_firing_rates(ax, st, None, None, T0, T1, T2)
    if i==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])
    ax.set_yticks([0,400])
    ax.set_ylim([-10,450])
    ax.set_xlabel('')
    ax.set_xticklabels('')

    ax = f.add_subplot(gs[2,i])
    if i==0:
        to_align.append(ax)
    hplpl.plot_pump(ax, ip, None, dt_rec,T0+T1+T2,N=10000)
    if i==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])
    else:
        ax.legend(frameon=False,loc='upper left',bbox_to_anchor=(0.1,1),handlelength=0.5)
f.align_ylabels(to_align)
plt.savefig('fig/6.pdf',bbox_inches='tight')
plt.close()
