# imports
import json, brian2, sys, os, copy
import numpy as np
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import model_utils, brianutils
import helpers.analysis as han
import helpers.plotting.plotting as hplpl
import helpers.plotting.schematics as hpls
import helpers.simulation as hs

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib.patches import Ellipse, Rectangle
import matplotlib
matplotlib.use('Agg')
from brianutils import units

#%% load params from cfg file
S = json.load(open('cfg/simulation_params.json'))

dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
inf_run_trace = eval(S['S']['inf_run']['trace'],units)
runtime_trace = eval(S['S']['runtime']['trace'],units)
inf_run_fi = eval(S['S']['inf_run']['fi'],units)
runtime_fi = eval(S['S']['runtime']['fi'],units)
I_in_fi = eval(S['ranges']['I_in_fi'],units)
pnfs = eval(S['ranges']['PN_freqs'],units)
pnf_idx = eval(S['ranges']['pnf_idx'])
I_pumps = eval(S['ranges']['I_pump'],units)
model = S['S']['model']
cmap = plt.get_cmap(S['plotting']['cmaps'][0])
cmap2 = plt.get_cmap(S['plotting']['cmaps'][1])


#%% load model and disable ion displacement
M = json.load(open(model))
M['ode'].pop(1)
M['parameters']['Na_in'] = M['parameters']['Na_in_0']
M['init_states'].pop('Na_in')

