

rm(list = ls())

T_model <- function(t,x,params){
  with(as.list(x),{   
    # Set current state values
    
    
    #(a) Total body irradiation
    TBItime = Transplant-TBI
    if(t>=(TBItime) & (t<Transplant)){
      psi = 1-Tx
    }
    else {
      psi = 1
    }
    
    #(b) ART
    if(t<ART){
      alphaL = 0
    }
    else{
      alphaL = 1/tsa
    }
    
    
    
    #(c) transplantation/ATI
    
    if(t<Transplant)  {
      rp = rp1
      lambdap = lambdap1
      Kp = 10^log10Kp
    }
    else{
      rp = rp1*exp(crp*Tx)
      lambdap = lambdap1*exp(clambdap*Tx)
      Kp = 10^(log10Kp+cKp*Tx)
    }
    Ks = 10^(log10Kp-log10Kps)
    Km = 10^(log10Kp-log10Kpm)
    Ke = 10^(log10Kp-log10Kpe)
    
    if(t<Transplant){
      omega8 = 10^log10omega8
      omega4 = 10^log10omega4
      I50 = 10^log10I50
      dh = de*exp(0)
    }
    else{
      omega8 = 10^(log10omega8+cw8)#*(1-Tx)-cw8*Tx)
      omega4 = 10^(log10omega4+cw4)
      I50 = 10^(log10I50+cI50)
      dh = de*exp(cde)
    }
    
    #(d) ATI
    if(t>=ART & t<(ATI+twash+tsa)){
      epsilon= 1
    }
    else{
      epsilon =0
    }
    
    
    Total = NSp+NSp2+NS+S+M+E+Es+Ip+Iu+Mp
    Infected = Ip+Iu+Mp
    
    ddt_Tp   =  -ke*Tp
    ddt_Pp   =  ke*Tp + rp*(1-(Total)/Kp)*Pp
    ddt_NSp  =  lambdapns*Pp  - dnsp*NSp + lambdasns*NSp2
    ddt_NSp2 =  lambdanss*NSp  + rs*(1-(Total)/Ks)*NSp2
    
    
    ddt_T   =  -ke*T
    ddt_P   =  ke*T + rp*(1-(Total)/Kp)*P - condP*(1-psi)*kt*P
    
    ddt_NS  =  lambdapns*P  - dns*NS + lambdasns*S - omega4*NS*(Infected)/(1+(Infected)/I50) - (1-psi)*kt*NS
    ddt_S   =  lambdanss*NS + rs*(1-(Total)/Ks)*S - (1-epsilon)*beta*V*S/(1+phi*Es) + omega4*NS*(Infected)/(1+(Infected)/I50) - (1-psi)*kt*S
    
    ddt_Mp = fi*xi*(1-epsilon)*beta*V*S/(1+phi*Es) - deltaM*(1+kappa*Es)*Mp - (1-psi)*kt*Mp
    
    ddt_Ip  =  (1-fi)*xi*(1-epsilon)*beta*V*S/(1+phi*Es) - deltaP*(1+kappa*Es)*Ip + alphaL - (1-psi)*kt*Ip
    ddt_Iu  =  (1-xi)*(1-epsilon)*beta*V*S/(1+phi*Es) - deltaU*(1+kappa*Es)*Iu  - (1-psi)*kt*Iu
    ddt_V   =  p*(Ip+Mp)/(1+theta*Es) - c*V
    
    ddt_M   =  lambdapm*(P+Pp) + rm*(1-(Total)/Km)*M + omega8*(1-2*f)*M*(Infected)/(1+(Infected)/I50) - (1-psi)*kt*M
    ddt_E   =  lambdame*M + re*(1-(Total)/Ke)*E - (1-psi)*kt*E
    
    ddt_Es  =  omega8*f*M*(Infected)/(1+(Infected)/I50)-dh*Es - (1-psi)*kt*Es
    
    der <- c(ddt_Tp,ddt_Pp,ddt_NSp,ddt_NSp2,ddt_T,ddt_P,ddt_NS,ddt_S,ddt_M,ddt_E,ddt_Ip ,ddt_Iu,ddt_V,ddt_Es,ddt_Mp)
    
    #print(x)
    list(der)
  })       
}



