

rm(list = ls())

T_model <- function(t,x,params){
  with(as.list(x),{   
    # Set current state values
  
    #(a) Total body irradiation
    TBItime = Transplant-TBI
    if(t>=(TBItime) & (t<Transplant)){
      psi = 1-Tx
    }
    else {
      psi = 1
    }
    
    #(b) ART
    if(t<ART){
      alphaL = 0
    }
    else{
      alphaL = 1/tsa
    }
    
    #(c) transplantation/ATI
    
    if(t<Transplant)  {
      rp = rp1
      lambdap = lambdap1
      Kp = 10^log10Kp
    }
    else{
      rp = rp1*exp(crp*Tx)
      lambdap = lambdap1*exp(clambdap*Tx)
      Kp = 10^(log10Kp+cKp*Tx)
    }
    Ks = 10^(log10Kp-log10Kps)
    Km = 10^(log10Kp-log10Kpm)
    Ke = 10^(log10Kp-log10Kpe)
    
    if(t<Transplant){
      omega8 = 10^log10omega8
      omega4 = 10^log10omega4
      I50 = 10^log10I50
      dh = de*exp(0)
    }
    else{
      omega8 = 10^(log10omega8+cw8)
      omega4 = 10^(log10omega4+cw4)
      I50 = 10^(log10I50+cI50)
      dh = de*exp(cde)
    }
    
    #(d) ATI
    if(t>=ART & t<(ATI+twash+tsa)){
      epsilon= 1
    }
    else{
      epsilon =0
    }
    
    Total = NSp+NSp2+NS+S+M+E+Es+Ip+Iu+Mp
    Infected = Ip+Iu#+Mp
    
    ddt_Tp   =  -ke*Tp
    ddt_Pp   =  ke*Tp + rp*(1-(Total)/Kp)*Pp
    ddt_NSp  =  lambdapns*Pp  - dnsp*NSp + lambdasns*NSp2
    ddt_NSp2 =  lambdanss*NSp  + rs*(1-(Total)/Ks)*NSp2
    
    ddt_T   =  -ke*T
    ddt_P   =  ke*T + rp*(1-(Total)/Kp)*P - condP*(1-psi)*kt*P
    
    ddt_NS  =  lambdapns*P  - dns*NS + lambdasns*S - omega4*NS*(Infected)/(1+(Infected)/I50) - (1-psi)*kt*NS
    ddt_S   =  lambdanss*NS + rs*(1-(Total)/Ks)*S - (1-epsilon)*beta*V*S/(1+phi*Es) + omega4*NS*(Infected)/(1+(Infected)/I50) - (1-psi)*kt*S
    
    ddt_Mp = fi*xi*(1-epsilon)*beta*V*S/(1+phi*Es) - deltaM*(1+kappa*Es)*Mp - (1-psi)*kt*Mp
    
    ddt_Ip  =  (1-fi)*xi*(1-epsilon)*beta*V*S/(1+phi*Es) - deltaP*(1+kappa*Es)*Ip + alphaL - (1-psi)*kt*Ip
    ddt_Iu  =  (1-xi)*(1-epsilon)*beta*V*S/(1+phi*Es) - deltaU*(1+kappa*Es)*Iu  - (1-psi)*kt*Iu
    ddt_V   =  p*(Ip+Mp)/(1+theta*Es) - c*V
    
    ddt_M   =  lambdapm*(P+Pp) + rm*(1-(Total)/Km)*M + omega8*(1-2*f)*M*(Infected)/(1+(Infected)/I50) - (1-psi)*kt*M
    ddt_E   =  lambdame*M + re*(1-(Total)/Ke)*E - (1-psi)*kt*E
    
    ddt_Es  =  omega8*f*M*(Infected)/(1+(Infected)/I50)-dh*Es - (1-psi)*kt*Es
    
    der <- c(ddt_Tp,ddt_Pp,ddt_NSp,ddt_NSp2,ddt_T,ddt_P,ddt_NS,ddt_S,ddt_M,ddt_E,ddt_Ip ,ddt_Iu,ddt_V,ddt_Es,ddt_Mp)
    
    #print(x)
    list(der)
  })       
}

#load R library for ordinary differential equation solvers
library(deSolve)
library(rstudioapi)
###########################
# Reading the data
###########################

