library("rstan")
library("lattice")
library("gridExtra")

load("output1.rda")
source("model_input.R")
extracted_output=rstan::extract(output1)

expose_stan_functions("plot_simulate.stan")
thinned_index=floor(seq(1,length(extracted_output$rho),length.out=200))
n_rep = length(thinned_index)
n_var=4
storage=array(NA,dim=c(n_ds,n_var,n_mice,n_rep))

for(i in 1:length(thinned_index)){
  for(m in 1:n_mice){
    phi = c(extracted_output$phi[thinned_index[i],1],extracted_output$phi[thinned_index[i],2])
    x_i = c(n_ds=n_ds,
            n_Eq=n_Eq)
    x_r = c(max_iRBC = max_iRBC,
            kappa = kappa,
            muM = muM,
            R0 = R0[m],
            I0 = I0[m])
    theta = extracted_output$theta[thinned_index[i],m,]
    predictedFit=malaria_fit(phi,theta, x_r, x_i) #indexed as predictedFit[[time]][variable]
    
    for(t in 1:14){
      storage[t,1,m,i] = predictedFit[[t]][1] #N1[t]
      storage[t,2,m,i] = predictedFit[[t]][2] #N2[t] 
      storage[t,3,m,i] = rnorm(1,sum(predictedFit[[t]][3:4]),phi[1]) #R[t] with noise
      storage[t,4,m,i] = rnorm(1,log10(predictedFit[[t]][4]+1),phi[2]) #I[t]
      
    }
  }
}

lower = upper = median_outcome = array(NA,dim=c(n_ds,n_var,n_mice))
for(t in 1:n_ds){
  for(j in 1:n_var){
    for(m in 1:n_mice){
      lower[t,j,m]=quantile(storage[t,j,m,],0.025,na.rm = T)
      upper[t,j,m]=quantile(storage[t,j,m,],0.975,na.rm = T)
      median_outcome[t,j,m]=median(storage[t,j,m,],na.rm = T)
    }
  }
}

mice = unique(mouse)
y_RBC[y_RBC==0] = NA
y_iRBC[y_iRBC==0] = NA

par(oma=c(3.5,3,2,3),mar=c(0.5,1,0.5,1))
layout.matrix <- t(matrix(c(c(71,1, 51,21,31,41,11,61),
                            c(72,2, 52,22,32,42,12,62),
                            c(73,3, 53,23,33,43,13,63),
                            c(74,4, 54,24,34,44,14,64),
                            c(75,5, 55,25,35,45,15,65),
                            c(76,6, 56,26,36,46,16,66),
                            c(77,7, 57,27,37,47,17,67),
                            c(78,8, 58,28,38,48,18,68),
                            c(79,9, 59,29,39,49,19,69),
                            c(80,10,60,30,40,50,20,70)), 
                          nrow = 8, ncol = 10)) # nrow is actually ncol, vice versa
layout(mat = layout.matrix,
       heights = rep(1,10), # Heights of the  rows
       widths = rep(1,8)) # Widths of the two columns

layout.show(80)
mouse_list = c(1:80)

