import json, sympy, scipy, brian2

import numpy as np
from numpy import exp

import brianutils, model_utils
from brianutils import units

import matplotlib.pyplot as plt

import copy
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch

cmap = plt.get_cmap('cividis')
cmapv = plt.get_cmap('viridis')
cmapr = plt.get_cmap('Blues_r')
cmapb = plt.get_cmap('Reds_r')
c_pn = plt.get_cmap('Dark2')(0)
c_bl = plt.get_cmap('Dark2')(3)

S = json.load(open('cfg/simulation_params.json'))

def plot_current_scape(ax, M, t, v, currents, spikes, dt, inf_run, sep=False, bar=True):

    I_NaT, I_NaP, I_Na_L, I_AChRNa, I_K, I_K_L, I_AChRK, I_pump = currents

    isi = spikes[1]-spikes[0]
    t0 = int((spikes[0] + isi/2)/dt - inf_run/dt)
    t1 = int((spikes[1] + isi/2)/dt - inf_run/dt)

    if sep:
        ax.plot(t[t0:t1], (I_NaP[t0:t1]+I_NaT[t0:t1]+I_AChRNa[t0:t1])/brian2.uA,color=cmapb(0.6),label='Na$^+$')
        ax.plot(t[t0:t1], (I_K_L[t0:t1]+I_K[t0:t1]+I_AChRK[t0:t1])/brian2.uA,color=cmapr(0.6),label='K$^+$')

        if type(I_pump) != int:
            ax.plot(t[t0:t1],I_pump[t0:t1]/brian2.uA, color='grey',label='Pump')

        if type(I_Na_L) != int:
            ax.plot(t[t0:t1],I_Na_L[t0:t1]/brian2.uA,color=cmapb(0),label='Na$^+$ leak')

    else:

        if type(I_pump) != int:
            ax.fill_between(t[t0:t1],I_pump[t0:t1]/brian2.uA, (I_pump[t0:t1] + I_K_L[t0:t1]+I_K[t0:t1]+I_AChRK[t0:t1])/brian2.uA,color=cmapr(0.6),label='K$^+$')
            ax.fill_between(t[t0:t1],I_pump[t0:t1]/brian2.uA, color='grey',label='Pump')
        else:
            ax.fill_between(t[t0:t1], (I_K_L[t0:t1]+I_K[t0:t1]+I_AChRK[t0:t1])/brian2.uA,color=cmapr(0.6),label='K$^+$')

        if type(I_Na_L) != int:
            ax.fill_between(t[t0:t1],I_Na_L[t0:t1]/brian2.uA, (I_Na_L[t0:t1]+I_NaP[t0:t1]+I_NaT[t0:t1]+I_AChRNa[t0:t1])/brian2.uA,color=cmapb(0.6),label='Na$^+$')
            ax.fill_between(t[t0:t1],I_Na_L[t0:t1]/brian2.uA,color=cmapb(0.2),label='Na$^+$ leak')
        else:
            ax.fill_between(t[t0:t1],(I_NaP[t0:t1]+I_NaT[t0:t1]+I_AChRNa[t0:t1])/brian2.uA,color=cmapb(0.6),label='Na$^+$')

    if type(I_pump) == int and type(I_Na_L) == int:
        ax.plot(t[t0:t1],(I_NaP[t0:t1] + I_NaT[t0:t1] + I_AChRNa[t0:t1] + I_K_L[t0:t1] + I_K[t0:t1] + I_AChRK[t0:t1])/brian2.uA,c='k',label='Net')
    elif type(I_Na_L) == int:
        ax.plot(t[t0:t1],(I_NaP[t0:t1] + I_NaT[t0:t1] + I_AChRNa[t0:t1] + I_pump[t0:t1] + I_K_L[t0:t1] + I_K[t0:t1] + I_AChRK[t0:t1])/brian2.uA,c='k',label='Net')
    else:
        ax.plot(t[t0:t1],(I_Na_L[t0:t1] + I_NaP[t0:t1] + I_NaT[t0:t1] + I_AChRNa[t0:t1] + I_pump[t0:t1] + I_K_L[t0:t1] + I_K[t0:t1] + I_AChRK[t0:t1])/brian2.uA,c='k',label='Net')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_xticks([])
    ax.set_xticklabels([])

    ax.plot([t[t0],t[t1]],[0,0],c='k',linewidth=1,zorder=10)
    ax.set_xlim([t[t0],t[t1]])

    if sep:
        ax.set_ylabel(r'Current [$\mu$A]')
    else:
        ax.set_ylabel('Current\ncontribution [$\mu$A]')

    ax.plot([t[t0]/brian2.second + 0.0003,t[t0]/brian2.second+0.0008],[-15,-15],c='k',linewidth=2)

    if bar:
        ax.text(t[t0]/brian2.second+0.0003,-17,'0.5 ms',horizontalalignment='left',verticalalignment='top')

    if type(I_Na_L) == int:
        print('syn percentage')
        print(np.sum(I_AChRNa[t0:t1])/np.sum(I_AChRNa[t0:t1] + I_NaP[t0:t1] + I_NaT[t0:t1]))

        return v[t0:t1], np.sum(I_NaP[t0:t1] + I_NaT[t0:t1] + I_AChRNa[t0:t1])*dt, np.sum(I_K_L[t0:t1] + I_K[t0:t1] + I_AChRK[t0:t1])*dt
    else:
        print('RMP')
        print(np.sum(I_Na_L[t0:t1]))

        print('syn')
        print(np.sum(I_AChRNa[t0:t1]))

        print('AP')
        print(np.sum(I_NaP[t0:t1] + I_NaT[t0:t1]))

        # print percentage of syn wrt total Na:
        print('syn percentage')
        print(np.sum(I_AChRNa[t0:t1])/np.sum(I_AChRNa[t0:t1] + I_Na_L[t0:t1] + I_NaP[t0:t1] + I_NaT[t0:t1]))

        return v[t0:t1], np.sum(I_Na_L[t0:t1] + I_NaP[t0:t1] + I_NaT[t0:t1] + I_AChRNa[t0:t1])*dt, np.sum(I_K_L[t0:t1] + I_K[t0:t1] + I_AChRK[t0:t1])*dt