setwd(dirname(getActiveDocumentContext()$path ))
data <- read.csv("Figure 4-source data.csv", header=TRUE,stringsAsFactors=FALSE)

IDs = unique(data[,"ID"])
ytypes = unique(data[,"ytype_def"])
pchID = c(21,22,23,24,25,21,22,23,24,25,21,22,23,24,25,21,22,23,24,25,21,22)

coloresID = c("steelblue4","slateblue3","navy","lightsteelblue4","dodgerblue2", 
              "salmon4","indianred3","brown2","coral1","chocolate2",
              "aquamarine3","darkolivegreen3","cadetblue4","chartreuse3","forestgreen","seagreen3",
              "palegreen3","springgreen4","darkseagreen4","green4","greenyellow","lightgreen")


###########################
# Reading estimates
###########################

estimates <- read.table("Figure 5-source data 4.txt", header=TRUE,stringsAsFactors=FALSE,sep=",")
tstart = 0
tend = 7*90
t.out <- seq(tstart,tend,by=1)

animals=c(1,7,11)

nrows = 3
ncolumns =3

par(mfrow = c(nrows,ncolumns),     
    oma = c(0, 0, 0, 0), # 
    mar = c(3.5, 4.5, 0.0, 0.5), 
    mgp = c(2.5, 1, 0),    
    xpd = FALSE)

Rt_vals= c(0,0,0,0,0, 
           6.45e6,4.14e6,2.08e6,2.39e6,6.28e6,
           1.06e7,1.6e7,6.7e6,1.48e7,1.2e7,6.6e7,
           6.2e6,8.25e6,2.3e6,4.82e6,6.6e6,6.6e6)



