import numpy as np
from scipy import stats
import scipy as sci
import pandas as pd
import matplotlib.pyplot as plt
def v(bio_n,bio_p,num_pfs):
    return stats.binom(bio_n,bio_p).rvs(num_pfs)
def f(x, threshold,incor_prob): ## x is int, threshold is x, thre_prob is a float number between (0.0,1.0)
    """Define a universial function to test whether a pf could incorporate into ring. x represents the counts of mutant subunits. 
    When x<threshold, the pf has incor_prob chance incorporating into ring. When x>threshold, the pf has thre_prob chance incorporating into ring."""
    if x<threshold:
        p=np.random.rand()
        if p<incor_prob:
            return x
        else:
            return 0
    else:
        p=np.random.rand()
        if p<1-incor_prob:
            return x
        else:
            return 0
f_ufunc=np.frompyfunc(f,3,1)
def anaEachtrial(a,num_pfs_ring):
    #background=np.sum(a)
    ringpfs=np.count_nonzero(a) ##pfs that could incorporate into ring
    if ringpfs>num_pfs_ring:
        temp=a[a!=0]
        ringfluo=np.sum(temp[:num_pfs_ring])
        #print(ringfluo)
        return ringfluo
    else:
        return 0.0  
def main(Ln,Frac,Totpfs,pfsRing,threMut,incPro,trialTimes):
    a=np.array([v(Ln,Frac,Totpfs) for x in range(trialTimes)])   ##generate pfsTot pfs for trialTimes times, a trialTimes * totpfs array
    #print(a.shape)
    background=np.sum(a,axis=1)
    a_ring=f_ufunc(a,threMut,incPro)                            ##test whether each pf could incorporate into ring
    #print(a_ring.shape)
    res=np.array([anaEachtrial(a_ring[i],pfsRing) for i in range(trialTimes)])/background
    nonzerores=res[np.nonzero(res)]
    #print(res)
    return np.mean(res),np.std(res),len(nonzerores)/trialTimes
trialTimes=1000##Trial times
LN=50            ##Subunits in a pf
fs=np.arange(0.1,1.0,0.1)        ##Fraction of mutant subunits
Totpfs=200       ##Total number of pfs 
pfsRing=20       ##Number of pfs required to form a ring
tM=np.arange(10,LN,5) ## threshold number
incPro=0.99      ## incorporated probability for pf with less than threMut mutant subunits
aver_res=np.empty((len(fs),len(tM)))
std_res=np.empty((len(fs),len(tM)))
ringform_res=np.empty((len(fs),len(tM)))
for i in range(len(fs)):
    for j in range(len(tM)):
        aver_res[i,j],std_res[i,j],ringform_res[i,j]=main(LN,fs[i],Totpfs,pfsRing,tM[j],incPro,trialTimes)    
adf=pd.DataFrame(aver_res,columns=tM,index=fs)
sdf=pd.DataFrame(std_res,columns=tM,index=fs)
import matplotlib.patches as mpatches
colors=['gold','coral','darkorange','violet','mediumpurple','darkcyan','olive','yellowgreen','lightgreen']
fig,ax1=plt.subplots(1,1)
fig.set_size_inches(20,14)
legends=[]
labels=[]
step=1
for i in range(len(tM)-1,-1,-step):
    t=tM[i]
    legend_patch=mpatches.Patch(color=colors[i],fill=True)
    ax1.errorbar(fs,adf[t],yerr=sdf[t],fmt='o-.',linewidth=5.0,markersize=20.0,color=colors[i])
    legends.append(legend_patch)
    labels.append('Threshold is %s'%t)
ax1.tick_params(which='both',labelsize=28)
ax1.set_xlabel("The fraction of laterally disruptive subunits.\n",fontsize=35)
ax1.set_ylabel("Laterally disruptive subunits in ring / Total\n",fontsize=35)
ax2=fig.add_axes([0.95,0.35,0.20,0.3],frameon=False)
ax2.xaxis.set_ticks([])
ax2.yaxis.set_ticks([])
ax2.legend(handles=legends,labels=labels,loc='center',fontsize=25,markerscale=10.0)
fig.savefig("LN%s_Totpf%s_iIncPro%s_step%s.png"%(LN,Totpfs,incPro,step),dpi=400,format='png',bbox_inches='tight')
colors=['gold','coral','darkorange','violet','mediumpurple','darkcyan','olive','yellowgreen','lightgreen']
fig,ax1=plt.subplots(1,1)
fig.set_size_inches(20,14)
legends=[]
labels=[]
for i in range(len(tM)-1,-1,-step):
    #print(i)
    t=tM[i]
    ax1.plot(fs,ringform_res[:,i],'-.',color=colors[i],linewidth=6,markersize=20)
    legend_patch=mpatches.Patch(color=colors[i],fill=True)
    legends.append(legend_patch)
    labels.append('Threshold is %s'%t)
ax2=fig.add_axes([0.98,0.35,0.20,0.3],frameon=False)
ax1.tick_params(which='both',labelsize=28)
ax1.set_xlabel("The fraction of laterally disruptive subunits\n",fontsize=28)
ax1.set_ylabel("The percentage of trials that Z-ring forms\n",fontsize=28)
ax1.xaxis.set_ticks(fs)
ax2.xaxis.set_ticks([])
ax2.yaxis.set_ticks([])
ax2.legend(handles=legends,labels=labels,loc='center',fontsize=25,markerscale=10.0)
fig.savefig("Pencentage_LN%s_Totpf%s_iIncPro%s_step%s.png"%(LN,Totpfs,incPro,step),dpi=400,format='png',bbox_inches='tight')