# imports
import json, brian2, sys
import numpy as np
sys.path.append('py/')
sys.path.append('py/helpers/')

from helpers import model_utils, brianutils
import helpers.analysis as han
import helpers.plotting.plotting as hplpl
import helpers.plotting.schematics as hpls
import helpers.simulation as hs

import matplotlib.pyplot as plt
import pickle as pkl
from matplotlib import gridspec
import matplotlib
matplotlib.use('Agg')
from brianutils import units
from matplotlib import patches
import matplotlib.image as mpimg

#%% load cfg
S = json.load(open('cfg/simulation_params.json'))
dt = brian2.defaultclock.dt = eval(S['S']['dt'],units)
model = S['S']['model']
cmap = plt.get_cmap(S['plotting']['cmaps'][1])
greys = plt.get_cmap('Greys')

circuit_img = mpimg.imread('fig/img/k_buffer.png')
circuit_img2 = mpimg.imread('fig/img/synapse.png')

#%% fitted pump currents
I_pump = np.load('data/I_pump_fitted_com.npy')*brian2.amp

# baseline pump current for different scenarios
ipump = {'chirp':[I_pump[0],I_pump[0]],'rises':[I_pump[3],I_pump[2]]}

#%% filenames
fnames = {'chirp':['','buf'],'rises':['0.5','1.0']}

##% plotting parameters
# zoom-ins
ts_ = {'chirp':[0.08,1.28],'rises':[5,23]}
te_ = {'chirp':[0.16,1.36],'rises':[5.1,23.1]}

# titles
titles = ['A','B','C','D','E','F','G']
titles_ = {'chirp':['Without extracellular\npotassium buffer','With extracellular\npotassium buffer'],'rises':['Weak synapse','Strong synapse']}

# plot limits
xlims = {'chirp':[0,2],'rises':[0,25]}
ylims = {'chirp':[3.8,4.2],'rises':[1.2,1.5]}
ylimsf = {'chirp':[0,600],'rises':[180,310]}
vlim = [-100,10]

# markersize scatterplots
ms_ = {'chirp':5, 'rises':2}

#%% Plot chirp figure
plt.rc('axes', labelsize=8)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.rc('axes', titlesize=10)
plt.rc('font', size=8)

mode='chirp'
f = plt.figure(figsize=(8,7))
gs_ = f.add_gridspec(2,1,height_ratios=[1,2],hspace=0.5)

gs = gridspec.GridSpecFromSubplotSpec(1,3,width_ratios=[2,1,6],subplot_spec=gs_[0])
ax = f.add_subplot(gs[1])
ax.text(0.75,0.87,'PN spikes',horizontalalignment='center')
ax.text(0.75,0.55,'EOD',horizontalalignment='center')
ax.axis('off')
ax = f.add_subplot(gs[:2])
hpls.plot_image(ax,circuit_img)

box = ax.get_position()
box.x0 = box.x0 - 0.15
ax.set_position(box)
ax.set_title('A',loc='left',weight='bold')
hplpl.plot_pn_above_eod(f, gs[2], dt)

gsc = gridspec.GridSpecFromSubplotSpec(5,2,subplot_spec=gs_[1],height_ratios=[2,1,2,1,2])