#load R library for ordinary differential equation solvers
library(deSolve)
library(rstudioapi)
###########################
# Reading the data
###########################

setwd(dirname(getActiveDocumentContext()$path ))
data <- read.csv("Figure 4-source data.csv", header=TRUE,stringsAsFactors=FALSE)

##

IDs = unique(data[,"ID"])
ytypes = unique(data[,"ytype_def"])
pchID = c(21,22,23,24,25,21,22,23,24,25,21,22,23,24,25,21,22,23,24,25,21,22)

coloresID = c("steelblue4","slateblue3","navy","lightsteelblue4","dodgerblue2", 
              "salmon4","indianred3","brown2","coral1","chocolate2",
              "aquamarine3","darkolivegreen3","cadetblue4","chartreuse3","forestgreen","seagreen3",
              "palegreen3","springgreen4","darkseagreen4","green4","greenyellow","lightgreen")



###########################
# Reading estimates
###########################

estimates <- read.table("Figure 5-source data 4.txt", header=TRUE,stringsAsFactors=FALSE,sep=",")


Rt_vals= c(0,0,0,0,0, 
           6.45e6,4.14e6,2.08e6,2.39e6,6.28e6,
           1.06e7,1.6e7,6.7e6,1.48e7,1.2e7,6.6e7,
           6.2e6,8.25e6,2.3e6,4.82e6,6.6e6,6.6e6)

ID = IDs[11]
i = which(IDs %in% ID)

##########################
# 1. Reading data
##########################
Vdata = 10^as.numeric(data[which(ID==data[,"ID"]& data[,"ytype_def"]=="VL"),"y"])
tVdata = as.numeric(data[which(ID==data[,"ID"] & data[,"ytype_def"]=="VL"),"time"])

##########################
# 2. Reading parameters
##########################

fp100 = as.numeric(data[which(ID==data[,"ID"] ),11])[1]
ART = as.numeric(data[which(ID==data[,"ID"] ),12])[1]
Transplant = as.numeric(data[which(ID==data[,"ID"] ),13])[1]
ATI = as.numeric(data[which(ID==data[,"ID"] ),10])[1]
Tx = as.numeric(data[which(ID==data[,"ID"] ),14])[1]


tstart = 0#
tend = ATI+7*52*1.5
t.out <- seq(tstart,tend,by=1)

W = 5
pb = 1
VolB = 60*W*1e3

Rt =W*Rt_vals[i]


ke	= estimates[i,"ke_mode"]
rp1	= median(estimates[i,"rp1_mode"])
crp	= median(estimates[i,"crp_mode"])
rs	= median(estimates[i,"rs_mode"])
rm	= median(estimates[i,"rm_mode"])
re	=median( estimates[i,"re_mode"])
qdns	=median( estimates[i,"qdns_mode"])


lambdap1	= median(estimates[i,"lambdap1_mode"])
clambdap	= median(estimates[i,"clambdap_mode"])
lambdanss	= median(estimates[i,"lambdanss_mode"])
lambdasns	= median(estimates[i,"lambdasns_mode"])
lambdame	= median(estimates[i,"lambdame_mode"])


rp = rp1
lambdap = lambdap1

log10Kp	= median(estimates[i,"log10Kp_mode"])#
cKp	= median(estimates[i,"cKp_mode"])#


log10Kpns	= median(estimates[i,"log10Kpns_mode"])

log10Kps	= median(estimates[i,"log10Kps_mode"])
log10Kpm	= median(estimates[i,"log10Kpm_mode"])
log10Kpe	= median(estimates[i,"log10Kpe_mode"])

condP = median(estimates[i,"condP_mode"])
kt	= median(estimates[i,"kt_mode"])
V0	= median(estimates[i,"V0_mode"])

xi= estimates[i,"xi_mode"]

fi= median(estimates[i,"fi_mode"])
deltaM= median(estimates[i,"deltaM_mode"])

log10beta= median(estimates[i,"log10beta_mode"])
tsa= median(estimates[i,"tsa_mode"])
twash= estimates[i,"twash_mode"]
deltaP= estimates[i,"deltaP_mode"]