#%% get fI curve
print('gettin fI curve')
statemon, spikemon, _ = hs.run_sim(M, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi',variables=['v'])

#%% set everything that is not on the fI curve to NaN.
amps_np = (np.max(statemon.v,axis=1)-np.min(statemon.v,axis=1))/brian2.mV

for i in range(len(I_in_fi)):
    if len(spikemon.t[spikemon.i==i])<2:
        amps_np[i] = np.nan

#%% get fI curves for increasing pump currents with co-expression of Na_l and pumps
amps = []
apss = []
spiking = []

for i,ips in enumerate(I_pumps):

    M['parameters']['Ip_mean'] = str(ips).replace(' ','*')
    statemon, spikemon, _ = hs.run_sim(M, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi',variables=['v'])

    amp = (np.max(statemon.v,axis=1)-np.min(statemon.v,axis=1))/brian2.mV
    for i in range(len(I_in_fi)):
        if len(spikemon.t[spikemon.i==i])<2:
            amp[i] = np.nan

    amps.append(amp)

    aps = []
    for i,iin in enumerate(I_in_fi):
        if len(spikemon.t[spikemon.i==i])<2:
            aps.append([])
        else:
            cap = statemon.v[i,(statemon.t>spikemon.t[spikemon.i==i][0])&(statemon.t<=spikemon.t[spikemon.i==i][1])]
            aps.append(cap)
    apss.append(aps)

#%% now do the amplitudes of fI curves in figure 6
pump_densities = eval(S['ranges']['pump_density'])
I_pump = np.load('data/fitted_I_pump.npy')*brian2.amp
M = json.load(open(model))
M_vd = json.load(open(model))

M_vd['definitions']['I_K'] = '(2/3)*g_K_max*n**4*(v-E_K)'
M_vd['definitions']['I_K_L'] = '(2/3)*g_K_l*(v-E_K)'

M_vd['parameters'].pop('Ip_mean')
M_vd['functions']['Ip_mean'] = '(1/2)*(I_K+I_K_L)'
M_vd['functions']['g_Na_l'] = '0*uS'

#%% fI curves for voltage-dependent pump
M_vd_cion = copy.deepcopy(M_vd)
M_vd_cion['ode'].pop(1)
M_vd_cion['parameters']['Na_in'] = M_vd_cion['init_states']['Na_in']
M_vd_cion['init_states'].pop('Na_in')

amps_vd = []
apss_vd = []

print('getting fI curves for varying (voltage-dependent) pump denisties')
for mult in pump_densities:
    M_vd_cion['functions']['Ip_mean'] = str(mult)+'*(1/2)*(I_K+I_K_L)'
    statemon, spikemon, _ = hs.run_sim(M_vd_cion, {'I_stim':I_in_fi}, inf_run_fi, runtime_fi,mode='fi',variables=['v'])

    amp = (np.max(statemon.v,axis=1)-np.min(statemon.v,axis=1))/brian2.mV
    for i in range(len(I_in_fi)):
        if len(spikemon.t[spikemon.i==i])<2:
            amp[i] = np.nan
    amps_vd.append(amp)

    aps = []
    for i,iin in enumerate(I_in_fi):
        if len(spikemon.t[spikemon.i==i])<2:
            aps.append([])
        else:
            cap = statemon.v[i,(statemon.t>spikemon.t[spikemon.i==i][0])&(statemon.t<=spikemon.t[spikemon.i==i][1])]
            aps.append(cap)
    apss_vd.append(aps)

#%% plot APs for weak and strong inputs
syns_rev = [0.05,0.5]

# A without pump;
#%% load model and disable ion displacement
M = json.load(open(model))
M['ode'].pop(1)
M['parameters']['Na_in'] = M['parameters']['Na_in_0']
M['init_states'].pop('Na_in')

#%% get states
statemon, spikemon, _ = hs.run_sim(M, {'syn_clamp':syns_rev}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)

#%% get APs
aps = []
for i,iin in enumerate(I_in_fi):
    if len(spikemon.t[spikemon.i==i])<2:
        aps.append([])
    else:
        cap = statemon.v[i,(statemon.t>spikemon.t[spikemon.i==i][0])&(statemon.t<=spikemon.t[spikemon.i==i][1])]
        aps.append(cap)

aps_np = []
for ap in aps:
    if len(ap)==0:
        continue
    plt.plot(np.linspace(-len(ap)/2,len(ap)/2,len(ap)),np.roll(ap,int(len(ap)/2)-np.argmax(ap))/brian2.mV)
    aps_np.append(np.roll(ap,int(len(ap)/2)-np.argmax(ap))/brian2.mV)

#%% get currents
ion_sum_np = []
currents_np = []

for cidx in [0,1]:
    t, v_, currents = han.get_currents(M,statemon,mode='all',pnf_idx=cidx)
    spikes = spikemon.t[spikemon.i==cidx]

    currents_np.append([t,v_,currents,spikes])

    f = plt.figure()
    ax = f.add_subplot()
    AP, Na_int, K_int = hplpl.plot_current_scape(ax, M, t, v_, currents, spikes, dt, inf_run_fi, sep=False,bar=False)
    ion_sum_np.append([Na_int,K_int])

#%% do this for model with pump. I have to tune it then.
print('tuning pump rates to match energetic demand')
Ip_pred = np.zeros(len(syns_rev))*brian2.uA

for _ in range(20):

    # run short simulations for naturalistic mean inputs
    statemon, spikemon, _ = hs.run_sim(M, {'syn_clamp':syns_rev,'Ip_mean':Ip_pred}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)

    # get ISIs (period (T) of the limit cycles (lc))
    Tlc = [np.mean(np.diff(spikemon.t[spikemon.i==i])) for i in range(len(syns_rev))]

    # get mean Na and K currents
    INa_means, IK_means = han.get_currents(M, statemon, (Tlc/dt).astype(int))

    # estimate pump current from this
    new_Ip_pred = brian2.amp*(-np.array(INa_means)/3 + np.array(IK_means)/2)/2

    # exit loop if pump current converged
    if np.max(np.abs(Ip_pred - new_Ip_pred)) < 0.001*brian2.nA:
        break

    Ip_pred = new_Ip_pred

#%% get APs
aps = []
aps_ = []
for i,iin in enumerate(I_in_fi):
    if len(spikemon.t[spikemon.i==i])<2:
        aps.append([])
    else:
        cap = statemon.v[i,(statemon.t>spikemon.t[spikemon.i==i][0])&(statemon.t<=spikemon.t[spikemon.i==i][1])]
        aps.append(cap)

for ap in aps:
    if len(ap)==0:
        continue
    aps_.append(np.roll(ap,int(len(ap)/2)-np.argmax(ap))/brian2.mV)

#%% get currents
ion_sum = []
currents_ = []
for cidx in [0,1]:
    t, v_, currents = han.get_currents(M,statemon,mode='all',pnf_idx=cidx)
    spikes = spikemon.t[spikemon.i==cidx]

    currents_.append([t,v_,currents,spikes])

    f = plt.figure()
    ax = f.add_subplot()
    AP, Na_int, K_int = hplpl.plot_current_scape(ax, M, t, v_, currents, spikes, dt, inf_run_fi, sep=False,bar=False)
    ion_sum.append([Na_int,K_int])

#%% now do the same for the voltage-dependent pump
# 1. simulate model with voltage dependent pump until steady state
# it will not be at steady state exactly because the pump cannot depend on neuro-
# transmitter release.

M_vd = json.load(open(model))
M_vd['definitions']['I_K'] = '(2/3)*g_K_max*n**4*(v-E_K)'
M_vd['definitions']['I_K_L'] = '(2/3)*g_K_l*(v-E_K)'

M_vd['parameters'].pop('Ip_mean')
M_vd['functions']['Ip_mean'] = '(1/2)*(I_K+I_K_L)'
M_vd['functions']['g_Na_l'] = '0*uS'

na_in_min, na_in_max = eval(S['ranges']['Na_in_ss'],units)
na_in_min = np.ones(2)*na_in_min
na_in_max = np.ones(2)*na_in_max

#%% check if it works for one neuron
for _ in range(10):
    na_in_cur = (na_in_max+na_in_min)/2

    # run short simulations for naturalistic mean inputs
    statemon, spikemon, _ = hs.run_sim(M_vd, {'syn_clamp':syns_rev,'Na_in':na_in_cur}, inf_run_fi, runtime_fi, mode='syn_clamp', spikemon_rec=True)

    for i in [0,1]:
        if np.max(statemon.Na_in[i,-int(statemon.N/2):]) > np.max(statemon.Na_in[i,:int(statemon.N/2)]):
            na_in_min[i] = na_in_cur[i]
        else:
            na_in_max[i] = na_in_cur[i]

#%% get APs
aps = []
aps_vd = []
for i,iin in enumerate(I_in_fi):
    if len(spikemon.t[spikemon.i==i])<2:
        aps.append([])
    else:
        cap = statemon.v[i,(statemon.t>spikemon.t[spikemon.i==i][0])&(statemon.t<=spikemon.t[spikemon.i==i][1])]
        aps.append(cap)

for ap in aps:
    if len(ap)==0:
        continue
    aps_vd.append(np.roll(ap,int(len(ap)/2)-np.argmax(ap))/brian2.mV)

#%% get currents
ion_sum_vd = []
currents_vd = []
for cidx in [0,1]:
    t, v_, currents = han.get_currents(M_vd,statemon,mode='all',pnf_idx=cidx)
    spikes = spikemon.t[spikemon.i==cidx]

    currents_vd.append([t,v_,currents,spikes])

    f = plt.figure()
    ax = f.add_subplot()
    AP, Na_int, K_int = hplpl.plot_current_scape(ax, M_vd, t, v_, currents, spikes, dt, inf_run_fi, sep=False,bar=False)
    ion_sum_vd.append([Na_int,K_int])

#%% PLOTTING RESULTS

def plot_amps(ax, amps):
    if len(amps)==len(I_in_fi):
        for i,ip in enumerate(I_pumps):
            plt.plot((I_in_fi+ip)/brian2.uA,amps,c=cmap(i/(len(I_pumps)-1)))
    else:
        for i,amp in enumerate(amps):
            ax.plot(I_in_fi/brian2.uA,amp,c=cmap(i/(len(amps)-1)))
    ax.set_xlim([0,2.5])
    ax.set_ylim([60,100])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel('Input current [$\mu$A]')
    ax.set_ylabel('Spike amplitude [mV]')

#%% create part of figure with amplitude curves
imv = plt.imshow([[0,150]],cmap='cividis')
imnv = plt.imshow([[0,2.5]],cmap='cividis')

#%% now do comparisons with the bars. no pump vs pump (+leak) vs vpump
cmapb = plt.get_cmap('Blues_r')
cmapr = plt.get_cmap('Reds_r')
charge = 1.602176634e-19 * brian2.coulomb # C = A*s

#%%
plt.rc('axes', labelsize=8)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.rc('axes', titlesize=10)
plt.rc('font', size=8)

f = plt.figure(figsize=(8,8))
gs0 = f.add_gridspec(2,1,figure=f,height_ratios=[3,1],hspace=0.3)
gs00 = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs0[0],height_ratios=[1,2],hspace=0.8)
gs = gridspec.GridSpecFromSubplotSpec(1,3,subplot_spec=gs00[0],wspace=0.9)
ax = f.add_subplot(gs[0])
plot_amps(ax, amps_np)
ax.set_yticks([60,100])
ax.set_title('Without co-expression\n')
ax.set_title('A\n',loc='left',x=-0.7,weight='bold')
cb = plt.colorbar(imnv,ax=ax)
cb.set_label('Pump current [$\mu$A]')
cb.set_ticks([0,2.5])

ax = f.add_subplot(gs[1])
plot_amps(ax, amps)
ax.set_yticks([60,100])
ax.set_title('With co-expression\n')
cb = plt.colorbar(imnv,ax=ax)
cb.set_label('Pump current [$\mu$A]')
cb.set_ticks([0,2.5])

ax = f.add_subplot(gs[2])
plot_amps(ax, amps_vd)
ax.set_yticks([60,100])
ax.set_title('Voltage-dependent\npump')
cb = plt.colorbar(imv,ax=ax)
cb.set_label('Pump density [%]')
cb.set_ticks([0,150])

gs = gridspec.GridSpecFromSubplotSpec(2,5,subplot_spec=gs00[1],wspace=0.5)

for i in [0,1]:
    ax = f.add_subplot(gs[i,4])
    if i==0:
        ax.set_title('C\n',loc='left',x=-1,weight='bold')
    for j,(app,col,style,wtd,lab) in enumerate(zip([aps_np,aps_,aps_vd],['k',cmap2(3),'b'],['-',':',':'],[3,2,2],['without\npump','constant\npump','voltage-\ndependent\npump'])):
        ax.plot(np.linspace(-len(app[i])/2,len(app[i])/2,len(app[i])),app[i],color=col,linestyle=style,linewidth=wtd,label=lab)
        if j==0:
            ax.plot([len(app[i])/2-500,len(app[i])/2],[-90,-90],color='k',linewidth=5)
    ax.set_ylim([-90,40])
    ax.text(ax.get_xlim()[1],ax.get_ylim()[1],'$syn=%s$'%('0.5' if i==1 else '0.05'),verticalalignment='top',horizontalalignment='right')
    ax.axis('off')
h2,l2 = ax.get_legend_handles_labels()

ax = f.add_subplot(gs[:,4])
ax.set_title('Action potentials')
ax.axis('off')
for i in [0,1]:
    for j,cur in enumerate([currents_np,currents_,currents_vd]):
        ax = f.add_subplot(gs[i,j])
        if j==2:
            hplpl.plot_current_scape(ax, M_vd, cur[i][0], cur[i][1], cur[i][2], cur[i][3], dt, inf_run_fi, sep=False,bar=False)
        elif i==0 and j==0:
            ax.set_title('B\n',loc='left',x=-0.8,weight='bold')
            hplpl.plot_current_scape(ax, M, cur[i][0], cur[i][1], cur[i][2], cur[i][3], dt, inf_run_fi, sep=False,bar=True)
        else:
            hplpl.plot_current_scape(ax, M, cur[i][0], cur[i][1], cur[i][2], cur[i][3], dt, inf_run_fi, sep=False,bar=False)

        ax.set_ylim([-20,35])
        ax.set_yticks([-20,35])

        if j>0:
            ax.set_ylabel('')
            ax.set_yticklabels([])

            if j==1:
                h,l = ax.get_legend_handles_labels()
        if i==0:
            if j==0:
                ax.set_title('Without pump\n')
            elif j==1:
                ax.set_title('Constant pump\n')
                ax.text(ax.get_xlim()[1],ax.get_ylim()[1],'$syn=0.05$',verticalalignment='top',horizontalalignment='right')
            else:
                ax.set_title('Voltage-dependent\npump')
        else:
            if j==1:
                ax.text(ax.get_xlim()[1],ax.get_ylim()[1],'$syn=0.5$',verticalalignment='top',horizontalalignment='right')

ax = f.add_subplot(gs[:,3])
ax.legend(h,l,frameon=False,ncol=1,loc='upper left',handlelength=0.5,bbox_to_anchor=(-0.5,1))
ax.axis('off')

ax = f.add_subplot(gs[:,3])
ax.legend(h2,l2,frameon=False,ncol=1,loc='lower right',handlelength=1,bbox_to_anchor=(1.5,-0.1))
ax.axis('off')

gs = gridspec.GridSpecFromSubplotSpec(1,2,subplot_spec=gs0[1])
for i, (ion_np, ion, ion_vd) in enumerate(zip(ion_sum_np,ion_sum,ion_sum_vd)):
    ax = f.add_subplot(gs[i])
    if i==0:
        ax.set_title('D\n',loc='left',x=-0.3,weight='bold')
    ax.set_title('$syn=%s$'%('0.5' if i==1 else '0.05'))
    for j,ion_ in enumerate([ion_np,ion,ion_vd]):

        p = ax.bar(max(0,j*3-1),ion_[0]/charge/1e11, color=cmapb(0.6))
        ax.bar_label(p, label_type = 'center',fmt='Na$^+$')

        p = ax.bar(max(0,j*3-1),ion_[1]/charge/1e11, color=cmapr(0.6))
        ax.bar_label(p, label_type = 'center',fmt='K$^+$')
        if j>0:
            p = ax.bar(j*3,ion_[1]/charge/1e11/2, color='grey')

    ax.plot([-1,8],[0,0],'--',c='k')
    ax.set_xticks([])
    ax.set_yticks([-1,0,1])
    ax.set_yticklabels(['-1e11','0','1e11'])
    if i==0:
        ax.set_ylabel('Ion exchange\nper action potential')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_xticks([0,2.5,5.5])
    ax.set_xticklabels(['Without\npump','Constant\npump','Voltage-\ndependent\npump'])
    ax.set_xlim([-0.75,7])
    ax.set_xticklabels(ax.get_xticklabels(),rotation=30, ha='right')
plt.tight_layout()
plt.savefig('fig/A1.pdf')