for j,fname in enumerate(fnames[mode]):
    ax = f.add_subplot(gsc[0,j])

    for i,(ts,te) in enumerate(zip(ts_[mode],te_[mode])):
        inset = patches.Rectangle((ts,0),te-ts,600,fc=greys(0.4),ec=greys(0.4))
        ax.add_patch(inset)

    if j==0:
        ax.set_title(titles[j*2+1],loc='left',weight='bold',x=-0.25)
    else:
        ax.set_title(titles[j*2+1],loc='left',weight='bold',x=-0.02)

    plt.title(titles_[mode][j]+'\n',loc='center')

    hplpl.plot_chirp(ax,fname,mode,model,ms=ms_[mode])
    ax.set_ylim(ylimsf[mode])

    if j==0:
        ax.legend(frameon=False,handletextpad=0,labelspacing=0,loc='upper right',bbox_to_anchor=(1.05,1.1))

    if j==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])
    else:
        ax.set_xlabel('')
        ax.set_xticklabels([])

    ax.set_xlim(xlims[mode])
    ax.set_xticks(xlims[mode])

    ax = f.add_subplot(gsc[2,j])

    for i,(ts,te) in enumerate(zip(ts_[mode],te_[mode])):
        inset = patches.Rectangle((ts,0),te-ts,600,fc=greys(0.4),ec=greys(0.4))
        ax.add_patch(inset)

    hplpl.plot_chirp_pump(ax,fname,mode,model,xlims[mode][0],xlims[mode][1],ipump[mode][j])
    ax.set_xlim(xlims[mode])
    ax.set_ylim(ylims[mode])

    ax.set_xticks([xlims[mode][0], ts_[mode][0]+(te_[mode][0]-ts_[mode][0])/2, ts_[mode][1]+(te_[mode][1]-ts_[mode][1])/2, xlims[mode][1]])
    ax.set_xticklabels([xlims[mode][0],'C'if j==0 else 'E','C' if j==0 else 'E',xlims[mode][1]])
    ax.get_xticklabels()[-2].set_weight("bold")
    ax.get_xticklabels()[-3].set_weight("bold")

    ax.set_xlabel('Time [s]',labelpad=-10)

    if j==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])

    gsc__ = gridspec.GridSpecFromSubplotSpec(2,2,subplot_spec=gsc[4,j])

    for i,(ts,te) in enumerate(zip(ts_[mode],te_[mode])):

        a = np.load('data/%s_%s.npz'%(mode,fname),allow_pickle=True)
        v, pnfs, st = a['v'], a['pnfs'][0], a['st'][0]

        ax = f.add_subplot(gsc__[0,i]) #gsc__[i*2])
        if i==0:
            if j==0:
                ax.set_title(titles[j*2+2+i],loc='left',weight='bold',x=-0.5)
            else:
                ax.set_title(titles[j*2+2+i],loc='left',weight='bold',x=-0.05)

        ax.eventplot(pnfs[(pnfs>=ts)&(pnfs<=te)],color=cmap(0),lineoffsets=-0.5)
        ax.eventplot(st[(st>=ts)&(st<=te)],color='k')

        ax.set_xlim([ts,te])
        ax.axis('off')

        ax = f.add_subplot(gsc__[1,i])

        plt.plot(np.linspace(ts,te,int((te-ts)/dt)),v[0,int(ts/dt):int(te/dt)]*1000,color='k')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlim([ts,te])
        ax.set_xticks([np.round((ts+0.1*(te-ts))*100)/100,np.round((te-0.1*(te-ts))*100)/100])
        ax.set_ylim(vlim)
        ax.set_yticks([-70,0])


        if i==0 and j==0:
            ax.set_ylabel('Membrane \nvoltage [mV]')
        else:
            ax.set_yticklabels([])

        ax.set_xlabel('Time [s]')

plt.savefig('fig/4.pdf',dpi=1000,bbox_inches='tight')
plt.close()

#%% plot rises
mode = 'rises'
f = plt.figure(figsize=(8,8))
gs_ = f.add_gridspec(2,1,height_ratios=[1,2],hspace=0.25)

gs = gridspec.GridSpecFromSubplotSpec(1,3,width_ratios=[2,1,6],subplot_spec=gs_[0])
ax = f.add_subplot(gs[1])
ax.text(1.1,0.225,'PN spikes',horizontalalignment='right')
ax.text(1.1,0.075,'EOD',horizontalalignment='right')
ax.axis('off')
ax = f.add_subplot(gs[:2])
ax.set_title('A',loc='left',weight='bold',x=-0.2)
hpls.plot_image(ax,circuit_img2)
ax.axis('off')
hplpl.plot_rises_intro(f,gs[2],dt)

gsc = gridspec.GridSpecFromSubplotSpec(4,2,subplot_spec=gs_[1],height_ratios=[0.75,0.5,0.5,2])