log10p1= median(estimates[i,"log10p1_mode"])
log10kappa= estimates[i,"log10kappa_mode"]
log10phi= estimates[i,"log10phi_mode"]
log10theta= median(estimates[i,"log10theta_mode"])

log10omega4= median(estimates[i,"log10omega4_mode"])
cw4= median(estimates[i,"cw4_mode"])

log10omega8=median(estimates[i,"log10omega8_mode"])
cw8= median(estimates[i,"cw8_mode"])
log10I50= median(estimates[i,"log10I50_mode"])
cI50= median(estimates[i,"cI50_mode"])
de= median(estimates[i,"de_mode"])
cde= median(estimates[i,"cde_mode"])

GM = estimates[i,"GM_mode"]

W = 5
pb = 1
VolB = 60*W*1e3


TBI = 5
c = 23
f=0.9
deltaU = deltaP

dns = (qdns+1)*lambdanss
dnsp = qdns*lambdanss

lambdapns = lambdap1
lambdapm = lambdap1

Kp = 10^log10Kp
Ks = 10^(log10Kp-log10Kps)
Km = 10^(log10Kp-log10Kpm)
Ke = 10^(log10Kp-log10Kpe)

p = 10^log10p1
beta = 10^log10beta/VolB
kappa = 10^log10kappa/VolB
phi = 10^log10phi/VolB
theta = 10^log10theta/VolB
I50 = 10^log10I50

qNSS = (rs/(lambdanss))*(Kp/Ks-1)
qPS  = (1/lambdapns)*(dns*qNSS-lambdasns)
qPM  = (1/lambdapm)*(rm*(Kp/Km-1))
qME  = (re/lambdame)*(Kp/Ke-1)

##########################
# 3. Initial values
##########################

Vzero = V0*1e3
E_0  = Kp/(qNSS*qPM*qME/qPS + qPM*qME/qPS + qME +1)
M_0  = qME*E_0
S_0  = (qPM/qPS)*M_0
NS_0 = qNSS*S_0
P_0  = qPS*S_0
Es_0 = 0
Pp_0 = 0
NSp_0 = 0
NSp2_0 = 0
Tp_0 = 0
T_0 = 0
Ip_0 = xi*c*Vzero/p
Iu_0 = (1-xi)*c*Vzero/p
V_0 = Vzero
Mp_0 = 0