for(ID in IDs[animals]){

  ##########################
  # Reading data
  ##########################
  
  # Viral load
  Vdata = 10^as.numeric(data[which(ID==data[,"ID"]& data[,"ytype_def"]=="VL"),"y"])
  tVdata = as.numeric(data[which(ID==data[,"ID"] & data[,"ytype_def"]=="VL"),"time"])
  
  #Cd4+CCR5+
  CD4CCR5pd = 10^as.numeric(data[which(ID==data[,"ID"]& data[,"ytype_def"]=="cd4+ccr5+"),"y"])
  tCD4CCR5pd = as.numeric(data[which(ID==data[,"ID"] & data[,"ytype_def"]=="cd4+ccr5+"),"time"])
  
  #Cd4+CCR5-
  CD4CCR5nd = 10^as.numeric(data[which(ID==data[,"ID"]& data[,"ytype_def"]=="cd4+ccr5-"),"y"])
  tCD4CCR5nd = as.numeric(data[which(ID==data[,"ID"] & data[,"ytype_def"]=="cd4+ccr5-"),"time"])
  

  i = which(IDs %in% ID)
  
  tstart = 0# t0 is a reserved keyword (HSCT)
  if (i<=5){
    tend = 7*(25+55+52*0.3)
  }
  else{
    tend = 7*(25+55+52*1)
  }
  
  t.out <- seq(tstart,tend,by=1)
  
  # Parameters
  
  ATI = as.numeric(data[which(ID==data[,"ID"] ),10])[1]
  fp100 =as.numeric(data[which(ID==data[,"ID"] ),11])[1]
  ART = as.numeric(data[which(ID==data[,"ID"] ),12])[1]
  Transplant = as.numeric(data[which(ID==data[,"ID"] ),13])[1]
  Tx = as.numeric(data[which(ID==data[,"ID"] ),14])[1]

  W = 5
  pb = 1
  VolB = 60*W*1e3
  
  Rt =W*Rt_vals[i]
  
  
  ke	= estimates[i,"ke_mode"]
  rp1	= median(estimates[i,"rp1_mode"])
  crp	= median(estimates[i,"crp_mode"])
  rs	= median(estimates[i,"rs_mode"])#3*rns#
  rm	= median(estimates[i,"rm_mode"])
  re	=median( estimates[i,"re_mode"])
  qdns	=median( estimates[i,"qdns_mode"])
  
  
  lambdap1	= median(estimates[i,"lambdap1_mode"])
  clambdap	= median(estimates[i,"clambdap_mode"])
  lambdanss	= median(estimates[i,"lambdanss_mode"])
  lambdasns	= median(estimates[i,"lambdasns_mode"])
  lambdame	= median(estimates[i,"lambdame_mode"])
  
  rp = rp1
  lambdap = lambdap1
  
  log10Kp	= median(estimates[i,"log10Kp_mode"])#
  cKp	= median(estimates[i,"cKp_mode"])#
  
  
  log10Kpns	= median(estimates[i,"log10Kpns_mode"])
  
  log10Kps	= median(estimates[i,"log10Kps_mode"])
  log10Kpm	= median(estimates[i,"log10Kpm_mode"])
  log10Kpe	= median(estimates[i,"log10Kpe_mode"])
  
  
  condP = median(estimates[i,"condP_mode"])
  kt	= median(estimates[i,"kt_mode"])
  V0	= median(estimates[i,"V0_mode"])
  
  xi= estimates[i,"xi_mode"]
  
  fi= median(estimates[i,"fi_mode"])
  deltaM= median(estimates[i,"deltaM_mode"])
  
  log10beta= median(estimates[i,"log10beta_mode"])
  tsa= median(estimates[i,"tsa_mode"])
  twash= estimates[i,"twash_mode"]
  deltaP= estimates[i,"deltaP_mode"]
  
  
  log10p1= estimates[i,"log10p1_mode"]
  log10kappa= estimates[i,"log10kappa_mode"]
  log10phi= estimates[i,"log10phi_mode"]
  log10theta= median(estimates[i,"log10theta_mode"])
  
  log10omega4= median(estimates[i,"log10omega4_mode"])
  cw4= median(estimates[i,"cw4_mode"])
  
  log10omega8= median(estimates[i,"log10omega8_mode"])
  cw8= median(estimates[i,"cw8_mode"])
  log10I50= median(estimates[i,"log10I50_mode"])
  cI50= median(estimates[i,"cI50_mode"])
  de= median(estimates[i,"de_mode"])
  cde= median(estimates[i,"cde_mode"])
  
  GM = estimates[i,"GM_mode"]
  
  TBI = 5
  c = 23
  f=0.9
  deltaU = deltaP
  
  dns = (qdns+1)*lambdanss
  dnsp = qdns*lambdanss
  
  lambdapns = lambdap1
  lambdapm = lambdap1
  
  Kp = 10^log10Kp
  Ks = 10^(log10Kp-log10Kps)
  Km = 10^(log10Kp-log10Kpm)
  Ke = 10^(log10Kp-log10Kpe)
  
  p = 10^log10p1
  beta = 10^log10beta/VolB
  kappa = 10^log10kappa/VolB
  phi = 10^log10phi/VolB
  theta = 10^log10theta/VolB
  I50 = 10^log10I50
  
  qNSS = (rs/(lambdanss))*(Kp/Ks-1)
  qPS  = (1/lambdapns)*(dns*qNSS-lambdasns)
  qPM  = (1/lambdapm)*(rm*(Kp/Km-1))
  qME  = (re/lambdame)*(Kp/Ke-1)
  
  fp=fp100/100
  
  # Initial values
  
  Vzero = V0*1e3
  E_0  = Kp/(qNSS*qPM*qME/qPS + qPM*qME/qPS + qME +1)
  M_0  = qME*E_0
  S_0  = (qPM/qPS)*M_0
  NS_0 = qNSS*S_0
  P_0  = qPS*S_0
  Es_0 = 0
  Pp_0 = 0
  NSp_0 = 0
  NSp2_0 = 0
  Tp_0 = 0
  T_0 = 0
  Ip_0 = xi*c*Vzero/p
  Iu_0 = (1-xi)*c*Vzero/p
  V_0 = Vzero
  Mp_0 = 0
  
  eventdat <- data.frame(var = c("T", "Tp"),
                         time = c(Transplant,Transplant) ,
                         value = Rt*c(1-fp,fp),
                         method = c("add", "add"))
  
  params=c()
  init.x <- c(Tp=Tp_0,Pp=Pp_0,NSp=NSp_0,NSp2=NSp2_0,
              T=T_0,P=P_0,NS=NS_0,S=S_0,M=M_0,E=E_0,
              Ip=Ip_0,Iu=Iu_0,V=V_0,Es=Es_0,Mp=Mp_0)
  out <- as.data.frame(lsodar(init.x,t.out,T_model,params,
                              events = list(data = eventdat))) 
  
  t.out = out[,1]
  Tp = out[,2]
  Pp  = out[,3]
  NSp = out[,4]
  NSp2 = out[,5]
  T  = out[,6]
  P  = out[,7]
  NS = out[,8]
  S = out[,9]
  M =  out[,10]
  E =  out[,11]
  Ip = out[,12]
  Iu = out[,13]
  V = out[,14]
  Es = out[,15]
  Mp = out[,16]
  
  
  S_b = pb*(S+Ip+Iu+Mp)/VolB
  NS_b = pb*(NS+NSp+NSp2)/VolB
  V_b = V/(VolB/10^3)
  
  ###########################
  # Plotting Viral load
  ###########################
  
  plot(t.out/7,V_b,
         ylab="Viral load (copies/ml)",
         xlab="Weeks after challenge",
         type="l",log="y",ylim = c(10^0,10^8),xlim=c(0,tend/7),
         cex.main=1.4,cex.axis=1.4,cex.lab=1.4,
         bty="n",xaxt="n",yaxt="n")

  text( (ART+50)/7,10^6,ID,
       adj=0,cex=1.1)
  points(tVdata/7,Vdata,cex=1.2,type="p",
         lwd=1,pch=pchID[i],bg=coloresID[i])
  abline(h=30,lty=3)
  axis(2,at=c(10,1e3,1e5,1e7),labels=expression(bold("10"),bold("10"^"3"),bold("10"^"5"),bold("10"^"7")),
       cex.axis=1.4,font=2,lwd=2)
  axis(1,at=axTicks(1),labels=as.character(axTicks(1)),cex.axis=1.4,font=2,lwd=2,las=1)
  lines(t.out/7,V_b,lty=1)
  abline(v=ART/7,lty=4)
  abline(v=ATI/7,lty=3)
  abline(v=Transplant/7,lty=4)

  ###########################
  # Plotting CCR5+
  ###########################
  plot(t.out/7,S_b,
       ylab=expression(paste("CD4"^"+","CCR5"^"+"," T cells/",mu,"L")),
       xlab="Weeks after challenge",xlim=c(0,tend/7),
       type="l",log="y",ylim = c(10^0,10^3),
       cex.main=1.4,cex.axis=1.4,cex.lab=1.4,
       yaxt="n",xaxt="n",bty="n")
  points(tCD4CCR5pd/7,CD4CCR5pd,cex=1.2,type="p",
         lwd=1,pch=pchID[i],bg=coloresID[i])
  
  axis(2,at=c(1,10,100,1000),labels=expression(bold("1"),bold("10"),bold("10"^"2"),bold("10"^"3")),
       cex.axis=1.4,font=2,lwd=2)#,las=1)
  
  atx <- axTicks(1)
  axis(1,at=atx,labels=as.character(atx),cex.axis=1.4,font=2,lwd=2,las=1)
  lines(t.out/7,S_b,lty=1)

  abline(v=ART/7,lty=4)
  abline(v=ATI/7,lty=3)
  abline(v=Transplant/7,lty=4)
  
  ###########################
  # Plotting CCR5-
  ###########################
  plot(t.out/7,NS_b,
       ylab=expression(paste("CD4"^"+","CCR5"^"-"," T cells/",mu,"L")),
       xlab="Weeks after challenge",xlim=c(0,tend/7),
       type="l",log="y",ylim = c(10^1,10^4),
       cex.main=1.4,cex.axis=1.4,cex.lab=1.4,
       yaxt="n",xaxt="n",bty="n")
  points(tCD4CCR5nd/7,CD4CCR5nd,cex=1.2,type="p",
           lwd=1,pch=pchID[i],bg=coloresID[i])
    
  axis(2,at=c(1,10,100,1000,1e4),labels=expression(bold("1"),bold("10"),bold("10"^"2"),bold("10"^"3"),bold("10"^"4")),
         cex.axis=1.4,font=2,lwd=2)#,las=1)
  axis(1,at=axTicks(1),labels=as.character(axTicks(1)),cex.axis=1.4,font=2,lwd=2,las=1)
    
  abline(v=ART/7,lty=4)
  abline(v=ATI/7,lty=3)
  abline(v=Transplant/7,lty=4)
    
}