for j,fname in enumerate(fnames[mode]):
    ax = f.add_subplot(gsc[0,j])
    for i,(ts,te) in enumerate(zip(ts_[mode],te_[mode])):
        inset = patches.Rectangle((ts,0),te-ts,600,fc=greys(0.4),ec=greys(0.4))
        ax.add_patch(inset)

    if j==0:
        ax.set_title(titles[j*3+1]+'\n',loc='left',weight='bold',x=-0.3)
    else:
        ax.set_title(titles[j*3+1]+'\n',loc='left',weight='bold',x=-0.05)
    plt.title(titles_[mode][j],loc='center')

    a = np.load('data/rises_syn.npz',allow_pickle=True)
    v, pnfs, st = a['v'][1-j], a['pnfs'][0], a['st']
    st = st[0][st[1]==1-j]
    print(np.abs(han.entrainment_index(pnfs[(pnfs<=25)],st[(st<=25)])))

    ax.set_title(r'$\overline{R}$ = %.3f'%np.abs(han.entrainment_index(pnfs[(pnfs<=25)],st[(st<=25)])),loc='right',fontsize=8)

    hplpl.plot_chirp(ax,fname,mode,model,ms=ms_[mode])
    ax.set_ylim(ylimsf[mode])

    if j==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])
    else:
        ax.legend(frameon=False,handletextpad=-0.5,loc='lower left',bbox_to_anchor=(0.01,-0.1),ncol=2,columnspacing=0.1)

    ax.set_xlim(xlims[mode])
    ax.set_xticks([xlims[mode][0], ts_[mode][0]+(te_[mode][0]-ts_[mode][0])/2, ts_[mode][1]+(te_[mode][1]-ts_[mode][1])/2, xlims[mode][1]])

    ax = f.add_subplot(gsc[1,j])
    for i,(ts,te) in enumerate(zip(ts_[mode],te_[mode])):
        inset = patches.Rectangle((ts,0),te-ts,600,fc=greys(0.4),ec=greys(0.4))
        ax.add_patch(inset)
    hplpl.plot_chirp_pump(ax,fname,mode,model,xlims[mode][0],xlims[mode][1],ipump=ipump[mode][j])
    ax.set_ylim(ylims[mode])
    ax.set_xlim(xlims[mode])
    ax.set_xticks([xlims[mode][0], ts_[mode][0]+(te_[mode][0]-ts_[mode][0])/2, ts_[mode][1]+(te_[mode][1]-ts_[mode][1])/2, xlims[mode][1]])
    ax.set_xticklabels([xlims[mode][0],'C'if j==0 else 'F','D' if j==0 else 'G',xlims[mode][1]])
    ax.get_xticklabels()[-2].set_weight("bold")
    ax.get_xticklabels()[-3].set_weight("bold")

    ax.set_xlabel('Time [s]',labelpad=-10)

    if j==1:
        ax.set_ylabel('')
        ax.set_yticklabels([])

    gsc__ = gridspec.GridSpecFromSubplotSpec(5,1,subplot_spec=gsc[3,j],height_ratios=[1,1,1.5,1,1])

    for i,(ts,te) in enumerate(zip(ts_[mode],te_[mode])):

        a = np.load('data/%s_syn.npz'%mode,allow_pickle=True)
        v, pnfs, st = a['v'][1-j], a['pnfs'][0], a['st']
        st = st[0][st[1]==1-j]

        ax = f.add_subplot(gsc__[i*3])
        if j==0:
            ax.set_title(titles[j*3+2+i],loc='left',weight='bold',x=-0.25)
        else:
            ax.set_title(titles[j*3+2+i],loc='left',weight='bold',x=-0.05)
        ax.eventplot(pnfs[(pnfs>=ts)&(pnfs<=te)],color=cmap(0),lineoffsets=-0.5)
        ax.eventplot(st[(st>=ts)&(st<=te)],color='k')

        ax.set_xlim([ts,te])
        ax.axis('off')

        ax.set_title(r'$\overline{R}$ = %.3f' % abs(han.entrainment_index(pnfs[(pnfs>=ts)&(pnfs<=te)],st[(st>=ts)&(st<=te)])),loc='right',fontsize=8)

        ax = f.add_subplot(gsc__[i*3+1])
        plt.plot(np.linspace(ts,te,len(v[int(ts/dt):int(te/dt)])),v[int(ts/dt):int(te/dt)]*1000,color='k')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlim([ts,te])
        ax.set_xticks([ts,te])
        ax.set_ylim(vlim)
        ax.set_yticks([-70,0])

        if j==0:
            ax.set_ylabel('Membrane \n voltage [mV]')
        else:
            ax.set_yticklabels([])

        if i==1:
            ax.set_xlabel('Time [s]')

plt.savefig('fig/5.pdf',dpi=1000,bbox_inches='tight')
plt.close()