######################################################
# 4. Running the model for the different fp values
######################################################
# 
fp_vals =c(0,0.75,0.8,0.85,0.95,1)
R0_vals=c()
j=1
for (fp in fp_vals){
  
  
  eventdat <- data.frame(var = c("T", 
                                 "Tp"),
                         time = c(Transplant,Transplant) ,
                         value = Rt*c(1-fp,fp),
                         method = c("add", "add"))
  
  params=c()
  init.x <- c(Tp=Tp_0,Pp=Pp_0,NSp=NSp_0,NSp2=NSp2_0,
              T=T_0,P=P_0,NS=NS_0,S=S_0,M=M_0,E=E_0,
              Ip=Ip_0,Iu=Iu_0,V=V_0,Es=Es_0,Mp=Mp_0)
  out <- as.data.frame(lsodar(init.x,t.out,T_model,params,
                              events = list(data = eventdat))) 
  
  t.out = out[,1]
  Tp = out[,2]
  Pp  = out[,3]
  NSp = out[,4]
  NSp2 = out[,5]
  T  = out[,6]
  P  = out[,7]
  NS = out[,8]
  S = out[,9]
  M =  out[,10]
  E =  out[,11]
  Ip = out[,12]
  Iu = out[,13]
  V = out[,14]
  Es = out[,15]
  Mp = out[,16]
  
  
  S_b = pb*(S+Ip+Iu+Mp)/VolB
  NS_b = pb*(NS+NSp+NSp2)/VolB
  CD8_b =pb*(M+E+Es)/VolB
  E_b = pb*(E+Es)/VolB
  M_b = pb*(M)/VolB
  V_b = V/(VolB/10^3)
  
  ## Computing Reff:
  
  w4 = 10^(log10omega4+cw4)
  w8 = 10^(log10omega8+cw8)
  dh = de*exp(cde)
  I50c = 10^(log10I50+cI50)

  R = min(P)
  D=Rt

  fps=fp
  alphaL = 1/tsa
  I_ati=alphaL/deltaP
  
  qsp = (lambdapns)/( ( (dns+w4*I_ati/(1+I_ati/I50c))*rs*(Kp/Ks-1)/(lambdanss+w4*I_ati/(1+I_ati/I50c)))-lambdasns)
  qmp = lambdapm/(rm*(Kp/Km-1)+(2*f-1)*w8*I_ati/(1+I_ati/I50c))
  qnp = (lambdapns+lambdasns*qsp)/(dns+w4*I_ati/(1+I_ati/I50c))
  
  qnpp = lambdapns/dnsp
  qep = lambdame*qmp/(re*(Kp/Ke-1))
  qehp = w8*f*I_ati*qmp/(dh*(1+I_ati/I50c))
  
  a = qmp +qep+qnp+qsp+qehp
  b = qnpp+qmp+qep+qehp
  P_ati = Kp/(a+b*fps*D/(R+D*(1-fps)))
  Pp_ati = Kp/(a*((1-fps)*D+R)/(fps*D)+b)
  M_ati = qmp*(P_ati+Pp_ati)
  S_ati = qsp*P_ati
  N_ati = qnp*P_ati
  Np_ati = qnpp*Pp_ati
  
  E_sati = qehp*(P_ati+Pp_ati)
  
  R0 = xi*beta*S_ati*p/(c*deltaP*(1+theta*E_sati)*(1+phi*E_sati)) # Reff
  
  #########################
  # 5. Plotting Viral load
  #########################
  
  if(j==1){
    
    nrows = 2
    ncolumns =1
    
    par(mfrow = c(nrows,ncolumns),
        oma = c(0, 0, 0, 0), #
        mar = c(3.8, 4.0, 0.0, 0.5),
        mgp = c(2.5, 1, 0),
        xpd = FALSE)
    
    plot(t.out/7,V_b,
         ylab="Viral load (copies/ml)",
         xlab="",
         type="l",log="y",ylim = c(10^-1,10^8),xlim=c(0,tend/7),
         cex.main=1.2,cex.axis=1.2,cex.lab=1.2,
         bty="n",xaxt="n",yaxt="n",col=paste0("gray",round(fp*50,digits=0)))
    text(0.8*tend/7,0.3,ID)
    points(tVdata/7,Vdata,cex=1.0,type="p",
           lwd=1,pch=pchID[i],bg=coloresID[i])
    abline(h=30,lty=3,col="blue")
    
    axis(2,at=c(0.1,10,1e3,1e5,1e7),labels=expression(bold("0.1"),bold("10"),bold("10"^"3"),bold("10"^"5"),bold("10"^"7")),
         cex.axis=1.2,font=2,lwd=2)
    axis(1,at=axTicks(1),labels=as.character(axTicks(1)),cex.axis=1.2,font=2,lwd=2,las=1)
    R0_vals=c(R0_vals,R0)
  }
  else{
    lines(t.out/7,V_b,col=paste0("gray",round(fp*50,digits=0)),lty=j)
    R0_vals=c(R0_vals,R0)
  }
  
  abline(v=ART/7,lty=3)
  abline(v=ATI/7,lty=3)
  
  abline(v=Transplant/7,lty=4)

  j=j+1
} 

