library("rstan")
library("lattice")
library("gridExtra")

load("output1.rda")
extracted_output=rstan::extract(output1)
source("model_input.R")

mouse_strains=c("129S1/SvImJ","A/J","C57BL/6","CAST/EiJ","NOD/ShiLtJ","NZO/HILtJ","PWK/PhJ","WSB/EiJ")

#####################################################################################
cols=c(rgb(10/255, 100/255, 225/255,alpha=1),
       rgb(230/255, 100/255, 0/255,alpha=1),
       rgb(31/255, 204/255, 129/255,alpha=1),
       rgb(158/255, 226/255, 27/255,alpha=1),
       rgb(255/255, 204/255, 0/255,alpha=1),
       rgb(50/255, 200/255, 255/255,alpha=1),
       rgb(171/255, 8/255, 38/255,alpha=1),
       rgb(0/255, 30/255, 120/255,alpha=1)) 

######################################################################################
par(mfrow=c(1,11),oma=c(2,3,3,7))
layout.matrix <- t(matrix(c(1, 2, 3, 4, 5, 6, 
                            7, 8, 9, 10,11,12), 
                          nrow = n_RndEffs, ncol = 2)) # nrow is actually ncol, vice versa
layout(mat = layout.matrix,
       heights = c(1,1), # Heights of the  rows
       widths = rep(1,n_RndEffs)) # Widths of the columns
layout.show(12)

parNames = c(expression(paste(mu[R],"'")),expression(paste(mu[R],"''")),
             expression(rho),
             expression(psi[N1]),expression(psi[N2]),
             expression(beta))

par(mar=c(3.5,0.5,0.5,0.5))
strain_list = c(7,2,5,4,3,6,1,8)
for(i in 1:length(parNames)){
  for(s in 1:length(strain_list)){
    stretch_coeff=1.5
    if(s==1){
      plot(NA,xlim=c(-2.2,2.2),
           ylim=c(0,7.5),
           ylab=NA,
           xlab=NA,axes=F)
      abline(v=0,col=rgb(0.2,0.2,0.2,alpha=0.3))
      axis(1,las=1,tck=-0.05,at=c(-2.0,0,2.0),cex.axis=0.8)
      if(i==9){
        mtext(mouse_strains[strain_list[1]],side=4,line=1,at=1-0.75,cex=0.75,las=1,col=cols[strain_list[1]])
        mtext(mouse_strains[strain_list[2]],side=4,line=1,at=2-0.75,cex=0.75,las=1,col=cols[strain_list[2]])
        mtext(mouse_strains[strain_list[3]],side=4,line=1,at=3-0.75,cex=0.75,las=1,col=cols[strain_list[3]])
        mtext(mouse_strains[strain_list[4]],side=4,line=1,at=4-0.75,cex=0.75,las=1,col=cols[strain_list[4]])
        mtext(mouse_strains[strain_list[5]],side=4,line=1,at=5-0.75,cex=0.75,las=1,col=cols[strain_list[5]])
        mtext(mouse_strains[strain_list[6]],side=4,line=1,at=6-0.75,cex=0.75,las=1,col=cols[strain_list[6]])
        mtext(mouse_strains[strain_list[7]],side=4,line=1,at=7-0.75,cex=0.75,las=1,col=cols[strain_list[7]])
        mtext(mouse_strains[strain_list[8]],side=4,line=1,at=8-0.75,cex=0.75,las=1,col=cols[strain_list[8]])
      }
    }

    postSamples = extracted_output$s[,i,strain_list[s]]
    plotted_density = density(postSamples,adjust = 3)
    
    tempCeiling = stretch_coeff*max(plotted_density$y)
    
    plotted_density$y = (plotted_density$y/tempCeiling) + (s-1)
    
    polygon(plotted_density,col=cols[strain_list[s]],border=NA)
    
  }
  if(i==1) mtext("a",side=2,line=1,at=8,cex=1,las=1,font=2)
  mtext("s",side=1,line=2,las=1,cex=0.8)
  mtext(parNames[i],side=3,line=1)
}

dataset = as.data.frame(matrix(NA,nrow=n_mice,ncol=10))
colnames(dataset) = c("muR_dash","muR_dash2",
                      "rho",
                      "psi_N1","psi_N2",
                      "beta","muz_hat","strain","tag","status")

dataset$strain = c(rep("129S1/SvImJ",10),rep("A/J",10),rep("C57BL/6",10),rep("CAST/EiJ",10),
                   rep("NOD/ShiLtJ",10),rep("NZO/HILtJ",10),rep("PWK/PhJ",10),rep("WSB/EiJ",10))
