import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Ellipse, Rectangle

cmapb = plt.get_cmap('Blues_r')
cmapr = plt.get_cmap('Reds_r')
cmap = plt.get_cmap('tab20c')

def plot_pump_schematics(ax,h=0.3,napos=-0.75,kpos=-0.5,ppos=0.5,rmol=0.25,wr=0.4,hr=0.5,marg=0.4):

    cna = cmapb(0.6)
    ck = cmapr(0.6)
    cp = 'grey'
    cd = plt.get_cmap('Dark2')(0)

    w = h/2

    ax.plot([-1,1],[0,0],zorder=-1,c='k')

    for i,col,pos in zip(range(1,3),[ck,cna],[napos,kpos]):
        ellipse = Ellipse((pos,0),
                width= w,
                height= h,
                facecolor=col)

        ax.add_patch(ellipse)

        rectangle = Rectangle((pos-w*wr/2, - h/hr/2),
                width= w*wr,
                height= h/hr,
                facecolor='w',zorder=10)

        ax.add_patch(rectangle)

        ax.plot([pos, pos], [-h,h],zorder=100,c='k')

        if i==1:
            ax.plot(pos, h,marker='^',zorder=100,c='k')
            ax.text(pos-0.1, h,'K$^+$\ncurrent',verticalalignment='top',horizontalalignment='right')
        else:
            ax.plot(pos, -h,marker='v',zorder=100,c='k')
            ax.text(pos+0.1, -h,'Na$^+$\ncurrent',zorder=1001,verticalalignment='bottom',horizontalalignment='left')

    ax.plot([ppos,0.75],[0,-0.25],color='k',zorder=-1)

    ellipse = Ellipse((ppos,0),
            width= w*3,
            height= h*1.5,
            facecolor=cp)
    ax.add_patch(ellipse)

    ellipse = Ellipse((0.75,-0.25),
            width= h,
            height= w,
            facecolor=cd)
    ax.add_patch(ellipse)
    ax.text(0.75,-0.25,'ATP',horizontalalignment='center',verticalalignment='center')
    ax.plot([ppos,ppos], [-h,h],zorder=2000,c='k')
    ax.plot(ppos,h,marker='^',c='k')
    ax.text(ppos+0.1,h+0.1,'Pump\ncurrent',horizontalalignment='left',verticalalignment='top')

    ax.scatter([ppos-h*rmol]*3, [i*h/5 for i in range(-1,2)],10,color=cna,zorder=10)
    ax.scatter([ppos+h*rmol]*2, [i*h/7 for i in [-1,1]],10,color=ck,zorder=10)

    ax.plot([ppos-h*rmol,ppos-h*rmol],[-h*rmol*2,h*rmol*2],c=cna,linewidth=1)
    ax.plot(ppos-h*rmol,h*rmol*2,c=cna,marker='^',markersize=5)

    ax.plot([ppos+h*rmol,ppos+h*rmol],[-h*rmol*2/3*2,h*rmol*2/3*2],c=ck,linewidth=1)
    ax.plot(ppos+h*rmol,-h*rmol*2/3*2,c=ck,marker='v',markersize=5)

    ax.scatter(np.random.random(200)*2-1,np.random.random(200)*marg,1,color=cna,zorder=-10)
    ax.scatter(np.random.random(20)*2-1,-np.random.random(20)*marg,1,color=cna,zorder=-10)
    ax.scatter(np.random.random(200)*2-1,-np.random.random(200)*marg,1,color=ck,zorder=-10)
    ax.scatter(np.random.random(20)*2-1,np.random.random(20)*marg,1,color=ck,zorder=-10)

    ax.text(-1,-0.5,'Intracellular space',horizontalalignment='left',verticalalignment='bottom')
    ax.text(-1,0.5,'Extracellular space',horizontalalignment='left',verticalalignment='top')

    ax.set_xlim([-1,1])
    ax.set_ylim([-0.5,0.5])
    ax.set_xticks([])
    ax.set_yticks([])

def plot_image(ax, image):
    ax.imshow(image)
    ax.axis('off')

def plot_circuit(ax,my=0.25,ms=200,lr=True):
    if lr:
        ax.scatter(0,0,ms,color=cmap(1),clip_on=False)
        ax.scatter(1,0,ms,color='k',clip_on=False)
        ax.plot([0,0.8],[0,0],color=cmap(1))
        ax.scatter(0.8,0,200,marker='<',color=cmap(1))
        ax.text(0,0,'PN',color='k',horizontalalignment='center',verticalalignment='center')
        ax.text(1,0,'Electrocyte',color='w',horizontalalignment='center',verticalalignment='center')
        ax.set_ylim([-0.25,0.25])
        ax.set_xlim([-0.25,1.25])
    else:

        # add background with perturbation.
        
        ax.plot([0,0],[0.25,1],color=cmap(1))
        ax.scatter(0,my,ms,marker='^',color=cmap(1))
        t = ax.text(0,1,'PN',color='k',horizontalalignment='center',verticalalignment='center',clip_on=False)
        t.set_bbox(dict(facecolor=cmap(1)))
        t2 = ax.text(0,0,'Electrocyte',color='w',horizontalalignment='center',verticalalignment='center',clip_on=False)
        t2.set_bbox(dict(facecolor='k'))
        ax.set_ylim([-0.25,1.25])
        ax.set_xlim([-0.1,0.1])
    ax.axis('off')