legend(ART/7+2.4,10^8.6,cex=0.75,bty="n",
       c(as.expression(bquote(italic(R[eff]) == .(round(R0_vals[1], 2))~"   ("~italic(f[p]) == .(round(fp_vals[1], 2))~")")),#,
         as.expression(bquote(italic(R[eff]) == .(round(R0_vals[2], 2))~" ("~italic(f[p]) == .(round(fp_vals[2], 2))~")")),#,
         as.expression(bquote(italic(R[eff]) == .(round(R0_vals[3], 2))~" ("~italic(f[p]) == .(round(fp_vals[3], 2))~")")),#,
         as.expression(bquote(italic(R[eff]) == .(round(R0_vals[4], 2))~" ("~italic(f[p]) == .(round(fp_vals[4], 2))~")")),#,
         as.expression(bquote(italic(R[eff]) == .(round(R0_vals[5], 2))~" ("~italic(f[p]) == .(round(fp_vals[5], 2))~")")),#,
         as.expression(bquote(italic(R[eff]) == .(round(R0_vals[6], 1))~"   ("~italic(f[p]) == .(round(fp_vals[6], 2))~")"))
       ),
       lty=c(1:6),
       col=c(paste0("gray",round(fp_vals[1]*50,digits=0)),
             paste0("gray",round(fp_vals[2]*50,digits=0)),
             paste0("gray",round(fp_vals[3]*50,digits=0)),
             paste0("gray",round(fp_vals[4]*50,digits=0)),
             paste0("gray",round(fp_vals[5]*50,digits=0)),
             paste0("gray",round(fp_vals[6]*50,digits=0)))
)
text(120,1e7,expression(italic(D)==10^7~HSPCs/kg),adj=0,cex=0.8)
text(120,1.5e6,expression(italic(P[r])==6~x~10^6~HSPCs),adj=0,cex=0.8)


###################################################
# 6. Plotting CD4+ T cell prediction when controlling
###################################################
fp=0.95 # Reff=0.7

eventdat <- data.frame(var = c("T", "Tp"),
                       time = c(Transplant,Transplant) ,
                       value = Rt*c(1-fp,fp),
                       method = c("add", "add"))

params=c()
init.x <- c(Tp=Tp_0,Pp=Pp_0,NSp=NSp_0,NSp2=NSp2_0,
            T=T_0,P=P_0,NS=NS_0,S=S_0,M=M_0,E=E_0,
            Ip=Ip_0,Iu=Iu_0,V=V_0,Es=Es_0,Mp=Mp_0)
out <- as.data.frame(lsodar(init.x,t.out,T_model,params,
                            events = list(data = eventdat))) 

t.out = out[,1]
Tp = out[,2]
Pp  = out[,3]
NSp = out[,4]
NSp2 = out[,5]
T  = out[,6]
P  = out[,7]
NS = out[,8]
S = out[,9]
M =  out[,10]
E =  out[,11]
Ip = out[,12]
Iu = out[,13]
V = out[,14]
Es = out[,15]
Mp = out[,16]


S_b = pb*(S+Ip+Iu+Mp)/VolB
NS_b = pb*(NS+NSp+NSp2)/VolB
CD8_b =pb*(M+E+Es)/VolB
E_b = pb*(E+Es)/VolB
M_b = pb*(M)/VolB
V_b = V/(VolB/10^3)

NSp/(NS+NSp)

plot(t.out/7,S_b,
     ylab=expression(paste("CD4"^"+"," T cells/",mu,"L")),
     xlab="Weeks after challenge",lwd=1,
     cex.main=1.2,cex.axis=1.2,cex.lab=1.2,
     type="l",log="y",ylim = c(10^0,10^3.3),lty=3,
     yaxt="n",xaxt="n",bty="n")#,col="brown3")

lines(t.out/7,NS/VolB,lty=2,lwd=1)
lines(t.out/7,(NSp+NSp2)/VolB,lty=1,lwd=1)

axis(2,at=c(1,10,100,1000),labels=expression(bold("1"),bold("10"),bold("10"^"2"),bold("10"^"3")),
     cex.axis=1.2,font=2,lwd=2)
axis(1,at=axTicks(1),labels=as.character(axTicks(1)),cex.axis=1.2 ,font=2,lwd=2,las=1)

abline(v=ART/7,lty=3)
abline(v=ATI/7,lty=3)
abline(v=Transplant/7,lty=4)

text(120,1000,expression(italic(R[eff])~"=0.7"),adj=0,cex=0.8)
legend(90,12,c(expression(paste(Delta,"CCR5-edited")),
               expression("CCR5"^"-"),
               expression("CCR5"^"+")),
       lty=c(1,2,3),lwd=1,bty="n",cex=0.8)