### Model fit
for(m in mouse_list){
  time_u=1:max(which(!is.na(y_RBC[mouse == mice[m]]))) #seq(1,14,1)
  
  plot(NA,type="l",ylim=c(0,1.2*10^7.0),xlim=c(0,16),axes=F,ylab=NA,xlab=NA)
  abline(v=seq(0,15,5),col="grey70",lwd=0.2)
  axis(1,las=1,labels=F,at=seq(0,15,5),tck=-0.025)
  axis(2,las=1,labels=F,at=seq(0,1.2*10^7.0,length.out = 3),tck=-0.025)
  polygon(x=c(time_u,rev(time_u)),y=c(lower[time_u,3,m],rev(upper[time_u,3,m])),col=rgb(19/255, 19/255, 19/255,alpha=0.3),border=FALSE)
  points(y_RBC[mouse == mice[m]]~day[mouse == mice[m]],pch=4,col="tomato",type="p",lwd=0.5,cex=0.75)
  
  if(m==80) {
    mtext("RBC",side=2,line=2.5,cex=0.75,col="tomato")
    axis(2,tick = F,at=seq(0,1.2*10^7.0,length.out = 3),cex.axis=0.8,labels=c(0,expression(6%*%10^6),expression(1.2%*%10^7)))
    
  }
  
  par(new=T)
  plot(NA,type="l",ylim=c(0,6.5),xlim=c(0,16),axes=F,ylab=NA,xlab=NA)
  axis(4,labels=F,at=seq(0,6.5,length.out = 3),tck=-0.025)
  polygon(x=c(time_u,rev(time_u)),y=c(lower[time_u,4,m],rev(upper[time_u,4,m])),col=rgb(19/255, 19/255, 19/255,alpha=0.3),border=FALSE)
  points(log10(y_iRBC[mouse == mice[m]])~day[mouse == mice[m]],pch=4,col="royalblue",type="p",lwd=0.5,cex=0.75)
  
  if(m == 1)  {mtext("129S1/SvImJ",side=3,line=1,cex=0.8)}
  if(m == 11)  {mtext("A/J",side=3,line=1,cex=0.8)}
  if(m == 21) {mtext("C57BL/6",side=3,line=1,cex=0.8)}
  if(m == 31) {mtext("CAST/EiJ",side=3,line=1,cex=0.8)}
  if(m == 41) {mtext("NOD/ShiLtJ",side=3,line=1,cex=0.8)}
  if(m == 51) {mtext("NZO/HILtJ",side=3,line=1,cex=0.8)}
  if(m == 61) {mtext("PWK/PhJ",side=3,line=1,cex=0.8)}
  if(m == 71) {mtext("WSB/EiJ",side=3,line=1,cex=0.8)}
  
  if(m==70) {
    mtext(expression(paste("log"[10],"(iRBC)")),side=4,line=2.5,cex=0.8,col="royalblue")
    axis(4,tick = F,at=seq(0,6.5,length.out = 3),cex.axis=0.8)
    mtext("Day post infection",side=1,line=2.5,cex=0.8)
    
  }
  
  if(m==10 | m==20 | m==30 | m==40 | m==50 | m==60 | m==70 | m==80 ) {
    axis(1,las=1,tick = F,at=seq(0,15,5),cex.axis=0.8)
  }
}

### Residuals
time_u =seq(3,14,1)
ts = seq(1,14,1)
storage_sse=array(NA,dim=c(2,n_mice,length(time_u)))
for(m in 1:n_mice){
  y_RBC_m = y_RBC[mouse==mice[m]] 
  y_iRBC_m = y_iRBC[mouse==mice[m]] 
  
  y_RBC_m[y_RBC_m==0] = NA
  y_iRBC_m[y_iRBC_m==0] = NA
  
  for(t in 1:length(time_u)){
    if(is.infinite(y_RBC_m[time_u[t]]) || is.na(y_RBC_m[time_u[t]])){
      storage_sse[1,m,t]=NA
    } else{
      temp=rep(NA,length(thinned_index))
      for(i in 1:length(thinned_index)){
        weight_model = extracted_output$lp__[thinned_index[i]]/sum(extracted_output$lp__[thinned_index])
        temp[i]=weight_model * (y_RBC_m[time_u[t]] - storage[which(ts==time_u[t]),3,m,i])
      }
      storage_sse[1,m,t]=sum(temp) 
    }
    
    if(is.infinite(y_iRBC_m[time_u[t]]) || is.na(y_iRBC_m[time_u[t]])){
      storage_sse[2,m,t]=NA
    } else{
      temp=rep(NA,length(thinned_index))
      for(i in 1:length(thinned_index)){
        weight_model = extracted_output$lp__[thinned_index[i]]/sum(extracted_output$lp__[thinned_index])
        temp[i]=weight_model * (log10(y_iRBC_m[time_u[t]]) - storage[which(ts==time_u[t]),4,m,i])
      }
      storage_sse[2,m,t]=sum(temp) 
    }
  }
}

