import numpy as np
import scipy as sp
from scipy.integrate import odeint
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import math


###----------------- TO REPRODUCE FIGURE 4 (based on Equation 1)
###                  take > 48 hours to run; to get faster run, lower the resolution (see parameters below)

##------------------ PARAMETERS 

d = 0.1                     # basal mortality rate

gamma = 1                   # gamma value (non linear mortality function analyzed in supplementary analyses)
linearFunction = 1          # if 0: nonlinear mortality function; if 1: linear mortality function (see Figure S1)

alphaVecMin = 0             # minimum value of alpha  (competition for resources) 
alphaVecMax = 1             # maximum value of alpha 


# sensitivity to changes in the density dependence factor (s in the manuscript; here l)
# /!\ log scale; here, from s = 10^-2 to s= 10^4

lVecMin = -2                # minimum value of s
lVecMax = 4                 # maximum value of s


## Numerical integration

nbIntervalAnalytical = 201  # precision of the sensitivity analysis (number of pixels) with the analytical resolution
sizeMarkerAnalytical = 2    # size of each pixel

nbIntervalNumerical = 201   # precision of the sensitivity analysis (number of pixels) with the numerical resolution
sizeMarkerNumerical = 2     # size of each pixel

nbIntervalY0 = 20           # range of initial densities value tested (see full numerical method in Appendix S3)

tMax = 5e6                  # length of the run
epsCoex = 1e-5              # threshold below which species are declared as extinct

freqInv = 1e-12             # initial frequency of the less competitive species to test globality criterion
TBeforeInv = 1000           # burn-in period before included a rare competitor
TInv = 1                    # length of the run before assessing the invasion success


## ## ## ## ## ## ## ## DO NOT CHANGE THE CODE BELOW






#------------------ differential equation describing Equation (1) in the manuscript

def dydt(y, t, d, l, c):
    N1, N2 = y
    
    if N1<0.0:
        N1=0.0
    if N2<0.0:
        N2=0.0
    
    # term n^Gamma in nonlinear model (Figure S1)
    N1gamma = N1**gamma
    N2gamma = N2**gamma
    if math.isnan(N1gamma)==True:
        N1gamma=0.0
    if math.isnan(N2gamma)==True:
        N2gamma=0.0
    
    # D(n)
    if linearFunction==1:
        deathTerms = [d*(1-l*N1), d*(1-l*N2)]
        if deathTerms[0]<0.0:
            deathTerms[0] = 0.0
        if deathTerms[1]<0.0:
            deathTerms[1] = 0.0
    else:
        deathTerms = [d/(1+l*N1gamma), d/(1+l*N2gamma)]

    # differential equation
    dydt = [ N1*(1-N1-deathTerms[0]), N2*(1-c*N1-N2-deathTerms[1]) ]
    return dydt




#------------------ mathematic formulas derived in Appendix S2

def getN1eq(d, l, c):
    n1eq = (l-1+np.sqrt((l-1)**2+4*l*(1-d)))/(2*l)
    return n1eq

def getN2eq(d, l, c, n1eq):
    deltaN = (l*(1-c*n1eq)-1)**2 + 4*l*(1-c*n1eq-d)
    n2eq = (l*(1-c*n1eq)-1+np.sqrt(deltaN)) / (2*l)
    return n2eq

