# imports
import json, brian2, sys
import numpy as np
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import model_utils, brianutils
import helpers.analysis as han
import helpers.plotting.plotting as hplpl
import helpers.plotting.schematics as hpls
import helpers.simulation as hs

import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib
matplotlib.use('Agg')
from brianutils import units

#%%
# simulation parameters (put in cfg). this block is used for the simulations as well.
S = json.load(open('cfg/simulation_params.json'))
inf_run = eval(S['S']['inf_run']['stim_protocol'],units)
rec_dt = eval(S['S']['rec_dt']['constant_pump'],units)

model = S['S']['model']

T0, T1, T2 = eval(S['stimuli']['long_square']['T0'],units), eval(S['stimuli']['long_square']['T1'],units), eval(S['stimuli']['long_square']['T2'],units)
bl_idx = eval(S['stimuli']['long_square']['baseline_idx'])
rise_idx = eval(S['stimuli']['long_square']['rise_idx'])

#%% load model and fitted params
M = json.load(open(model))

I_pump = np.load('data/fitted_I_pump.npy')*brian2.amp
mean_syns = np.load('data/mean_syns.npy')
fs = np.load('data/fr_fitted.npy')*brian2.Hz

f_dev = {'high': fs[rise_idx], 'low': 0}
stim_dev = {'high': mean_syns[rise_idx], 'low': 0}

#%%
plt.rc('axes', labelsize=8)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.rc('axes', titlesize=10)
plt.rc('font', size=8)


f = plt.figure(figsize=(8,4),constrained_layout=True)
gs = f.add_gridspec(3,2,figure=f,height_ratios=[1.5,2,2])

for i,dev in enumerate(['low','high']):

    a = np.load('data/dev_input_%s.npz'%dev,allow_pickle=True)
    st, nain = a['st'], a['nain']
    ip = model_utils.eval_func(M, "I_pump", Na_in=nain*brian2.mM,Ip_mean=I_pump[bl_idx])

    ax = f.add_subplot(gs[0,i])
    hplpl.plot_input(ax, mean_syns[bl_idx],stim_dev[dev],T0,T1,T2)
    ax.set_ylim([0,0.2])
    if i==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])
    ax.set_yticks([0,0.2])

    ax.set_xlim([0,100])
    if i==0:
        ax.set_title('Synaptic input suppression\n',loc='center')
        ax.set_title('A',loc='left',x=-0.25,weight='bold')
        ax.legend(frameon=False,loc='upper right')
    else:
        ax.set_title('D',loc='left',x=-0.05,weight='bold')
        ax.set_title('Synaptic input doubling\n',loc='center')
    ax.set_xticklabels([])
    ax.set_xlabel('')

    ax = f.add_subplot(gs[1,i])

    if i==0:
        ax.set_title('B',loc='left',x=-0.25,weight='bold')
    else:
        ax.set_title('E',loc='left',x=-0.05,weight='bold')

    hplpl.plot_firing_rates(ax, st, fs[bl_idx], f_dev[dev], T0, T1, T2)

    if i==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])
        ax.legend(frameon=False,loc='lower right')
    ax.set_yticks([0,200,400])
    ax.set_ylim([-10,450])
    ax.set_xticklabels([])
    ax.set_xlabel('')

    ax = f.add_subplot(gs[2,i])

    if i==0:
        ax.set_title('C',loc='left',x=-0.25,weight='bold')
    else:
        ax.set_title('F',loc='left',x=-0.05,weight='bold')

    hplpl.plot_pump(ax, ip, I_pump[bl_idx],rec_dt,T0+T1+T2)
    ax.set_ylim([1,2])
    ax.set_xlim([0,100])

    if i==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])

    ax.set_xlabel('Time [s]')

f.align_ylabels()
plt.savefig('fig/2.pdf',bbox_inches='tight')
plt.close()