def plot_firing_rates(ax, st, f_bl, dev_f, T0, T1, T2):

    if f_bl is not None:
        ax.plot([0,T0],[f_bl,f_bl],c=c_bl,zorder=-1,linewidth=2)
        ax.plot([T0,T0,T0+T1,T0+T1],[f_bl,dev_f,dev_f,f_bl],c=c_bl,zorder=-1,linewidth=2,clip_on=False)
        ax.plot([T0+T1,T0+T1+T2],[f_bl,f_bl],c=c_bl,zorder=-1,linewidth=2,label='Without pump current')
        ax.plot([0,T0+T1+T2],[f_bl,f_bl],'--',c='grey',zorder=-2)

    if np.max(np.diff(st))>10:
        frs = []
        wsize=0.5
        for ws in np.linspace(0,100-wsize,10000):
            frs.append(len(st[(st>ws)&(st<ws+wsize)])/wsize)
        ax.plot(np.linspace(wsize/2,100-wsize/2,10000),frs,c='k')
    else:
        ax.plot(st[st<100][1:], 1/np.diff(st[st<100]),c='k',zorder=1,clip_on=False)

    ax.set_ylabel('Firing rate\n[Hz]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([0,T0+T1+T2])
    ax.set_xlabel('Time [s]')

def plot_input(ax,bl,dev,T0,T1,T2):

    ax.plot([0,T0+T1+T2],[bl,bl],'--',c='grey',label='Baseline')
    ax.plot([0,T0],[bl,bl],c='k')
    ax.plot([T0,T0,T0+T1,T0+T1],[bl,dev,dev,bl],c='k',clip_on=False)
    ax.plot([T0+T1,T0+T1+T2],[bl,bl],c='k')

    ax.set_ylabel('Synaptic\ninput\n[arb. u.]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([0,T0+T1+T2])
    ax.set_xlabel('Time [s]')

def plot_pump(ax,ip,pump_bl,dt_rec,tmax,N=None):

    if pump_bl is not None:
        ax.plot([0,100],[pump_bl/brian2.uA,pump_bl/brian2.uA],'--',c='grey')

    if N is not None:
        ax.plot(np.linspace(0,len(ip)*dt_rec,len(ip)),ip/brian2.uA,c='k',alpha=0.4,label='Instant')
        ax.plot(np.linspace(N/2*dt_rec,(len(ip)-N/2)*dt_rec,len(ip)-N+1),np.convolve(ip/brian2.uA, np.ones(N)/N, mode='valid'),c='k',label='Time-\naveraged')
    else:
        ax.plot(np.linspace(0,len(ip)*dt_rec,len(ip)),ip/brian2.uA,c='k')

    ax.set_ylabel('Pump current\n[$\mu$A]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([0,tmax])
    ax.set_xlabel('Time [s]')

#%% plot shifted fI curve due to pump currents
def plot_fI_pump(ax, I_in, I_pumps, f_, I_stim, show_input=True):

    if type(f_) == list:
        for i, (iin, ff) in enumerate(zip(I_in, f_)):
            ax.plot(iin/brian2.uA, ff, c=cmap(i/(len(f_)-1)))
    else:
        for i,ip in enumerate(I_pumps):
            ax.plot((I_in+ip)/brian2.uA, f_,c=cmap(i/(len(I_pumps)-1)))

    if show_input:
        ax.plot([I_stim/brian2.uA,I_stim/brian2.uA],[0,700],'--',c='k')

    ax.set_xlabel(r'Input current [$\mu$A]')
    ax.set_ylabel(r'Firing rate [Hz]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([0,3.5])
    ax.set_ylim([0,700])

#%% plot pump vs firing rate
def plot_pump_vs_freq(ax, I_pumps, f_intersects, I_stim,text=True):
    ax.plot(I_pumps/brian2.uA,f_intersects,c='k')
    ax.set_xlabel(r'Pump current [$\mu$A]')
    ax.set_ylabel('Firing rate [Hz]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([0,np.max(I_pumps/brian2.uA)])
    ax.set_ylim([0,600])
    if text:
        ax.text(0.75,600,'Input\ncurrent\n= %.2f $\mu$A'%(I_stim/brian2.uA),verticalalignment='top')

def plot_coexpression_rule(ax, I_pumps, gNal):
    for i,(ip,gnal) in enumerate(zip(I_pumps,gNal)):
        ax.plot(ip/brian2.uA,gnal/brian2.uS,marker='o',c=cmap(ip/np.max(I_pumps)))
    ax.set_xlabel(r'Pump current [$\mu$A]')
    ax.set_ylabel('Sodium leak\nconductance [$\mu$S]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

#%% plot firing rate vs required pump current.
def plot_fr_vs_pump(ax, fs, Ip_pred, I_pumps):
    for i,(f,ip) in enumerate(zip(fs,Ip_pred)):
        ax.plot(f,ip/brian2.uA,c=cmap(ip/np.max(I_pumps)),marker='o')
    ax.set_xlabel('Firing rate [Hz]')
    ax.set_ylabel('Pump current [$\mu$A]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

def plot_spikes(ax,AP,AP_nopump,dt):

    plt.plot(np.arange(0,len(AP))-len(AP)/2,AP,label='With pump',color='k')
    plt.plot(np.arange(0,len(AP_nopump))-len(AP_nopump)/2,AP_nopump,':',label='Without pump',color=c_bl)

    ax.plot([-(0.5*brian2.ms)/dt, 0], [np.min(AP) - 5*brian2.mV, np.min(AP)-5*brian2.mV], color='k', linewidth=2,clip_on=False)

    ax.text(-brian2.ms/dt/2,np.min(AP)-10*brian2.mV,'0.5 ms',verticalalignment='top',horizontalalignment='center')
    ax.axis('off')

#%% plot energetic requirements
def plot_total_currents(ax, Na_int, K_int, Na_int_nopump, K_int_nopump):

    charge = 1.602176634e-19 * brian2.coulomb # C = A*s

    p = ax.bar(1,Na_int/charge, color=cmapb(0.6),label='Na$^+$')
    p = ax.bar(2,Na_int_nopump/charge, color=cmapb(0.6),label='Na$^+$')
    ax.bar_label(p, label_type = 'center',fmt='Na$^+$')

    p = ax.bar(1,K_int/charge, color=cmapr(0.6),label='K$^+$')
    p = ax.bar(2,K_int_nopump/charge, color=cmapr(0.6),label='K$^+$')
    ax.bar_label(p, label_type = 'center',fmt='K$^+$')

    ax.plot([0,3],[0,0],'--',c='k')
    ax.set_xticks([1,2])
    ax.set_xticklabels(['With pump','Without pump'])
    ax.set_yticks([-1e11,0,1e11])
    ax.set_ylabel('Ion exchange\nper action potential')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_xlim([0.25,2.75])
    ax.set_xticklabels(ax.get_xticklabels(),rotation=30, ha='right')

def plot_v(ax, v_traces, runtime, c='k'):
    ax.plot(np.linspace(0,runtime/brian2.ms,len(v_traces)),v_traces/brian2.mV,c=c)
    ax.set_ylabel('Membrane\nvoltage [mV]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([0,10])
    ax.set_xlim([0,10])
    ax.set_yticks([-80,20])
    ax.set_ylim([-100,20])
    ax.set_xlabel('Time [ms]')

def plot_stim(ax, I_stim, runtime,c='k'):
    ax.plot(np.linspace(0,runtime/brian2.ms,len(I_stim)),I_stim/brian2.uA,c=c,label='Time-dependent')
    ax.set_ylabel('Input\ncurrent [$\mu$A]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticks([0,10])
    ax.set_xlim([0,10])
    ax.set_yticks([0,6])
    ax.set_ylim([0,7])
    ax.set_xticklabels([])

def plot_fi(ax, I_in, f_):
    ax.plot(I_in/brian2.uA,f_,c='k')
    ax.set_ylim(0,630)
    ax.set_ylabel('Mean driven\nelectrocyte firing rate [Hz]')
    ax.set_xlabel('Mean input current [$\mu$A]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([-0.5,2.5])

def plot_If(ax, pnfs, I_stims):
    ax.plot(pnfs,I_stims/brian2.uA,c='k')
    ax.set_xlabel('Pacemaker firing rate [Hz]')
    ax.set_ylabel('Mean input\ncurrent [$\mu$A]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlim([200,600])

def plot_freq_diff(ax, pnfs, f_md,arrow=False):
    ax.plot(pnfs,pnfs,c=c_pn,label='Time-dependent\ninput')
    ax.plot(pnfs,f_md,c='k',label='Mean-driven')

    if arrow:
        for f1,f2 in zip(pnfs/brian2.Hz,f_md):
            plt.plot([f1,f1],[f1,f2],c='grey',clip_on=False,linewidth=1)

            if np.abs(f1-f2)>10:
                if f1>f2:
                    plt.plot(f1,f2,markersize=5,marker=7,c='grey',clip_on=False)
                    plt.plot(f1,f1,marker=6,markersize=5,c='grey',clip_on=False)
                else:
                    plt.plot(f1,f2,marker=6,c='grey',markersize=5,clip_on=False)
                    plt.plot(f1,f1,marker=7,c='grey',markersize=5,clip_on=False)

    ax.set_xlabel('Pacemaker firing rate [Hz]')
    ax.set_ylabel('Electrocyte\nfiring rate [Hz]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend(loc='upper left',frameon=False)
    ax.set_xlim([200,600])
    ax.set_ylim([200,600])

#%% simulate a shorter chirp just for the schematic?
def plot_pn(ax, pnfs, ts, te):
    ax.eventplot(pnfs[(pnfs>=ts)&(pnfs<=te)],color=c_pn,lineoffsets=-0.5,clip_on=False)
    ax.set_xlim([ts,te])
    ax.axis('off')

def plot_eod(ax, v, ts, te, dt, mode='chirp'):
    try:
        ax.plot(np.linspace(ts,te,int((te-ts)/dt)),v[0,int(ts/dt):int(te/dt)]*1000,color='k',clip_on=False)
    except:
        ax.plot(np.linspace(ts,te,int((te-ts)/dt)+1),v[0,int(ts/dt):int(te/dt)]*1000,color='k',clip_on=False)

    if mode=='chirp':
        ax.plot([ts+0.3*(te-ts),te-0.35*(te-ts)],[0,0],color='k',linewidth=4)
        ax.text(ts+0.475*(te-ts),10,'Chirp',horizontalalignment='center',verticalalignment='bottom')

    ax.set_xlim([ts,te])
    ax.axis('off')

def plot_freqs(ax, pnfs, st, ts, te, tss=None, tee=None, mode='chirps'):
    if mode=='chirps':
        ms1 = 25
        ms2 = 25
    else:
        ms1 = 10
        ms2 = 1
    ax.scatter(pnfs[(pnfs>=ts)&(pnfs<=te)][1:], 1/np.diff(pnfs[(pnfs>=ts)&(pnfs<=te)]),ms1,color=c_pn,clip_on=False,rasterized=True)

    ifreq = 1/np.diff(st[(st>=ts)&(st<=te)])
    it = st[(st>=ts)&(st<=te)][1:]

    ax.set_xlim([ts,te])
    ax.set_ylabel('Firing\nrate [Hz]')
    if mode=='chirps':
        ax.scatter(it,ifreq,ms2,color='k',clip_on=False,rasterized=True)
        ax.set_ylim([0,600])
        ax.set_yticks([0,600])
    else:
        ax.scatter(it[ifreq>=250],ifreq[ifreq>=250],ms2,color='k',clip_on=False,rasterized=True)
        ax.text(tss+0.1,310,'Frequency rise',horizontalalignment='left',verticalalignment='center')

        ax.plot([tss,tee,tee,tss,tss],[250,250,310,310,250],color='k',clip_on=False,linewidth=0.75)
        ax.set_ylim([250,310])
        ax.set_yticks([250,300])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_xticks([])
    if mode=='chirps':
        ax.plot([te-0.01,te],[100,100],linewidth=2,color='k',clip_on=False)
        ax.text(te-0.01,0,'10 ms',horizontalalignment='left',verticalalignment='top')
    else:
        ax.plot([te-0.05,te],[255,255],linewidth=2,color='k',clip_on=False)
        ax.text(te,250,'50 ms',horizontalalignment='right',verticalalignment='top')

def plot_pn_above_eod(f, gs0, dt, ts=0.09,te=0.13):

    inf_run = eval(S['S']['inf_run']['stim_protocol'],units)

    a = np.load('data/chirp_short.npz',allow_pickle=True)
    v, pnfs, st = a['v'], a['pnfs'], a['st']

    # TODO: THIS IS STILL HARDCODED. IT SHOULD BE REMOVED ALREADY BEFORE SAVING THE FILE.
    # it should work atm however since inf_run is set to zero
    pnfs = pnfs[0] - inf_run/brian2.second #- 0.2 --> TODO
    st = st[0] -inf_run/brian2.second #-0.2

    gs00 = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs0,height_ratios=[2,1])
    gs = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs00[0],hspace=0,height_ratios=[1,2])

    ax1 = f.add_subplot(gs[0])
    plot_pn(ax1, pnfs, ts, te)
    ax2 = f.add_subplot(gs[1])
    plot_eod(ax2, v, ts, te, dt)

    ax3 = f.add_subplot(gs00[1])
    plot_freqs(ax3, pnfs, st, ts, te)
    ax3.plot([ts+0.011,ts+0.011],[100,400],color=c_pn)
    ax3.plot(ts+0.011,400,'^',color=c_pn)
    ax3.text(ts+0.011,0,'Chirp\ninitiation',horizontalalignment='center',verticalalignment='top',color=c_pn)

def plot_rises_intro(f, gs0, dt, tsb=4.9,teb=6.5,tss=4.975,tes=5.05):

    inf_run = eval(S['S']['inf_run']['stim_protocol'],units)

    a = np.load('data/rises_syn.npz',allow_pickle=True)
    v, pnfs, st = a['v'], a['pnfs'][0] - inf_run/brian2.second, a['st'][0] - inf_run/brian2.second

    gs00 = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs0,height_ratios=[2,1],hspace=0.3)
    gs = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs00[1],hspace=0,height_ratios=[1,2])

    ax3 = f.add_subplot(gs00[0])
    plot_freqs(ax3, pnfs, st, tsb, teb, tss, tes, mode='rises')

    axs = f.add_subplot(gs[:])
    axs.set_xticks([])
    axs.set_yticks([])

    #axin1 = ax.inset_axes([0.8, 0.1, 0.15, 0.15])

    #gs = gridspec.GridSpecFromSubplotSpec(2,1,subplot_spec=gs[:],height_ratios=[1,2])

    ax1_ = f.add_subplot(gs[0])
    ax1_.axis('off')
    ax1 = ax1_.inset_axes([0.05, 0, 0.9, 0.9])
    plot_pn(ax1, pnfs, tss, tes)
    ax2_ = f.add_subplot(gs[1])
    ax2_.axis('off')
    ax2 = ax2_.inset_axes([0.05, 0.1, 0.9, 0.9])
    plot_eod(ax2, v, tss, tes, dt, mode='rises')

    con = ConnectionPatch(xyA=(tss,250), xyB=(axs.get_xlim()[0],axs.get_ylim()[1]), coordsA="data", coordsB="data",
                      axesA=ax3, axesB=axs, color="k",linewidth=0.75)
    ax3.add_artist(con)

    con = ConnectionPatch(xyA=(tes,250), xyB=(axs.get_xlim()[1],axs.get_ylim()[1]), coordsA="data", coordsB="data",
                      axesA=ax3, axesB=axs, color="k",linewidth=0.75)
    ax3.add_artist(con)

def plot_chirp(ax,fname,mode,model,ms=2,arrows=True,sat=False):

    inf_run = eval(S['S']['inf_run']['stim_protocol'],units)

    M = json.load(open(model))
    if mode=='chirp':
        a = np.load('data/%s_%s.npz'%(mode,fname),allow_pickle=True)
    else:
        a = np.load('data/%s_syn.npz'%mode,allow_pickle=True)

    st, pnfs = a['st'], a['pnfs'][0]-inf_run/brian2.second

    if '1.0' in fname:
        st = st[0][st[1]==0]-inf_run/brian2.second
    elif '.5' in fname:
        st = st[0][st[1]==1]-inf_run/brian2.second
    else:
        st = st[0]-inf_run/brian2.second

    ax.plot(pnfs[1:], 1/np.diff(pnfs),'.',c=c_pn,markersize=ms*2,rasterized=True,label=r'$\frac{1}{\mathrm{ISI}_\mathrm{PN}}$')
    ax.plot(st[1:], 1/np.diff(st),'.',markersize=ms,c='k',rasterized=True,label=r'$\frac{1}{\mathrm{ISI}_\mathrm{electrocyte}}$')

    if mode=='rises' and not sat:
        ax.set_ylabel('Firing\nrate\n[Hz]')
    else:
        ax.set_ylabel('Firing\nrate [Hz]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xticklabels([])

    if mode=='chirp':
        ax.set_xlim([0,2.5])
        ax.set_ylim([0,600])

    if arrows:
        for chirptime in (pnfs[1:])[1/np.diff(pnfs)<100]:
            ax.plot([chirptime,chirptime],[-200,-50],color=c_pn,clip_on=False)
            ax.plot(chirptime,-50,'^',clip_on=False,color=c_pn,markersize=3)

        if fname=='':
            ax.text(chirptime+0.1,-125,'chirp initiations',verticalalignment='top',color=c_pn)

def plot_chirp_pump(ax,fname,mode,model,ts,te,ipump,sat=False):

    M = json.load(open(model))

    if mode=='chirp':
        a = np.load('data/%s_%s.npz'%(mode,fname),allow_pickle=True)
    else:
        a = np.load('data/%s_syn.npz'%mode,allow_pickle=True)

    nain = a['nain']

    if 'buf' in fname:
        M['functions']['K_out'] = 'K_out_0'

    if '.5' in fname:
        nain = nain[1]
    else:
        nain = nain[0]

    delta_ip = model_utils.eval_func(M, "I_pump", Na_in=nain*brian2.mM,Ip_mean=ipump)

    ax.plot([ts,te],[ipump/brian2.uA, ipump/brian2.uA],'--',c='grey',label='Baseline')
    if fname=='' or sat:
        ax.text(te,ipump/brian2.uA,'Baseline',color='grey',horizontalalignment='right',verticalalignment='top')
    ax.plot(np.linspace(0,len(nain)/1000000,len(nain)),delta_ip/brian2.uA,c='k')

    if mode=='rises' and not sat:
        ax.set_ylabel('Pump\ncurrent\n[$\mu$A]')
    else:
        ax.set_ylabel('Pump\ncurrent [$\mu$A]')
    ax.set_xlabel('Time [s]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    if mode=='chirp':
        ax.set_xlim([ts,te])

#%% TODO: THESE ARE DUPLICATES. COMBINE THEM WITH THE EXISTING FUNCTIONS.


def plot_totcur_compare(ax,na_,k_,p_,na,k,p):
    cmapb = plt.get_cmap('Blues_r')
    cmapr = plt.get_cmap('Reds_r')
    charge = 1.602176634e-19 * brian2.coulomb # C = A*s

    p = ax.bar(2,na/charge/1e11, color=cmapb(0.6),label='Na$^+$')
    p = ax.bar(0,na_/charge/1e11, color=cmapb(0.6),label='Na$^+$')
    ax.bar_label(p, label_type = 'center',fmt='Na$^+$')
    p = ax.bar(2,k/charge/1e11, color=cmapr(0.6),label='K$^+$')
    p = ax.bar(3,k/charge/1e11/2, color='grey',label='K$^+$')
    ax.bar_label(p, label_type = 'edge',fmt='Net\npump')
    p = ax.bar(0,k_/charge/1e11, color=cmapr(0.6),label='K$^+$')
    ax.bar_label(p, label_type = 'center',fmt='K$^+$')
    p = ax.bar(1,k_/charge/1e11/2, color='grey',label='K$^+$')

    ax.plot([-1,5],[0,0],'--',c='k')
    ax.set_xticks([])
    ax.set_yticks([-1,0,1])
    ax.set_yticklabels(['-1e11','0','1e11'])
    ax.set_ylabel('Ion exchange\nper action potential')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_xticks([0.5,2.5])
    ax.set_xticklabels(['Constant','Voltage\ndependent'])
    ax.set_xlim([-0.75,3.75])
    ax.set_xticklabels(ax.get_xticklabels(),rotation=30, ha='right')

def plot_freq_dpump(ax,I_in,f_,mean_stim,pnf):
    # show how the pump current changes the firing rate
    ax.plot((mean_stim-I_in)/brian2.uA,f_,c='k')
    ax.set_xlabel('$\Delta I_\mathrm{pump}$ [$\mu$A]')
    ax.set_ylabel('$r_e$ [Hz]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim([0,700])

def plot_prcs(ax,prcs,im,cb_lab='$\Delta I_\mathrm{pump}$ [$\mu$A]'):
    for i,prc in enumerate(prcs[::3]):
        if np.max(prc)==0:
            continue
        plt.plot(np.linspace(0,1,len(prc)),prc*brian2.mV,c=cmap(1-3*i/(len(prcs)-1)))
    ax.set_xticks([0,1])
    ax.set_xlabel('$\phi$')
    ax.set_ylabel(r'$Z(\phi)$ [$mV^{-1}$]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    cb = plt.colorbar(im,ax=ax)
    cb.set_label(cb_lab)

def plot_entrainment_range(ax, I_in, f_, ranges, mean_stim, pnf, color='k'):

    ubs = [ranges[i][1] for i in range(len(ranges))]+f_[f_>0]
    lbs = [ranges[i][0] for i in range(len(ranges))]+f_[f_>0]

    ax.scatter((mean_stim - I_in[f_>0])/brian2.uA,ubs,5,color=color,marker='v',label='$r_{pn_{max}}$')
    ax.scatter((mean_stim - I_in[f_>0])/brian2.uA,lbs,5,color=color,marker='^',label='$r_{pn_{min}}$')
    ax.fill_between((mean_stim - I_in[f_>0])/brian2.uA, lbs, ubs, alpha=0.2, color=color)

    ax.plot((mean_stim - I_in)/brian2.uA, np.ones(len(I_in))*pnf,color=color,label='$r_{pn}$')

    ax.set_ylim([0,750])
    ax.set_xlabel('$\Delta I_\mathrm{pump}$ [$\mu$A]')
    ax.set_ylabel('$r_\mathrm{pn}$ [Hz]')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.legend(frameon=False)