def getDerivativeN1eq(d, l, c):
    return -1/2.0*((2*d - l - 1)/np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 1/2.0*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l**2

def getDerivativeN2eq(d, l, c):
    deriv = 1/4.0*(l*(c*((2*d - l - 1)/np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l + c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l**2) - c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - (((c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 2)*l + 2)*(l*(c*((2*d - l - 1)/np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l + c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l**2) - c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l + 2) - 4*l*(c*((2*d - l - 1)/np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l + c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l**2) + 8*d + 4*c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 8)/np.sqrt(((c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 2)*l + 2)**2 - 8*(2*d + c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 2)*l) + 2)/l + 1/4.0*((c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 2)*l - np.sqrt(((c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 2)*l + 2)**2 - 8*(2*d + c*(l + np.sqrt((l - 1)**2 - 4*(d - 1)*l) - 1)/l - 2)*l) + 2)/l**2
    return deriv

def existenceEqCoexistence(d, l, c, n1eq):
    cond = 0
    X = 1-c*n1eq-d
    if ( ( l<=1/d and ( (c+d-1<= 0 and l>= 0) or (c+d-1>0 and l<c*(1-c)/(c+d-1) ) )  ) or ( l>1/d and c < (2*(l+1-2*np.sqrt(l*d))) / (l-1+np.sqrt((l-1)**2+4*l*(1-d))) ) ):
        Gamma = d*l**4 - l**3 + 2*(d-1)*l**2  - l + d
        Llim1 = (1 + np.sqrt(8*d + 1) - np.sqrt(2)*np.sqrt(-8*d**2 + np.sqrt(8*d + 1) + 4*d + 1))/(4*d)
        Llim2 = (1 + np.sqrt(8*d + 1) + np.sqrt(2)*np.sqrt(-8*d**2 + np.sqrt(8*d + 1) + 4*d + 1))/(4*d)
        if c+d-1>0:
            Llim3 = c*(c-1)/(1-c-d)
        #if l>Llim2:
        if c+d-1>0 and l>Llim3:
            cond = 4
        else:
            cond = 5
    return cond





#------------------ Numerical integration

# all the initial conditions tested to investigate global stability criterion
y0Test = np.linspace(1.0,0.0001, nbIntervalY0)
y0Test_1=[0]*nbIntervalY0*nbIntervalY0
y0Test_2=[0]*nbIntervalY0*nbIntervalY0
r_=0
for y1 in y0Test:
    for y2 in y0Test:
        y0Test_1[r_] = y1
        y0Test_2[r_] = y2
        r_ = r_ + 1


t = np.linspace(0, tMax, 2)                                 # duration of the run

alphaVec = np.linspace(alphaVecMin, alphaVecMax, nbIntervalNumerical)    # values of alpha tested
lVec = np.logspace(lVecMin, lVecMax, nbIntervalNumerical)                # values of density dependence factor (s) tested

cVecPlot = [0] * nbIntervalNumerical * nbIntervalNumerical  # variable where all alpha values will be stored
lVecPlot = [0] * nbIntervalNumerical * nbIntervalNumerical  # variable where all s values will be stored
eqPlot = [0] * nbIntervalNumerical * nbIntervalNumerical    # variable equilibrium state will be stored

tBefore = np.linspace(0, TBeforeInv, 2)                     # duration of the burn-in perion before invasion test
tInvasion = np.linspace(0, TInv, 2)                         # duration of the run during the invasion test

r_ = 0
for c in alphaVec:
    print(c)
    for l in lVec:
        if l==0:
            l=0.00001
        if c==0:
            c=0.00001
            

        eqState = 0 # extinction (if so, eqState will remain = 0)

        found = 0
        rt_ = 0 
        while found==0 and rt_<len(y0Test_1):
            y0 = [y0Test_1[rt_], y0Test_2[rt_]]         # initialization of densities
            yb = y0
            
            sol = odeint(dydt, yb, t, args=(d, l, c))   # simulation until t = Tmax
            yb = sol[1]
                
            if yb[0]>epsCoex and yb[1]>epsCoex:         # condition for coexistence
                eqState = 5
                found = 1
                
            rt_ = rt_ + 1
        
        # test for global attractor
        if eqState == 5: # is coexistence a global attractor (if so, eqState will remain = 5)?
            y0 = [1,0]
            sol = odeint(dydt, y0, tBefore, args=(d, l, c))
            y0=sol[1]
            y0[1] = freqInv
            sol = odeint(dydt, y0, tInvasion, args=(d, l, c))
            if sol[1][1]<y0[1]:
                eqState = 4 # local attractor (eqState will remain = 4)
                
        # storage of the density dependence factor and the number of remaining species
        cVecPlot[r_] = c
        lVecPlot[r_] = l
        eqPlot[r_] = eqState
        r_ = r_ + 1






#------------------ plot

cVecPlot.append(-1000)
lVecPlot.append(-1000)
eqPlot.append(5)
cVecPlot.append(-1000)
lVecPlot.append(-1000)
eqPlot.append(4)
cVecPlot.append(-1000)
lVecPlot.append(-1000)
eqPlot.append(0)

# plot
f1 = plt.figure(figsize=(5,4.5))
plt.scatter(cVecPlot, lVecPlot, c=eqPlot, s=sizeMarkerNumerical, lw = 0 ,marker="s", cmap='gray')
plt.xlim(0,1)
plt.ylim(min(lVec), max(lVec))
plt.yscale('log')
plt.rc('xtick', labelsize=30) 
plt.rc('ytick', labelsize=30) 
plt.xticks(fontsize=13, rotation=0)
plt.yticks(fontsize=14, rotation=0)

# storage of the plot in the folder "Graphes"
f1.savefig('Graphes/sensitivityNumericalRes_D'+str(d)+'_gamma'+str(gamma)+'linearFunction'+str(linearFunction)+'.png', bbox_inches='tight',dpi=300)





#------------------ Analytical integration

# analytical condition invasion
cVec_a = np.linspace(0, 1, 100)
lVec_a = cVec_a*(cVec_a-1)/(1-cVec_a-d)

# analytical no coexistence when l = 1/d
c_1d = (-(1-d)+np.sqrt((1-d)*(1+3*d))) / (2*d)

# analytical solution threshold L
Llim2 = (1 + np.sqrt(8*d + 1) + np.sqrt(2)*np.sqrt(-8*d**2 + np.sqrt(8*d + 1) + 4*d + 1))/(4*d)

# Existence coexistence equilibrium
alphaVec = np.linspace(alphaVecMin, alphaVecMax, nbIntervalAnalytical)
lVec = np.logspace(lVecMin, lVecMax, nbIntervalAnalytical)

cVecPlot = [0] * nbIntervalAnalytical * nbIntervalAnalytical
lVecPlot = [0] * nbIntervalAnalytical * nbIntervalAnalytical
eqPlot = [0] * nbIntervalAnalytical * nbIntervalAnalytical
diffNPlot = [-1] * nbIntervalAnalytical * nbIntervalAnalytical

eps = 1e-3
r_ = 0
for c in alphaVec:
    for l in lVec:
        if l==0:
            l=0.00001
        if c==0:
            c=0.00001
        n1eq = getN1eq(d, l, c)
        lEspsilon = 0.0001
        n1eqMore = getN1eq(d, l+lEspsilon, c)
        eqState = existenceEqCoexistence(d, l, c, n1eq) # coexistence
        cVecPlot[r_] = c
        lVecPlot[r_] = l
        eqPlot[r_] = eqState
        if eqState != 0:
            n2eq = getN2eq(d, l, c, n1eq)
            diffNPlotOne = - getN2eq(d, l, c, n1eq)
            diffNPlot[r_] = diffNPlotOne 
        else:
            diffNPlot[r_] = -1
        r_ = r_ + 1




#------------------ plot

f1 = plt.figure(figsize=(5,4.5))
plt.scatter(cVecPlot, lVecPlot, c=eqPlot, s=sizeMarkerAnalytical, lw = 0 ,marker="s", cmap='gray')
plt.plot((0,100), (Llim2,Llim2),'red',linestyle='--')
plt.xlim(0,1)
plt.ylim(min(lVec), max(lVec))
plt.yscale('log')
plt.rc('xtick', labelsize=30) 
plt.rc('ytick', labelsize=30) 
plt.xticks(fontsize=13, rotation=0)
plt.yticks(fontsize=14, rotation=0)
f1.savefig('Graphes/sensitivityAnalyticalRes_D'+str(d)+'.png', bbox_inches='tight',dpi=300)

cVecPlot2 = cVecPlot[:]
lVecPlot2 = lVecPlot[:]
eqPlot2 = eqPlot[:]
r_ = 0
for i in range(len(eqPlot)):
    if eqPlot[i] != 0:
        del cVecPlot2[r_]
        del lVecPlot2[r_]
        del eqPlot2[r_]
    else:
        r_ = r_ + 1
    
# plot
f1 = plt.figure(figsize=(5,4.5))
plt.scatter(cVecPlot, lVecPlot, c=diffNPlot, s=sizeMarkerAnalytical, lw = 0 ,marker="s", cmap='bwr',vmin=-1.,vmax=0.)
plt.xlim(0,1)
plt.ylim(min(lVec), max(lVec))
plt.yscale('log')
plt.rc('xtick', labelsize=30) 
plt.rc('ytick', labelsize=30) 
plt.xticks(fontsize=13, rotation=0)
plt.yticks(fontsize=14, rotation=0)
plt.scatter(cVecPlot2, lVecPlot2, c='k', s=sizeMarkerAnalytical, lw = 0 ,marker="s")

# storage of the plot in the folder "Graphes"
f1.savefig('Graphes/sensitivityAnalytical2Res_D'+str(d)+'.png', bbox_inches='tight',dpi=300)