mouse_strains=c("129S1/SvImJ","A/J","C57BL/6","CAST/EiJ","NOD/ShiLtJ","NZO/HILtJ","PWK/PhJ","WSB/EiJ")

compa = (11+6+11+5+6+11+6+11+
         11+6+11+5+6+10+6+10)
strain_id_mod = c(rep(8,10),rep(1,10),rep(6,10),rep(3,10),rep(4,10),rep(5,10),rep(2,10),rep(7,10))
par(mfcol=c(2,8),mar=c(1,1,0,0),oma=c(3,5.5,3,0))
for(s in c(8,1,6,3,4,5,2,7)){
  for(j in 1:2){
    if(j==1) {
      plot(NA,xlim=c(4,14),ylim=c(-4,4),
           ylab=NA,xlab=NA,
           axes=F)
      if(s == 1)  {mtext("129S1/SvImJ",side=3,line=1,cex=0.8)}
      if(s == 2)  {mtext("A/J",side=3,line=1,cex=0.8)}
      if(s == 3) {mtext("C57BL/6",side=3,line=1,cex=0.8)}
      if(s == 4) {mtext("CAST/EiJ",side=3,line=1,cex=0.8)}
      if(s == 5) {mtext("NOD/ShiLtJ",side=3,line=1,cex=0.8)}
      if(s == 6) {mtext("NZO/HILtJ",side=3,line=1,cex=0.8)}
      if(s == 7) {mtext("PWK/PhJ",side=3,line=1,cex=0.8)}
      if(s == 8) {mtext("WSB/EiJ",side=3,line=1,cex=0.8)}
      
      medianSD = median(exp(extracted_output$sd_RBC)) * 5*10^5
    } else if(j==2){
      plot(NA,xlim=c(4,14),ylim=c(-4,4),
           ylab=NA,xlab=NA,
           axes=F)
      medianSD = median(exp(extracted_output$sd_iRBC)) * 0.2
    }
    if(s==8 & j==1){
      mtext("RBC",side=2,line=3.5,cex=0.8,las=1)
    }
    
    if(s==8 & j==2) {
      axis(2,cex.axis=0.8,tck=-0.025,las=1)
      axis(1,cex.axis=0.8,tck=-0.025)
      mtext("Day post infection",side=1,line=2.5,cex=0.8)
      mtext("Std. residuals",side=2,line=2,cex=0.8)
      mtext("iRBC",side=2,line=3.5,cex=0.8,las=1)
      
    } else{
      axis(2,cex.axis=0.8,tck=-0.025,labels = F)
      axis(1,cex.axis=0.8,tck=-0.025,labels = F)
    }

    bonf_alpha = 0.05/compa
    
    n_per_day = rep(NA,length(time_u))
    for(t in 1:length(time_u)){
      if(j==2){
        n_per_day[t] = sum(y_iRBC[strain==s & day ==time_u[t]]>0,na.rm=T)
      }
      if(j==1){
        n_per_day[t] = sum(y_RBC[strain==s &  day ==time_u[t]]>0,na.rm=T)
      }
    }
    z_score = qnorm(1-(bonf_alpha/2),0,1)
    boundary = rep(z_score,length(time_u))/sqrt(n_per_day)
    points(boundary[is.finite(boundary)]~time_u[is.finite(boundary)],type='l',lty="dotted",col="royalblue")
    points(-boundary[is.finite(boundary)]~time_u[is.finite(boundary)],type='l',lty="dotted",col="royalblue")
    
    abline(h=0,lty="dashed")
    for(t in 1:length(time_u)){
      out = storage_sse[j,mice%in%mice[strain_id==s],t]/medianSD
      if(length(out)>0){
        points(out~rep(time_u[t],length(out)),pch=4,col="grey60")
        points(mean(out,na.rm = T)~time_u[t],pch=16,col="red")
      }
    }
  }
}