dataset$tag = c(rep("h",10),rep("d",10),rep("g",10),rep("c",10),
                rep("b",10),rep("f",10),rep("a",10),rep("e",10))
dataset$status = c(rep("Resilient",10),rep("Non-resilient",10),rep("Resilient",10),rep("Non-resilient",10),
                   rep("Non-resilient",10),rep("Resilient",10),rep("Non-resilient",10),rep("Resilient",10))

for(m in 1:n_mice){
  for(i in 1:n_RndEffs){
    dataset[m,i] = median(extracted_output$theta[,m,i])
  }
}

color_strain = rep(NA,n_mice)
for(i in 1:n_mice){
  if(dataset$strain[i]=="129S1/SvImJ"){
    color_strain[i] = cols[1]
  } 
  if(dataset$strain[i]=="A/J"){
    color_strain[i] = cols[2]
  } 
  if(dataset$strain[i]=="C57BL/6"){
    color_strain[i] =cols[3]
  }
  if(dataset$strain[i]=="CAST/EiJ"){
    color_strain[i] = cols[4]
  } 
  if(dataset$strain[i]=="NOD/ShiLtJ"){
    color_strain[i] = cols[5]
  } 
  if(dataset$strain[i]=="NZO/HILtJ"){
    color_strain[i] = cols[6]
  }
  if(dataset$strain[i]=="PWK/PhJ"){
    color_strain[i] = cols[7]
  }
  if(dataset$strain[i]=="WSB/EiJ"){
    color_strain[i] = cols[8]
  } 
}


par(mar=c(2.5,1,0,1))
barplot(dataset$muR_dash[order(dataset$muR_dash)],xlim=c(0,max(dataset$muR_dash[order(dataset$muR_dash)])*1.1),border=NA,horiz = T,col=color_strain[order(dataset$muR_dash)],space=rep(0,n_mice),cex.axis=0.8,axes=F)
axis(1,las=1,tck=-0.05,cex.axis=0.8,at=seq(0,0.05,0.025))
mtext("b",side=2,line=1,at=80,cex=1,las=1,font=2)
mtext(parNames[1],side=1,line=2.5,las=1,cex=0.8)
barplot(dataset$muR_dash2[order(dataset$muR_dash2)],xlim=c(0,max(dataset$muR_dash2[order(dataset$muR_dash2)])*1.1),border=NA,horiz = T,col=color_strain[order(dataset$muR_dash2)],space=rep(0,n_mice),cex.axis=0.8,axes=F)
axis(1,las=1,tck=-0.05,cex.axis=0.8,at=seq(0,5e-03,2.5e-03))
mtext(parNames[2],side=1,line=2.5,las=1,cex=0.8)
barplot(dataset$rho[order(dataset$rho)],horiz = T,xlim=c(0,max(dataset$rho[order(dataset$rho)])*1.1),border=NA,col=color_strain[order(dataset$rho)],space=rep(0,n_mice),cex.axis=0.8,axes=F)
axis(1,las=1,tck=-0.05,cex.axis=0.8,at=seq(0,0.3,0.15))
mtext(parNames[3],side=1,line=2.5,las=1,cex=0.8)
barplot(dataset$psi_N1[order(dataset$psi_N1)],horiz = T,xlim=c(0,max(dataset$psi_N1[order(dataset$psi_N1)])*1.1),border=NA,col=color_strain[order(dataset$psi_N1)],space=rep(0,n_mice),cex.axis=0.8,axes=F)
axis(1,las=1,tck=-0.05,cex.axis=0.8,at=seq(0,0.5,0.25))
mtext(parNames[4],side=1,line=2.5,las=1,cex=0.8)
barplot(dataset$psi_N2[order(dataset$psi_N2)],horiz = T,xlim=c(0,max(dataset$psi_N2[order(dataset$psi_N2)])),border=NA,col=color_strain[order(dataset$psi_N2)],space=rep(0,n_mice),cex.axis=0.8,axes=F)
axis(1,las=1,tck=-0.05,cex.axis=0.8,at=seq(0,8,4))
mtext(parNames[5],side=1,line=2.5,las=1,cex=0.8)
barplot(dataset$beta[order(dataset$beta)],horiz = T,xlim=c(0,max(dataset$beta[order(dataset$beta)])*1.1),border=NA,col=color_strain[order(dataset$beta)],space=rep(0,n_mice),cex.axis=0.8,axes=F)
axis(1,las=1,tck=-0.05,cex.axis=0.8,at=seq(0,12,6))
mtext(parNames[6],side=1,line=2.5,las=1,cex=0.8)