### R script to estimate 2019 incidence rate and typical (2014-2019) seasonality
### pattern in incidence for Haiti based on monthly 2014-2019 HF data
### Three process: (1) detrend and median filter the six years of data
### (2) disaggregate for annual average
### (3) disaggregate for monthly temporal innovations

setwd("~/Code/Haiti2019/")

## Step 1: Define 1 km sq. grid system

library(raster)
library(rgeos)
library(rgdal)

reference.image <- raster("Covariates/Monthly/EVI/EVI_v6.2014.01.Mean.1km.Data.tif")
admin0 <- readOGR("adm/HaitiOutline.shp")
ref.vals <- numeric(length(reference.image))
coordinates.reference <- coordinates(reference.image)
for (i in 1:length(admin0@polygons[[1]]@Polygons)) {
  ref.vals <- ref.vals + point.in.polygon(coordinates.reference[,1],coordinates.reference[,2],admin0@polygons[[1]]@Polygons[[i]]@coords[,1],admin0@polygons[[1]]@Polygons[[i]]@coords[,2])
}
ref.vals[ref.vals==0] <- NA
values(reference.image) <- ref.vals

xrange.ref <- range(coordinates(reference.image)[!is.na(getValues(reference.image)),1])
xrange.ref <- xrange.ref+c(-1,1)*diff(xrange.ref)*0.05
yrange.ref <- range(coordinates(reference.image)[!is.na(getValues(reference.image)),2])
yrange.ref <- yrange.ref+c(-1,1)*diff(yrange.ref)*0.05
reference.image <- crop(reference.image,extent(c(xrange.ref,yrange.ref)))

population.facebook.highres.vals <- getValues(raster("hrsl_hti_pop.tif"))
population.facebook.highres.coords <- coordinates(raster("hrsl_hti_pop.tif"))
population.facebook.highres.coords <- population.facebook.highres.coords[!is.na(population.facebook.highres.vals),]
population.facebook.highres.vals <- population.facebook.highres.vals[!is.na(population.facebook.highres.vals)]
facebook.lookup <- cellFromXY(reference.image,population.facebook.highres.coords)
facebook.counts <- aggregate(population.facebook.highres.vals,list(facebook.lookup),sum)[,2]
facebook.pop <- reference.image
values(facebook.pop) <- NA
facebook.pop[sort(unique(facebook.lookup))] <- facebook.counts
facebook.pop[facebook.pop<1] <- NA

population <- getValues(facebook.pop)

in.country <- which(!is.na(getValues(reference.image)) & !is.na(population))
in.country.coords <- coordinates(reference.image)[in.country,]
bigN <- length(in.country)

population <- getValues(facebook.pop)[in.country]

buffer.image <- reference.image+NA
buffer.image[in.country] <- population
writeRaster(buffer.image,"reference_grid_population.tif",overwrite=T)
#save.image("postStep1.dat")
#load("postStep1.dat")

## Step 2: Read in case data & identify missingness

case.data <- read.csv("Haiti_2014_2019_monthly_update_270820.csv")
hf.ids <- unique(case.data$MasterID)
hf.longlats <- cbind(case.data$Long,case.data$Lat)[!duplicated(case.data$MasterID),]

hf.names <- unique(case.data$HF)
hf.name.to.id <- numeric(length(hf.names))
for (i in 1:length(hf.names)) { hf.name.to.id[i] <- case.data$MasterID[case.data$HF==hf.names[i]][1]}

case.missingness <- matrix(1,nrow=length(hf.ids),ncol=72)
for (i in 1:length(hf.ids)) {
  for (j in 1:6) {
    for (k in 1:12) {
      if (length(which(case.data$MasterID==hf.ids[i] & case.data$Year==(2014:2019)[j] & case.data$Month==k))>0) {case.missingness[i,(j-1)*12+k] <- 0}
    }
  }
  cat(i,"\n")
}
last.month.reported <- numeric(length(hf.ids))
for (i in 1:length(hf.ids)) {
  last.month.reported[i] <- max(which(case.missingness[i,]==0))
}
hfs.reporting.in.2019 <- which(last.month.reported > 5*12)

nHFs.reporting <- length(hfs.reporting.in.2019)
hf.longlats.reporting <- hf.longlats[hfs.reporting.in.2019,]
hf.ids.reporting <- hf.ids[hfs.reporting.in.2019]

invalid.longlat <- which(is.na(extract(reference.image,hf.longlats.reporting)))
for (i in invalid.longlat) {
  distances <- (in.country.coords[,1]-hf.longlats.reporting[i,1])^2+(in.country.coords[,2]-hf.longlats.reporting[i,2])^2
  distances <- sort.list(distances)[1]
  hf.longlats.reporting[i,] <- in.country.coords[distances,]
}

monthly.case.matrix.reporting <- matrix(NA,nrow=nHFs.reporting,ncol=12*6)
for (i in 1:6) {
  for (j in 1:12) {
    for (k in 1:nHFs.reporting) {
      if (length(which(case.data$MasterID==hf.ids.reporting[k] & case.data$Year==c(2014,2015,2016,2017,2018,2019)[i] & case.data$Month==j))>0) {
      monthly.case.matrix.reporting[k,(i-1)*12+j] <- max(case.data$Total_cases[case.data$MasterID==hf.ids.reporting[k] & case.data$Year==c(2014,2015,2016,2017,2018,2019)[i] & case.data$Month==j])}
    }
  }
}
monthly.case.matrix.reporting.missing <- matrix(0,nrow=nHFs.reporting,ncol=12*6)
monthly.case.matrix.reporting.missing[is.na(monthly.case.matrix.reporting)] <- 1
monthly.case.matrix.reporting[is.na(monthly.case.matrix.reporting)] <- -999

test.data <- read.csv("2012_2019 OU Long 2020 06 12.csv")
Nmic <- Ntot <- matrix(NA,nrow=nHFs.reporting,ncol=12*6)
test.data$TestedCOMMUNITYMonthly[is.na(test.data$TestedCOMMUNITYMonthly)] <- 0

for(i in 1:length(test.data$ADM1Code)) {
  if (test.data$HF_Name[i] %in% hf.names) {
    indx <- which(test.data$HF_Name[i]==hf.names)
    test.data$HF_Code_Final[i] <- hf.name.to.id[indx]
  }
}
for (i in 1:6) {
  for (j in 1:12) {
    for (k in 1:nHFs.reporting) {
      Nmic[k,(i-1)*12+j] <- sum(test.data$MicroscopytestedPASSIVEPoint[test.data$HF_Code_Final %in% hf.ids.reporting[k] & test.data$Year==c(2014,2015,2016,2017,2018,2019)[i] & test.data$Month==j],na.rm=TRUE)
      Ntot[k,(i-1)*12+j] <- sum(c(test.data$MicroscopytestedPASSIVEPoint[test.data$HF_Code_Final %in% hf.ids.reporting[k] & test.data$Year==c(2014,2015,2016,2017,2018,2019)[i] & test.data$Month==j],test.data$RDTtestedPASSIVEPointofCare[test.data$HF_Code_Final %in% hf.ids.reporting[k] & test.data$Year==c(2014,2015,2016,2017,2018,2019)[i] & test.data$Month==j],test.data$TestedCOMMUNITYMonthly[test.data$HF_Code_Final %in% hf.ids.reporting[k] & test.data$Year==c(2014,2015,2016,2017,2018,2019)[i] & test.data$Month==j]),na.rm=TRUE)
    }
  }
}

library(splines)
bsplines <- bs(1:72,df=3)
temporal.spline.matrix <- matrix(NA,nrow=72,ncol=3)
temporal.spline.matrix[,1] <- bsplines[,1]
temporal.spline.matrix[,2] <- bsplines[,2]
temporal.spline.matrix[,3] <- bsplines[,3]

temporal.spline.matrix <- temporal.spline.matrix[72:1,]

input.data <- list(
  'nHFs'=nHFs.reporting,
  'nMonths'=72,
  'Nmic'=Nmic,
  'Ntot'=Ntot,
  'temporal_spline_matrix'=t(temporal.spline.matrix)
)

parameters <- list(
  'mean_intercept'=-1,
  'mean_spline_slopes'=rep(0,3),
  'log_shrinkage_sd'=-1,
  'local_spline_intercepts'=rep(0,nHFs.reporting),
  'local_spline_slopes'=matrix(0,nrow=nHFs.reporting,ncol=3)
)

library(TMB)
compile("haiti_micfrac_smoothing.cpp",flags="-Ofast -Wno-errors")
dyn.load(dynlib("haiti_micfrac_smoothing"))

obj <- MakeADFun(input.data,parameters,DLL="haiti_micfrac_smoothing",random=c('local_spline_intercepts','local_spline_slopes'))
obj$fn()
opt <- nlminb(obj$par,obj$fn,obj$gr,control=list(iter.max=100,eval.max=100))
save(opt,file="micsplines.dat")

micfrac_smoothed <- obj$report()$splines

mic.summary <- cbind(hf.longlats.reporting,rowMeans(micfrac_smoothed[,61:72]))
colnames(mic.summary) <- c("Long","Lat","Mic Frac 2019")
write.csv(mic.summary,file="outputs/micsummary.csv")
#save.image("postStep2.dat")
#load("postStep2.dat")

## Step 3: Read in traveltime surfaces and construct raw catchment populations

friction <- malariaAtlas::getRaster(
  surface = "A global friction surface enumerating land-based travel speed for a nominal year 2015",
  extent = bbox(reference.image))
friction[which(getValues(friction)==friction[3004])] <- 100
library(gdistance)
library(sp)
library(rgeos)
T <- gdistance::transition(friction, function(x) 1/mean(x), 8) 
T.GC <- gdistance::geoCorrection(T)   

buffer.image <- reference.image
values(buffer.image) <- NA
traveltime.distance.matrix <- matrix(0,nrow=bigN,ncol=nHFs.reporting)
for (i in 1:nHFs.reporting) {
  traveltime.distance.matrix[,i] <- gdistance::accCost(T.GC, hf.longlats.reporting[i,])[in.country]
  buffer.image[in.country] <- traveltime.distance.matrix[,i]
  image(log(buffer.image),zlim=c(log(1),log(700)))
  points(hf.longlats.reporting[i,1],hf.longlats.reporting[i,2])
  cat(i,"\n")
}
#save(traveltime.distance.matrix,file="ttdist.dat")

traveltime.distance.matrix[traveltime.distance.matrix<5] <- 5
invdistance <- 1/traveltime.distance.matrix^2
invdistance[traveltime.distance.matrix > 120] <- 0

catchments <- invdistance
catchments <- catchments/rowSums(catchments+0.000000001)

estimated.catchment.pops <- as.numeric(t(catchments)%*%population)

#save.image("postStep3.dat")
#load("postStep3.dat")

## Step 4: Build INLA mesh

library(INLA)
haiti.mesh <- inla.mesh.2d(boundary = admin0,max.edge=c(0.1,2),cut=0.05)
haiti.spde <- (inla.spde2.matern(haiti.mesh,alpha=2)$param.inla)[c("M0","M1","M2")]
haiti.A <- inla.mesh.project(haiti.mesh,as.matrix(hf.longlats.reporting))$A

## Step 5: Fit detrending model: 3 B-splines

library(splines)
bsplines <- bs(1:72,df=5)
temporal.spline.matrix <- matrix(NA,nrow=72,ncol=5)
temporal.spline.matrix[,1] <- bsplines[,1]
temporal.spline.matrix[,2] <- bsplines[,2]
temporal.spline.matrix[,3] <- bsplines[,3]
temporal.spline.matrix[,4] <- bsplines[,4]
temporal.spline.matrix[,5] <- bsplines[,5]

temporal.spline.matrix <- temporal.spline.matrix[72:1,]

intermonth.distmat <- matrix(0,nrow=12,ncol=12)
for (i in 1:11) {
  for (j in (i+1):12) {
    intermonth.distmat[i,j] <- min((j-i)/12,(12-j+i)/12)
  }
}
intermonth.distmat <- intermonth.distmat + t(intermonth.distmat)

input.data <- list(
  'nHFs'=nHFs.reporting,
  'nMonths'=72,
  'spde'=haiti.spde,
  'A'=haiti.A,
  'HFcases_matrix'=monthly.case.matrix.reporting,
  'HFmissing_matrix'=monthly.case.matrix.reporting.missing,
  'catchment_pop'=estimated.catchment.pops,
  'temporal_spline_matrix'=t(temporal.spline.matrix),
  'intermonth_distmat'=intermonth.distmat,
  'micfracs'=micfrac_smoothed
)

parameters <- list(
  'intercept_baseline'=-5,
  'log_range_baseline'=-1.0,
  'log_sd_baseline'=2.0,
  'log_range_trend'=1.0,
  'log_sd_trend'=-1.0,
  'log_ar_scale_trend'=0.0,
  'logit_ar_param_trend'=1.0,
  'log_range_seasonal'=1.0,
  'log_sd_seasonal'=-1.0,
  'log_range_seasonaltime'=1.0,
  'log_sd_seasonaltime'=-1.0,
  'log_overdispersion_sd'=1.0,
  'mean_miceffect'=0,
  'log_ar_scale_miceffect'=0.0,
  'logit_ar_param_miceffect'=1.0,
  'miceffect'=rep(0,72),
  'field_baseline'=numeric(haiti.mesh$n),
  'field_trend'=matrix(0,nrow=haiti.mesh$n,ncol=5),
  'field_seasonal'=matrix(0,nrow=haiti.mesh$n,ncol=12),
  'log_overdispersion_factors'=rep(0,nHFs.reporting)
)

library(TMB)
compile("haiti_seasonality_detrending.cpp",flags="-Ofast -Wno-errors")
dyn.load(dynlib("haiti_seasonality_detrending"))

obj <- MakeADFun(input.data,parameters,DLL="haiti_seasonality_detrending",random=c('miceffect','field_baseline','field_trend','field_seasonal','log_overdispersion_factors'))
obj$fn()
opt <- nlminb(obj$par,obj$fn,obj$gr,control=list(iter.max=100,eval.max=100))
opt <- nlminb(opt$par,obj$fn,obj$gr,control=list(iter.max=300,eval.max=300))
rep <- sdreport(obj,getJointPrecision = TRUE)
save(opt,rep,file="detrending.final.dat")

parnames <- unique(names(rep$jointPrecision[1,]))
for (i in 1:length(parnames)) {
  eval(parse(text=(paste("parameters$",parnames[i]," <- c(rep$par.fixed,rep$par.random)[names(c(rep$par.fixed,rep$par.random))==\"",parnames[i],"\"]",sep=""))))}

library(sparseMVN)
library(Matrix)
r.draws <- rmvn.sparse(50,unlist(parameters),Cholesky(rep$jointPrecision),prec=TRUE)
trend_incidence_draws <- total_incidence_field <- miceffect_draws <- list()
for (i in 1:50) {
  outputs <- obj$report(r.draws[i,])
  trend_incidence_draws[[i]] <- outputs$trend_incidence_rate-rowMeans(outputs$trend_incidence_rate[,61:72])
  total_incidence_field[[i]] <- outputs$total_incidence_field
  miceffect_draws[[i]] <- outputs$miceffect_matrix
  cat(i,"\n")
}

miceffect_posterior <- t(matrix(rep(r.draws[,which(names(unlist(parameters))=="mean_miceffect.mean_miceffect")],each=72),ncol=50))+r.draws[,which(names(unlist(parameters))=="miceffect.miceffect")]
save(miceffect_posterior,file="outputs/miceffect_posterior.dat")

monthly.case.matrix.detrended.list <- list()
for (k in 1:50) {
  monthly.case.matrix.reporting.imputed <- monthly.case.matrix.reporting
  monthly.case.matrix.reporting.imputed[monthly.case.matrix.reporting.missing==1] <- (total_incidence_field[[k]]*matrix(rep(estimated.catchment.pops,72),nrow=nHFs.reporting))[monthly.case.matrix.reporting.missing==1]
  monthly.case.matrix.detrended.raw <- monthly.case.matrix.reporting.imputed*exp(-trend_incidence_draws[[k]])*exp(-log(miceffect_draws[[k]]))
  monthly.case.matrix.detrended <- matrix(NA,nrow=nHFs.reporting,ncol=12)
  for (i in 1:nHFs.reporting) {for (j in 1:12) {monthly.case.matrix.detrended[i,j] <- median(monthly.case.matrix.detrended.raw[i,j+12*(0:5)])}}
  monthly.case.matrix.detrended <- round(monthly.case.matrix.detrended)
  monthly.case.matrix.detrended.list[[k]] <- monthly.case.matrix.detrended
}

mean.annual.cases.detrended <- numeric(nHFs.reporting)
for (i in 1:50) {
  mean.annual.cases.detrended <- mean.annual.cases.detrended+rowSums(monthly.case.matrix.detrended.list[[i]])
}
mean.annual.cases.detrended <- mean.annual.cases.detrended/50
mean.annual.cases.detrended <- cbind(hf.longlats.reporting,mean.annual.cases.detrended)
colnames(mean.annual.cases.detrended) <- c("Long","Lat","Cases")
write.csv(mean.annual.cases.detrended,file="outputs/meandetrended.csv")

save.image("detrended.dat")

trends <- matrix(0,nrow=nHFs.reporting,ncol=72)
for (i in 1:50) {trends <- trends + trend_incidence_draws[[i]]/50}
annual.trends <- numeric(nHFs.reporting)
for (i in 1:nHFs.reporting) {
  trend.buffer <- c(mean(trends[i,1:12]),mean(trends[i,13:24]),mean(trends[i,25:36]),mean(trends[i,37:48]),mean(trends[i,49:60]),mean(trends[i,61:72]))
  annual.trends[i] <- trend.buffer[which.max(abs(trend.buffer))]
}
annual.trends <- cbind(hf.longlats.reporting,annual.trends)
colnames(annual.trends) <- c("Long","Lat","Max Log Trend")
write.csv(annual.trends,file="outputs/annualtrends.csv")

## Step 6: Clustering of HFs for full catchment model

clustering <- hclust(dist(hf.longlats.reporting))
N.clustered <- 450
clusterCut <- cutree(clustering, N.clustered)
hf.longlats.clustered <- aggregate(hf.longlats.reporting,list(clusterCut),mean)[,2:3]

traveltime.distance.matrix.clustered <- t(as.matrix(aggregate(t(traveltime.distance.matrix),list(clusterCut),min)[,2:(bigN+1)]))
nHFs <- N.clustered
#save(traveltime.distance.matrix.clustered,file="ttdist_clustered.dat")

invdistance <- 1/traveltime.distance.matrix.clustered^2

catchments <- invdistance
catchments <- catchments/rowSums(catchments+0.000000001)

for (i in 1:bigN) {
  catchment.ordering <- sort.list(traveltime.distance.matrix.clustered[i,],decreasing=F)
  n.less.than.20 <- length(which(traveltime.distance.matrix.clustered[i,]<20))
  catchment.list <- catchment.ordering[1:(max(n.less.than.20,1)+4)] # supposes attendance at one of nearest 5 HFs (counting those within a 20 min travel time as a single facility)
  invdistance[i,!(1:nHFs %in% catchment.list)] <- 0
}

invdistance[traveltime.distance.matrix.clustered>120] <- 0

catchments <- invdistance
catchments <- catchments/rowSums(catchments+0.000000001)

invdistance[catchments < exp(-3)] <- 0

#save.image("postStep6.dat")
#load("postStep6.dat")

## Step 7: Construct pixel-focussed mesh

library(INLA)
haiti.mesh.coarse <- inla.mesh.2d(boundary=admin0,max.edge=c(0.1,3),cut=0.12)
haiti.spde.coarse <- (inla.spde2.matern(haiti.mesh.coarse,alpha=2)$param.inla)[c("M0","M1","M2")]
haiti.A.coarse <- inla.mesh.project(haiti.mesh.coarse,in.country.coords)$A
haiti.mesh.fine <- inla.mesh.2d(boundary=admin0,max.edge=c(0.1,3),cut=0.02)
haiti.spde.fine <- (inla.spde2.matern(haiti.mesh.fine,alpha=2)$param.inla)[c("M0","M1","M2")]
haiti.A.fine <- inla.mesh.project(haiti.mesh.fine,in.country.coords)$A

## Step 8: Read in temporal covariates

evi.covariate.matrix <- matrix(NA,nrow=bigN,ncol=12)
for (k in 1:12) {
  cov.current <- numeric(bigN)
  for (z in 1:6) {
    cov.current <- cov.current + crop(raster(paste("Covariates/Monthly/EVI/EVI_v6.",(2014:2019)[z],".",sprintf('%02i',k),".",c("M","M","M","M","m","m")[z],"ean.1km.Data.tif",sep="")),reference.image)[in.country]/6
  }
  nas <- which(is.na(cov.current))
  cat('NAs = ',length(nas),'\n')
  if (length(nas)>0) {
    validp <- which(!(is.na(cov.current) | cov.current==-Inf | cov.current==Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in nas) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  hist(cov.current)
  evi.covariate.matrix[,k] <- cov.current
}
evi.covariate.matrix <- (evi.covariate.matrix-mean(evi.covariate.matrix,na.rm=TRUE))/sd(evi.covariate.matrix,na.rm=TRUE)
evi.covariate.matrix[evi.covariate.matrix>3] <- 3
evi.covariate.matrix[evi.covariate.matrix < -3] <- -3

LSTday.covariate.matrix <- matrix(NA,nrow=bigN,ncol=12)
for (k in 1:12) {
  cov.current <- numeric(bigN)
  for (z in 1:6) {
    cov.current <- cov.current + crop(raster(paste("Covariates/Monthly/LST_Day/LST_Day_v6.",(2014:2019)[z],".",sprintf('%02i',k),".mean.1km.Data.tif",sep="")),reference.image)[in.country]/6
  }    

  nas <- which(is.na(cov.current))
  cat('NAs = ',length(nas),'\n')
  if (length(nas)>0) {
    validp <- which(!(is.na(cov.current) | cov.current==-Inf | cov.current==Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in nas) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  hist(cov.current)
  LSTday.covariate.matrix[,k] <- cov.current
}
LSTday.covariate.matrix <- (LSTday.covariate.matrix-mean(LSTday.covariate.matrix,na.rm=TRUE))/sd(LSTday.covariate.matrix,na.rm=TRUE)
LSTday.covariate.matrix[LSTday.covariate.matrix>3] <- 3
LSTday.covariate.matrix[LSTday.covariate.matrix < -3] <- -3

LSTdiff.covariate.matrix <- matrix(NA,nrow=bigN,ncol=12)
for (k in 1:12) {
  cov.current <- numeric(bigN)
  for (z in 1:6) {
    cov.current <- cov.current + crop(raster(paste("Covariates/Monthly/LST_Delta/LST_DiurnalDiff_v6.",c(2014,2014,2016,2017,2018,2019)[z],".",sprintf('%02i',k),".mean.1km.",c("","","","","Data.","Data.")[z],"tif",sep="")),reference.image)[in.country]/6
  }    

  nas <- which(is.na(cov.current))
  cat('NAs = ',length(nas),'\n')
  if (length(nas)>0) {
    validp <- which(!(is.na(cov.current) | cov.current==-Inf | cov.current==Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in nas) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  hist(cov.current)
  LSTdiff.covariate.matrix[,k] <- cov.current
}
LSTdiff.covariate.matrix <- (LSTdiff.covariate.matrix-mean(LSTdiff.covariate.matrix,na.rm=TRUE))/sd(LSTdiff.covariate.matrix,na.rm=TRUE)
LSTdiff.covariate.matrix[LSTdiff.covariate.matrix>3] <- 3
LSTdiff.covariate.matrix[LSTdiff.covariate.matrix < -3] <- -3

TCB.covariate.matrix <- matrix(NA,nrow=bigN,ncol=12)
for (k in 1:12) {
  cov.current <- numeric(bigN)
  for (z in 1:6) {
    cov.current <- cov.current + crop(raster(paste("Covariates/Monthly/TCB/TCB_v6.",(2014:2019)[z],".",sprintf('%02i',k),".mean.1km.Data.tif",sep="")),reference.image)[in.country]/6
  }    

  nas <- which(is.na(cov.current))
  cat('NAs = ',length(nas),'\n')
  if (length(nas)>0) {
    validp <- which(!(is.na(cov.current) | cov.current==-Inf | cov.current==Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in nas) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  hist(cov.current)
  TCB.covariate.matrix[,k] <- cov.current
}
TCB.covariate.matrix <- (TCB.covariate.matrix-mean(TCB.covariate.matrix,na.rm=TRUE))/sd(TCB.covariate.matrix,na.rm=TRUE)
TCB.covariate.matrix[TCB.covariate.matrix>3] <- 3
TCB.covariate.matrix[TCB.covariate.matrix < -3] <- -3

TCW.covariate.matrix <- matrix(NA,nrow=bigN,ncol=12)
for (k in 1:12) {
  cov.current <- numeric(bigN)
  for (z in 1:6) {
    cov.current <- cov.current + crop(raster(paste("Covariates/Monthly/TCW/TCW_v6.",(2014:2019)[z],".",sprintf('%02i',k),".mean.1km.Data.tif",sep="")),reference.image)[in.country]/6
  }    
  
  nas <- which(is.na(cov.current))
  cat('NAs = ',length(nas),'\n')
  if (length(nas)>0) {
    validp <- which(!(is.na(cov.current) | cov.current==-Inf | cov.current==Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in nas) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  hist(cov.current)
  TCW.covariate.matrix[,k] <- cov.current
}
TCW.covariate.matrix <- (TCW.covariate.matrix-mean(TCW.covariate.matrix,na.rm=TRUE))/sd(TCW.covariate.matrix,na.rm=TRUE)
TCW.covariate.matrix[TCW.covariate.matrix>3] <- 3
TCW.covariate.matrix[TCW.covariate.matrix < -3] <- -3

covariate.names <- c("accessibility_to_cities_2015_v1.0",
                     "AI",
                     "DistToWater",
                     "Elevation",
                     "Landcover_forest",
                     "Landcover_grass_savanna",
                     "Landcover_urban_barren",
                     "Landcover_woodysavanna",
                     "OSM_v32",
                     "PET",
                     "Slope",
                     "TWI"
)
transformation.types <- c("Exponential",
                          "Normal",
                          "Exponential",
                          "Exponential",
                          "None",
                          "None",
                          "None",
                          "None",
                          "None",
                          "Normal",
                          "Exponential",
                          "Normal"
)
Ncovariates <- length(covariate.names)
for (i in 1:Ncovariates) {cat(covariate.names[i],"\t\t",transformation.types[i],col="\n")}

covariates <- list()

for (k in 1:Ncovariates) {  
  cat("Processing covariate:",covariate.names[k],"...\n")
  cov.current <- crop(raster(paste("Covariates/",covariate.names[k],".tif",sep="")),extent(reference.image))
  if (!prod(dim(cov.current)==dim(reference.image))) {stop("Mismatched dimensions!\n")}
  cov.current <- getValues(cov.current)[in.country]
  if (transformation.types[k]=="Normal") {
    cov.current <- (cov.current-mean(cov.current,na.rm=TRUE))/sd(cov.current,na.rm=TRUE)
  }
  if (transformation.types[k]=="Exponential") {
    cov.current <- cov.current+min(cov.current[cov.current>0 & cov.current!=Inf & cov.current!=-Inf & !is.nan(cov.current)],na.rm=TRUE)+abs(min(cov.current[cov.current>0 & cov.current!=Inf & cov.current!=-Inf & !is.nan(cov.current)],na.rm=TRUE))
    cov.current <- qnorm(pexp(cov.current,1/mean(cov.current,na.rm=TRUE)))
  }
  nas <- which(is.na(cov.current))
  cat('NAs = ',length(nas),'\n')
  if (length(nas)>0) {
    validp <- which(!(is.na(cov.current) | cov.current==-Inf | cov.current==Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in nas) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  infs <- which(cov.current==Inf)
  cat('Infs = ',length(infs),'\n')
  if (length(infs)>0) {
    validp <- which(!(cov.current==-Inf | cov.current==Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in infs) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  neginfs <- which(cov.current==-Inf)
  cat('-Infs = ',length(infs),'\n')
  if (length(neginfs)>0) {
    validp <- which(!(cov.current==-Inf))
    valid.coords <- in.country.coords[validp,]
    for (i in neginfs) {
      nearestv <- which.min((valid.coords[,1]-in.country.coords[i,1])^2+(valid.coords[,2]-in.country.coords[i,2])^2)
      cov.current[i] <- cov.current[validp[nearestv]]
    }
  }
  hist(cov.current)
  cat("range: ",range(cov.current),"\n")
  Sys.sleep(3)
  covariates[[length(covariates)+1]] <- cov.current
}

covariates <- do.call(rbind,covariates)
covariates[covariates > 3] <- 3
covariates[covariates < -3] <- -3

access.to.treatment <- apply(traveltime.distance.matrix.clustered,1,min)
access.to.treatment[access.to.treatment < 1] <- 1

invlogit <- function(x) {1/(1+exp(-x))}
treatment <- 0.65*invlogit(-access.to.treatment/360)+0.3

invdists <- pixel.ids <- hf.ids <- numeric()
for (i in 1:bigN) {
  nonzero <- which(invdistance[i,]>0)
  invdists <- c(invdists,invdistance[i,nonzero])
  pixel.ids <- c(pixel.ids,rep(i,length(nonzero)))
  hf.ids <- c(hf.ids,nonzero)
}

save.image("prefit.dat")

gc()

## Step 8: Fit spatio-temporal baseline model in TMB

baseline_replicants <- list()
library(TMB)
compile("haiti_static_spatialvar.cpp",flags="-Ofast -Wno-errors")
dyn.load(dynlib("haiti_static_spatialvar"))

for (k in 1:50) {
  
input.data <- list(
  'bigN'=bigN,
  'static_covariate_matrix'=t(covariates),
  'spde_fine'=haiti.spde.fine,
  'A_fine'=haiti.A.fine,
  'spde_coarse'=haiti.spde.coarse,
  'A_coarse'=haiti.A.coarse,
  'nHFs'=nHFs,
  'population'=population,
  'HFcases'=rowSums(aggregate(monthly.case.matrix.detrended.list[[k]],list(clusterCut),sum)[2:13]),
  'Nunwrapped'=length(invdists),
  'invdists'=invdists,
  'hf_ids'=hf.ids,
  'pixel_ids'=pixel.ids,
  'treatment'=treatment
)

parameters <- list(
  'intercept'=-5,
  'static_slopes'=rep(0,dim(covariates)[1]),
  'log_range'=--3,
  'log_sd'=2,
  'field'=numeric(haiti.mesh.fine$n),
  'log_range_slopes'=-1,
  'log_sd_slopes'=0,
  'static_slopes_offsets'=matrix(0,nrow=haiti.mesh.coarse$n,ncol=(dim(covariates)[1])),
  'log_masses'=rep(0,nHFs),
  'log_overdispersion_scale'=3
)

obj <- MakeADFun(input.data,parameters,DLL="haiti_static_spatialvar",random=c('intercept','static_slopes','field','static_slopes_offsets','log_masses'))
obj$fn()

if (length(baseline_replicants)==0) {
  opt <- nlminb(obj$par,obj$fn,obj$gr,control=list(iter.max=300,eval.max=300)) 
} else {
  opt <- nlminb(opt$par,obj$fn,obj$gr,control=list(iter.max=300,eval.max=300)) 
}
rep <- sdreport(obj,getJointPrecision = TRUE)

parnames <- unique(names(rep$jointPrecision[1,]))
for (i in 1:length(parnames)) {
  eval(parse(text=(paste("parameters$",parnames[i]," <- c(rep$par.fixed,rep$par.random)[names(c(rep$par.fixed,rep$par.random))==\"",parnames[i],"\"]",sep=""))))}

library(sparseMVN)
library(Matrix)
r.draws <- rmvn.sparse(2,unlist(parameters),Cholesky(rep$jointPrecision),prec=TRUE)
for (i in 1:2) {
  baseline_replicants[[length(baseline_replicants)+1]] <- list()
  outputs <- obj$report(r.draws[i,])
  baseline_replicants[[length(baseline_replicants)]]$field.draws <- outputs$predicted_surface_malaria
  baseline_replicants[[length(baseline_replicants)]]$staticfield.draws <- outputs$static_field
  baseline_replicants[[length(baseline_replicants)]]$gp.draws <- outputs$baseline_field
  baseline_replicants[[length(baseline_replicants)]]$covar.draws <- outputs$static_field_offsets
  baseline_replicants[[length(baseline_replicants)]]$catchments <- outputs$catchments
  baseline_replicants[[length(baseline_replicants)]]$maxpred <- apply(abs(outputs$full_cov_preds),1,which.max)
  baseline_replicants[[length(baseline_replicants)]]$maxpred_pos <- apply((outputs$full_cov_preds),1,which.max)
  baseline_replicants[[length(baseline_replicants)]]$maxpred_neg <- apply((outputs$full_cov_preds),1,which.min)
  baseline_replicants[[length(baseline_replicants)]]$maxpredsign <- sign(outputs$full_cov_preds)[cbind(1:bigN,apply(abs(outputs$full_cov_preds),1,which.max))]
  baseline_replicants[[length(baseline_replicants)]]$baseline_log_masses <- r.draws[i,names(unlist(parameters))=="log_masses.log_masses"]
  baseline_replicants[[length(baseline_replicants)]]$slopes <- r.draws[i,names(unlist(parameters))=="static_slopes.static_slopes"]
  baseline_replicants[[length(baseline_replicants)]]$par.fixed <- rep$par.fixed
  baseline_replicants[[length(baseline_replicants)]]$cov.fix <- rep$cov.fixed
}
save(baseline_replicants,file="baseline_replicants.dat")
}

# Write Out Key Raster Summaries

field.draws <- gp.draws <- covar.draws <- staticfield.draws <- matrix(NA,ncol=length(baseline_replicants),nrow=bigN)
for (i in 1:length(baseline_replicants)) {
  field.draws[,i] <- baseline_replicants[[i]]$field.draws
  gp.draws[,i] <- baseline_replicants[[i]]$gp.draws
  covar.draws[,i] <- baseline_replicants[[i]]$covar.draws
  staticfield.draws[,i] <- baseline_replicants[[i]]$staticfield.draws
}
pointwise.mean.caserate <- apply(field.draws,1,mean)
pointwise.stddev.caserate <- apply(field.draws,1,sd)
pointwise.mean.gp <- apply(gp.draws,1,mean)
pointwise.mean.covar <- apply(covar.draws,1,mean)
pointwise.mean.covarstatic <- apply(staticfield.draws,1,mean)
pointwise.mean.cases <- log(population)+pointwise.mean.caserate
buffer.image[in.country] <- pointwise.mean.caserate
writeRaster(buffer.image,file="outputs/final_baseline.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.stddev.caserate
writeRaster(buffer.image,file="outputs/final_baseline_stddev.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.gp
writeRaster(buffer.image,file="outputs/final_baseline_gp.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.covar
writeRaster(buffer.image,file="outputs/final_baseline_covs.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.covarstatic
writeRaster(buffer.image,file="outputs/final_baseline_staticcovs.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.cases
writeRaster(buffer.image,file="outputs/final_baseline_counts.tif",overwrite=TRUE)

exceedance.prob <- apply((field.draws>log(1/1000)),1,mean)
buffer.image[in.country] <- exceedance.prob
writeRaster(buffer.image,file="outputs/final_baseline_prob_exceed_1_per_1000_PYO.tif",overwrite=TRUE)
exceedance.prob <- apply((field.draws>log(50/1000)),1,mean)
buffer.image[in.country] <- exceedance.prob
writeRaster(buffer.image,file="outputs/final_baseline_prob_exceed_50_per_1000_PYO.tif",overwrite=TRUE)
nonexceedance.prob <- apply((field.draws<log(0.1/1000)),1,mean)
buffer.image[in.country] <- nonexceedance.prob
writeRaster(buffer.image,file="outputs/final_baseline_prob_nonexceed_1_per_10000_PYO.tif",overwrite=TRUE)

pop.at.risk <- colSums((field.draws>log(1/1000))*matrix(rep(population,length(baseline_replicants)),ncol=length(baseline_replicants)))
pop.at.risk.median <- median(pop.at.risk)
pop.at.risk.lower <- quantile(pop.at.risk,0.025)
pop.at.risk.upper <- quantile(pop.at.risk,0.975)
pop.output <- as.matrix(c(pop.at.risk.median,pop.at.risk.lower,pop.at.risk.upper))
names(pop.output) <- c("Median","Lower 95%","Upper 95%")
write.csv(pop.output,file="outputs/population_at_risk_national_1_per_1000_PYO.csv")

reference.coords <- in.country.coords
admin1 <- readOGR("adm/hti_admbnda_adm1_cnigs_20181129.shp")
pop.at.risk.summary <- list()
for (i in 1:length(admin1)) {
  in.admin.sector <- numeric(bigN)
  for (j in 1:length(admin1@polygons[[i]]@Polygons)) {
    in.admin.sector <- in.admin.sector + point.in.polygon(reference.coords[,1],reference.coords[,2],admin1@polygons[[i]]@Polygons[[j]]@coords[,1],admin1@polygons[[i]]@Polygons[[j]]@coords[,2])
  }
  pop.at.risk <- colSums((field.draws>log(1/1000))*matrix(rep(population,length(baseline_replicants)),ncol=length(baseline_replicants))*matrix(rep(in.admin.sector,length(baseline_replicants)),ncol=length(baseline_replicants)))
  pop.at.risk.median <- median(pop.at.risk)
  pop.at.risk.lower <- quantile(pop.at.risk,0.025)
  pop.at.risk.upper <- quantile(pop.at.risk,0.975)
  pop.at.risk.summary[[i]] <- c(pop.at.risk.median,pop.at.risk.lower,pop.at.risk.upper)  
}
pop.at.risk.summary <- do.call(rbind,pop.at.risk.summary)
pop.at.risk.summary <- cbind(as.character(admin1@data$ADM1_FR),as.character(pop.at.risk.summary[,1]),as.character(pop.at.risk.summary[,2]),as.character(pop.at.risk.summary[,3]))
colnames(pop.at.risk.summary) <- c("Department","Median","Lower 95%","Upper 95%")
write.csv(pop.output,file="outputs/population_at_risk_department_1_per_1000_PYO.csv")
admin1@data$MedianPopAtRisk <- as.numeric(pop.at.risk.summary[,2])
writeOGR(admin1, ".", "outputs/popatrisk_department_1_per_1000_PYO", driver="ESRI Shapefile",overwrite_layer = TRUE)

admin2 <- readOGR("adm/hti_admbnda_adm2_cnigs_20181129.shp")
pop.at.risk.summary <- list()
for (i in 1:length(admin2)) {
  in.admin.sector <- numeric(bigN)
  for (j in 1:length(admin2@polygons[[i]]@Polygons)) {
    in.admin.sector <- in.admin.sector + point.in.polygon(reference.coords[,1],reference.coords[,2],admin2@polygons[[i]]@Polygons[[j]]@coords[,1],admin2@polygons[[i]]@Polygons[[j]]@coords[,2])
  }
  pop.at.risk <- colSums((field.draws>log(1/1000))*matrix(rep(population,length(baseline_replicants)),ncol=length(baseline_replicants))*matrix(rep(in.admin.sector,length(baseline_replicants)),ncol=length(baseline_replicants)))
  pop.at.risk.median <- median(pop.at.risk)
  pop.at.risk.lower <- quantile(pop.at.risk,0.025)
  pop.at.risk.upper <- quantile(pop.at.risk,0.975)
  pop.at.risk.summary[[i]] <- c(pop.at.risk.median,pop.at.risk.lower,pop.at.risk.upper)  
}
pop.at.risk.summary <- do.call(rbind,pop.at.risk.summary)
pop.at.risk.summary <- cbind(as.character(admin2@data$ADM1_FR),as.character(pop.at.risk.summary[,1]),as.character(pop.at.risk.summary[,2]),as.character(pop.at.risk.summary[,3]))
colnames(pop.at.risk.summary) <- c("Commune","Median","Lower 95%","Upper 95%")
write.csv(pop.output,file="outputs/population_at_risk_commune_1_per_1000_PYO.csv")
admin2@data$MedianPopAtRisk <- as.numeric(pop.at.risk.summary[,2])
writeOGR(admin2, ".", "outputs/popatrisk_commune_1_per_1000_PYO", driver="ESRI Shapefile",overwrite_layer = TRUE)

nearest.hf <- matrix(NA,nrow=bigN,ncol=length(baseline_replicants))
for (i in 1:length(baseline_replicants)) {
  nearest.hf[,i] <- apply(as.matrix(baseline_replicants[[i]]$catchments),2,which.max)
}
getmode <- function(v) {
  uniqv <- unique(v)
  uniqv[which.max(tabulate(match(v, uniqv)))]
}
nearest.hf <- apply(nearest.hf,1,getmode)
buffer.image[in.country] <- runif(nHFs)[nearest.hf]
writeRaster(buffer.image,file="outputs/nearest_hf_visualization.tif",overwrite=TRUE)

length.nontrivial <- function(x) {length(which(x>0.1))}
n.local.hfs <- matrix(NA,nrow=bigN,ncol=length(baseline_replicants))
for (i in 1:length(baseline_replicants)) {
  n.local.hfs[,i] <- apply(as.matrix(baseline_replicants[[i]]$catchments),2,length.nontrivial)
}
getmode <- function(v) {
  uniqv <- unique(v)
  uniqv[which.max(tabulate(match(v, uniqv)))]
}
n.local.hfs <- apply(n.local.hfs,1,getmode)
buffer.image[in.country] <- n.local.hfs
writeRaster(buffer.image,file="outputs/nlocal_hf_visualization.tif",overwrite=TRUE)

maxpredpos <- maxpredneg <- matrix(NA,ncol=length(baseline_replicants),nrow=bigN)
for (i in 1:length(baseline_replicants)) {
  maxpredpos[,i] <- baseline_replicants[[i]]$maxpred_pos
  maxpredneg[,i] <- baseline_replicants[[i]]$maxpred_neg
}
buffer.image[in.country] <- apply(maxpredpos,1,getmode)
writeRaster(buffer.image,file="outputs/final_most_important_pos_covariate.tif",overwrite=TRUE)
buffer.image[in.country] <- apply(maxpredneg,1,getmode)
writeRaster(buffer.image,file="outputs/final_most_important_neg_covariate.tif",overwrite=TRUE)

maxpred <- maxpredsign <- matrix(NA,ncol=length(baseline_replicants),nrow=bigN)
for (i in 1:length(baseline_replicants)) {
  maxpred[,i] <- baseline_replicants[[i]]$maxpred
  maxpredsign[,i] <- baseline_replicants[[i]]$maxpredsign
}
buffer.image[in.country] <- apply(maxpred,1,getmode)
writeRaster(buffer.image,file="outputs/final_most_important_covariate.tif",overwrite=TRUE)
buffer.pix <- apply(maxpred,1,getmode)
for (i in 1:bigN) {buffer.pix[i] <- getmode(maxpredsign[i,maxpred[i,]==buffer.pix[i]])}
buffer.image[in.country] <- buffer.pix
writeRaster(buffer.image,file="outputs/final_sign_most_important_covariate.tif",overwrite=TRUE)

mean_slopes <- matrix(0,nrow=Ncovariates,ncol=length(baseline_replicants))
for (i in 1:(length(baseline_replicants)/2)) {mean_slopes[,i] <- baseline_replicants[[i*2-1]]$par.fixed[2:13]}
mean_slopes_mean <- rowMeans(mean_slopes)
mean_slopes_stddev <- sqrt(rowMeans(mean_slopes*mean_slopes)-rowMeans(mean_slopes)^2)
issignif <- as.integer(mean_slopes_mean+mean_slopes_stddev*3 < 0 | mean_slopes_mean-mean_slopes_stddev*3 > 0)
cov.stat <- cbind(covariate.names,as.character(mean_slopes_mean),as.character(mean_slopes_stddev),as.character(issignif))
colnames(cov.stat) <- c("Covariate Name","Post Mean Slope","Post Std Dev Slope","3 sig signif.?")
write.csv(cov.stat,file="outputs/covstats.csv")

save.image("postfit_static.dat")

agg.pops <- matrix(NA,nrow=nHFs,ncol=length(baseline_replicants))
for (i in 1:length(baseline_replicants)) {agg.pops[,i] <- as.numeric(baseline_replicants[[i]]$catchments%*%(population*treatment))}
med.pop <- upp.pop <- low.pop <- numeric(nHFs)
for (i in 1:nHFs) {
  med.pop[i] <- quantile(agg.pops[i,],0.5)
  low.pop[i] <- quantile(agg.pops[i,],0.025)
  upp.pop[i] <- quantile(agg.pops[i,],0.975)
}
pop.summ <- cbind(1:nHFs,med.pop,low.pop,upp.pop)
agg.pops <- matrix(NA,nrow=nHFs,ncol=length(baseline_replicants))
for (i in 1:length(baseline_replicants)) {agg.pops[,i] <- as.numeric(baseline_replicants[[i]]$catchments%*%(population))}
med.pop <- upp.pop <- low.pop <- numeric(nHFs)
for (i in 1:nHFs) {
  med.pop[i] <- quantile(agg.pops[i,],0.5)
  low.pop[i] <- quantile(agg.pops[i,],0.025)
  upp.pop[i] <- quantile(agg.pops[i,],0.975)
}
pop.summ <- cbind(pop.summ,med.pop,low.pop,upp.pop)
colnames(pop.summ) <- c("Aggregated HF ID Number","Posterior Median Catchment Pop Est (Will Seek Treatment)","Posterior 2.5% Catchment (Will Seek Treatment)","Upper 97.5% Catchment Pop Est (Will Seek Treatment)","Posterior Median Catchment Pop Est (Ignoring Treatment Seeking)","Posterior 2.5% Catchment (Ignoring Treatment Seeking)","Upper 97.5% Catchment Pop Est (Ignoring Treatment Seeking)")
write.csv(pop.summ,file="outputs/aggregated_catchment_pops.csv")

hf.agg <- cbind(as.character(hf.ids.reporting),as.character(hf.longlats.reporting[,1]),as.character(hf.longlats.reporting[,2]),as.character(clusterCut),as.character(hf.longlats.clustered[clusterCut,1]),as.character(hf.longlats.clustered[clusterCut,2]))
colnames(hf.agg) <- c("HF ID Code","HF Long","HF Lat","Aggregated HF ID Number","Agg HF Long","Agg HF Lat")
write.csv(hf.agg,file="outputs/hf_aggregation_codes.csv")

## Catchment visualisation

cases <- exp(raster("outputs/final_baseline_counts.tif")[in.country])*treatment
mean.catchments <- matrix(0,nrow=nHFs,ncol=bigN)
for (i in 1:length(baseline_replicants)) {mean.catchments <- mean.catchments+as.matrix(baseline_replicants[[i]]$catchments)/length(baseline_replicants)}

for (i in 1:nHFs) {
  catchx <- mean.catchments[i,]*cases
  catchx <- catchx/sum(catchx)
  values(buffer.image) <- NA
  buffer.image[in.country] <- catchx
  writeRaster(buffer.image,file=paste0("outputs/prob_case_origin_given_reported_at_aggregated_HF_number_",sprintf("%03i",i),".tif"))
}

all.complete.journeys.list <- list()
all.complete.journeys.weights.list <- list()
all.complete.journeys.ids.list <- list()

for (facility.num in 1:nHFs) {

expected.casepix <- mean.catchments[facility.num,]*cases
casepix <- which(expected.casepix > 0)
weights <- expected.casepix[expected.casepix > 0]
# normweights <- weights/sum(weights)
casepix <- casepix[which(weights > 0.1)]
weights <- weights[which(weights > 0.1)]
# weights <- sort(weights,decreasing=TRUE)
# normweights <- sort(normweights,decreasing=TRUE)
# casepix <- casepix[cumsum(weights)<0.99]
# weights <- weights[cumsum(normweights)<0.99]
# normweights <- normweights[cumsum(normweights)<0.99]

if (length(casepix) > 1) {

journeys <- list()
for (i in 1:length(weights)) {
  buffer <- shortestPath(T.GC, as.numeric(xyFromCell(reference.image,in.country[casepix[i]])), as.numeric(hf.longlats.clustered[facility.num,]), output = "SpatialLines")@lines[[1]]
  buffer@ID <- as.character(i)
  journeys[[i]] <- buffer
}
new.journeys <- list()
new.weights <- list()
for (i in 1:length(journeys)) {
  if (length(journeys[[i]]@Lines[[1]]@coords)>2) {
    new.journeys[[length(new.journeys)+1]] <- journeys[[i]]
    new.weights[[length(new.weights)+1]] <- weights[i]
  }
}
journeys <- new.journeys
weights <- as.numeric(new.weights)

complete.journeys <- list()
complete.journeys.weights <- list()
complete.journeys[[1]] <- journeys[[1]]@Lines[[1  Error when reading the variable: 'log_shrinkage_temporal_slopes'. Please check data and parameters.
]]@coords[1:2,]
complete.journeys.weights[[1]] <- weights[1]
for (i in 1:length(journeys)) {
  if (length(journeys[[i]]@Lines[[1]]@coords)>2) {
  for (j in 1:(length(journeys[[i]]@Lines[[1]]@coords[,1])-1)) {
    in.complete.journeys <- 0
    for (k in 1:length(complete.journeys)) {
      if (prod(journeys[[i]]@Lines[[1]]@coords[(j):(j+1),]==complete.journeys[[k]])) {
        complete.journeys.weights[[k]] <- complete.journeys.weights[[k]] + weights[i]
        in.complete.journeys <- 1
      }
    }
    if (in.complete.journeys==0) {
      complete.journeys[[length(complete.journeys)+1]] <- journeys[[i]]@Lines[[1]]@coords[(j):(j+1),]
      complete.journeys.weights[[length(complete.journeys.weights)+1]] <- weights[i]
    }
  }}
}
complete.journeys.sl <- list()
for (i in 1:length(complete.journeys)) {
  complete.journeys.sl[[i]] <- Lines(Line(complete.journeys[[i]]), ID=as.numeric(facility.num*1000+i))
}
all.complete.journeys.list[[length(all.complete.journeys.list)+1]] <- complete.journeys.sl
all.complete.journeys.weights.list[[length(all.complete.journeys.weights.list)+1]] <- as.numeric(complete.journeys.weights)
all.complete.journeys.ids.list[[length(all.complete.journeys.ids.list)+1]] <- rep(facility.num,length(complete.journeys.weights))
cat(facility.num,"\n")
}}
sl_obj <- SpatialLines(unlist(all.complete.journeys.list))
ids <- data.frame()
for (i in (1:length(sl_obj))) {
  id <- data.frame(sl_obj@lines[[i]]@ID)
  ids <- rbind(ids, id)
}
colnames(ids)[1] <- "linkId"
row.names(ids) <- ids$linkId
splndf <- SpatialLinesDataFrame(sl_obj, data = ids, match.ID = TRUE)
splndf$weights <- log(unlist(all.complete.journeys.weights.list))
HFids <- runif(450)
splndf$cIds <- HFids[unlist(all.complete.journeys.ids.list)]
library(rgdal)
writeOGR(splndf, dsn="." ,layer="journeys",driver="ESRI Shapefile",overwrite_layer = TRUE)
hfcols <- cbind(xyFromCell(reference.image,cellFromXY(reference.image,hf.longlats.clustered)),HFids)
colnames(hfcols) <- c("long","lat","colid")
write.csv(hfcols,"journeyHFs.csv")

all.complete.journeys.list <- list()
all.complete.journeys.weights.list <- list()
all.complete.journeys.ids.list <- list()

for (facility.num in 1:nHFs) {
  
  expected.casepix <- mean.catchments[facility.num,]
  casepix <- which(expected.casepix > 0)
  weights <- (population*treatment)[casepix]
  normweights <- weights/sum(weights)
  normweights <- sort(normweights,decreasing=TRUE)
  casepix <- casepix[cumsum(normweights)<0.90]
  normweights <- normweights[cumsum(normweights)<0.90]
  weights <- normweights
    
  if (length(casepix) > 1) {
    
    journeys <- list()
    for (i in 1:length(weights)) {
      buffer <- shortestPath(T.GC, as.numeric(xyFromCell(reference.image,in.country[casepix[i]])), as.numeric(hf.longlats.clustered[facility.num,]), output = "SpatialLines")@lines[[1]]
      buffer@ID <- as.character(i)
      journeys[[i]] <- buffer
    }
    new.journeys <- list()
    new.weights <- list()
    for (i in 1:length(journeys)) {
      if (length(journeys[[i]]@Lines[[1]]@coords)>2) {
        new.journeys[[length(new.journeys)+1]] <- journeys[[i]]
        new.weights[[length(new.weights)+1]] <- weights[i]
      }
    }
    journeys <- new.journeys
    weights <- as.numeric(new.weights)
    
    complete.journeys <- list()
    complete.journeys.weights <- list()
    complete.journeys[[1]] <- journeys[[1]]@Lines[[1]]@coords[1:2,]
    complete.journeys.weights[[1]] <- 0#weights[1]
    for (i in 1:length(journeys)) {
      if (length(journeys[[i]]@Lines[[1]]@coords)>2) {
        for (j in 1:(length(journeys[[i]]@Lines[[1]]@coords[,1])-1)) {
          in.complete.journeys <- 0
          for (k in 1:length(complete.journeys)) {
            if (prod(journeys[[i]]@Lines[[1]]@coords[(j):(j+1),]==complete.journeys[[k]])) {
              complete.journeys.weights[[k]] <- complete.journeys.weights[[k]] + weights[i]
              in.complete.journeys <- 1
            }
          }
          if (in.complete.journeys==0) {
            complete.journeys[[length(complete.journeys)+1]] <- journeys[[i]]@Lines[[1]]@coords[(j):(j+1),]
            complete.journeys.weights[[length(complete.journeys.weights)+1]] <- weights[i]
          }
        }}
    }
    complete.journeys.sl <- list()
    for (i in 1:length(complete.journeys)) {
      complete.journeys.sl[[i]] <- Lines(Line(complete.journeys[[i]]), ID=as.numeric(facility.num*1000+i))
    }
    all.complete.journeys.list[[length(all.complete.journeys.list)+1]] <- complete.journeys.sl
    all.complete.journeys.weights.list[[length(all.complete.journeys.weights.list)+1]] <- as.numeric(complete.journeys.weights)
    all.complete.journeys.ids.list[[length(all.complete.journeys.ids.list)+1]] <- rep(facility.num,length(complete.journeys.weights))
    cat(facility.num,"\n")
  }}
sl_obj <- SpatialLines(unlist(all.complete.journeys.list))
ids <- data.frame()
for (i in (1:length(sl_obj))) {
  id <- data.frame(sl_obj@lines[[i]]@ID)
  ids <- rbind(ids, id)
}
colnames(ids)[1] <- "linkId"
row.names(ids) <- ids$linkId
splndf <- SpatialLinesDataFrame(sl_obj, data = ids, match.ID = TRUE)
splndf$weights <- log(unlist(all.complete.journeys.weights.list))
splndf$cIds <- HFids[unlist(all.complete.journeys.ids.list)]
library(rgdal)
writeOGR(splndf, dsn="." ,layer="alljourneys",driver="ESRI Shapefile",overwrite_layer = TRUE)

## Step 10: Fit seasonality model

library(TMB)
compile("haiti_seasonality_lags.cpp",flags="-Ofast -Wno-errors")
dyn.load(dynlib("haiti_seasonality_lags"))

seasonality_outputs <- list()
for (k in 1:(length(baseline_replicants)/2)) {
  
input.data <- list(
  'bigN'=bigN,
  'evi_covariate_matrix'= evi.covariate.matrix-matrix(rep(rowMeans(evi.covariate.matrix),12),nrow=bigN),
  'LSTday_covariate_matrix'= LSTday.covariate.matrix-matrix(rep(rowMeans(LSTday.covariate.matrix),12),nrow=bigN),
  'LSTdiff_covariate_matrix'= LSTdiff.covariate.matrix-matrix(rep(rowMeans(LSTdiff.covariate.matrix),12),nrow=bigN),
  'TCB_covariate_matrix'= TCB.covariate.matrix-matrix(rep(rowMeans(TCB.covariate.matrix),12),nrow=bigN),
  'TCW_covariate_matrix'= TCW.covariate.matrix-matrix(rep(rowMeans(TCW.covariate.matrix),12),nrow=bigN),
  'spde'=haiti.spde.coarse,
  'A'=haiti.A.coarse,
  'nHFs'=nHFs,
  'population'=population,
  'HFcases_matrix'=as.matrix(aggregate(monthly.case.matrix.detrended.list[[k]],list(clusterCut),sum)[2:13]),
  'log_masses'=baseline_replicants[[k*2-1]]$baseline_log_masses,
  'baseline_surface'=baseline_replicants[[k*2-1]]$field.draws,
  'Nunwrapped'=length(invdists),
  'invdists'=invdists,
  'hf_ids'=hf.ids,
  'pixel_ids'=pixel.ids,
  'treatment'=treatment,
  'intermonth_distmat'=intermonth.distmat
)

parameters <- list(
  'intercept'=0,
  'log_overdispersion_field_sd'=-2,
  'log_shrinkage_temporal_slopes'=0,
  'evi_slope_lag0'=0,
  'evi_slope_lag1'=0,
  'evi_slope_lag2'=0,
  'LSTday_slope_lag0'=0,
  'LSTday_slope_lag1'=0,
  'LSTday_slope_lag2'=0,
  'LSTdiff_slope_lag0'=0,
  'LSTdiff_slope_lag1'=0,
  'LSTdiff_slope_lag2'=0,
  'TCB_slope_lag0'=0,
  'TCB_slope_lag1'=0,
  'TCB_slope_lag2'=0,
  'TCW_slope_lag0'=0,
  'TCW_slope_lag1'=0,
  'TCW_slope_lag2'=0,
  'log_range_seasonal'=-2,
  'log_sd_seasonal'=3,
  'log_range_seasonaltime'=-2,
  'log_sd_seasonaltime'=3,
  'field_seasonal'=matrix(0,nrow=haiti.mesh.coarse$n,ncol=12)
)

obj <- MakeADFun(input.data,parameters,DLL="haiti_seasonality_lags",random=c('intercept','evi_slope_lag0','evi_slope_lag1','evi_slope_lag2','LSTday_slope_lag0','LSTday_slope_lag1','LSTday_slope_lag2','LSTdiff_slope_lag0','LSTdiff_slope_lag1','LSTdiff_slope_lag2','TCB_slope_lag0','TCB_slope_lag1','TCB_slope_lag2','TCW_slope_lag0','TCW_slope_lag1','TCW_slope_lag2','field_seasonal'))
obj$fn()
 # if (length(seasonality_outputs)==0) {
   opt <- nlminb(obj$par,obj$fn,obj$gr,control=list(iter.max=300,eval.max=300)) 
 # } else {
 #  opt <- nlminb(opt$par,obj$fn,obj$gr,control=list(iter.max=300,eval.max=300))
 # }
rep <- sdreport(obj,getJointPrecision = TRUE)
#save(opt,rep,file="seasonality_optrep_final.dat")

parnames <- unique(names(rep$jointPrecision[1,]))
for (i in 1:length(parnames)) {
  eval(parse(text=(paste("parameters$",parnames[i]," <- c(rep$par.fixed,rep$par.random)[names(c(rep$par.fixed,rep$par.random))==\"",parnames[i],"\"]",sep=""))))}

library(sparseMVN)
library(Matrix)
r.draws <- rmvn.sparse(3,unlist(parameters),Cholesky(rep$jointPrecision),prec=TRUE)

for (z in 1:3) {
  outputs <- obj$report(r.draws[z,])
  seasonality_outputs[[length(seasonality_outputs)+1]] <- list()
  seasonality_outputs[[length(seasonality_outputs)]]$predicted_surface <- outputs$predicted_surface_malaria
  seasonality_outputs[[length(seasonality_outputs)]]$rep <- rep
}
save(seasonality_outputs,file="seasonality_outputs.dat")
}
  
field.draws1 <- field.draws2 <- field.draws3 <- field.draws4 <- field.draws5 <- field.draws6 <- field.draws7 <- field.draws8 <- field.draws9 <- field.draws10 <- field.draws11 <- field.draws12 <- matrix(0,nrow=bigN,ncol=length(seasonality_outputs))
for (i in 1:length(seasonality_outputs)) {
  field.draws1[,i] <- seasonality_outputs[[i]]$predicted_surface[,1]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws2[,i] <- seasonality_outputs[[i]]$predicted_surface[,2]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws3[,i] <- seasonality_outputs[[i]]$predicted_surface[,3]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws4[,i] <- seasonality_outputs[[i]]$predicted_surface[,4]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws5[,i] <- seasonality_outputs[[i]]$predicted_surface[,5]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws6[,i] <- seasonality_outputs[[i]]$predicted_surface[,6]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws7[,i] <- seasonality_outputs[[i]]$predicted_surface[,7]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws8[,i] <- seasonality_outputs[[i]]$predicted_surface[,8]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws9[,i] <- seasonality_outputs[[i]]$predicted_surface[,9]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws10[,i] <- seasonality_outputs[[i]]$predicted_surface[,10]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws11[,i] <- seasonality_outputs[[i]]$predicted_surface[,11]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  field.draws12[,i] <- seasonality_outputs[[i]]$predicted_surface[,12]-rowMeans(seasonality_outputs[[i]]$predicted_surface)
  cat(i,"\n")
}

pointwise.mean.field1 <- apply(field.draws1,1,mean)
pointwise.mean.field2 <- apply(field.draws2,1,mean)
pointwise.mean.field3 <- apply(field.draws3,1,mean)
pointwise.mean.field4 <- apply(field.draws4,1,mean)
pointwise.mean.field5 <- apply(field.draws5,1,mean)
pointwise.mean.field6 <- apply(field.draws6,1,mean)
pointwise.mean.field7 <- apply(field.draws7,1,mean)
pointwise.mean.field8 <- apply(field.draws8,1,mean)
pointwise.mean.field9 <- apply(field.draws9,1,mean)
pointwise.mean.field10 <- apply(field.draws10,1,mean)
pointwise.mean.field11 <- apply(field.draws11,1,mean)
pointwise.mean.field12 <- apply(field.draws12,1,mean)

buffer.image[in.country] <- pointwise.mean.field1
writeRaster(buffer.image,file="outputs/final_seasonality_effect1.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field2
writeRaster(buffer.image,file="outputs/final_seasonality_effect2.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field3
writeRaster(buffer.image,file="outputs/final_seasonality_effect3.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field4
writeRaster(buffer.image,file="outputs/final_seasonality_effect4.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field5
writeRaster(buffer.image,file="outputs/final_seasonality_effect5.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field6
writeRaster(buffer.image,file="outputs/final_seasonality_effect6.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field7
writeRaster(buffer.image,file="outputs/final_seasonality_effect7.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field8
writeRaster(buffer.image,file="outputs/final_seasonality_effect8.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field9
writeRaster(buffer.image,file="outputs/final_seasonality_effect9.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field10
writeRaster(buffer.image,file="outputs/final_seasonality_effect10.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field11
writeRaster(buffer.image,file="outputs/final_seasonality_effect11.tif",overwrite=TRUE)
buffer.image[in.country] <- pointwise.mean.field12
writeRaster(buffer.image,file="outputs/final_seasonality_effect12.tif",overwrite=TRUE)

max.month.draws <- matrix(0,nrow=bigN,ncol=length(seasonality_outputs))
for (i in 1:length(seasonality_outputs)) {
  max.month.draws[,i] <- apply(seasonality_outputs[[i]]$predicted_surface,1,which.max)
}
max.month.draws <- apply(max.month.draws,1,getmode)
buffer.image[in.country] <- max.month.draws
writeRaster(buffer.image,file="outputs/final_seasonality_peak_month.tif",overwrite=TRUE)

min.month.draws <- matrix(0,nrow=bigN,ncol=length(seasonality_outputs))
for (i in 1:length(seasonality_outputs)) {
  min.month.draws[,i] <- apply(seasonality_outputs[[i]]$predicted_surface,1,which.min)
}
min.month.draws <- apply(min.month.draws,1,getmode)
buffer.image[in.country] <- min.month.draws
writeRaster(buffer.image,file="outputs/final_seasonality_trough_month.tif",overwrite=TRUE)

seasonality.amplitude <- matrix(0,nrow=bigN,ncol=length(seasonality_outputs))
for (i in 1:length(seasonality_outputs)) {
  seasonality.amplitude[,i] <- apply(seasonality_outputs[[i]]$predicted_surface,1,max)-apply(seasonality_outputs[[i]]$predicted_surface,1,min)
}
seasonality.amplitude <- rowMeans(seasonality.amplitude)
buffer.image[in.country] <- seasonality.amplitude
writeRaster(buffer.image,file="outputs/final_seasonality_amplitude.tif",overwrite=TRUE)

sinx <- sin(1:12/12*2*pi)
sin2x <- sin(1:12/12*2*2*pi)
cosx <- cos(1:12/12*2*pi)
cos2x <- cos(1:12/12*2*2*pi)
xd <- cbind(sinx,sin2x,cosx,cos2x)

pointwise.mean.matrix <- cbind(pointwise.mean.field1,pointwise.mean.field2,pointwise.mean.field3,pointwise.mean.field4,pointwise.mean.field5,pointwise.mean.field6,pointwise.mean.field7,pointwise.mean.field8,pointwise.mean.field9,pointwise.mean.field10,pointwise.mean.field11,pointwise.mean.field12)
AB <- matrix(nrow=bigN,ncol=2)
for (i in 1:bigN) {
  xfit <- lm(pointwise.mean.matrix[i,]~xd)
  AB[i,1] <- max(abs(xd[,c(1,3)]%*%xfit$coefficients[c(2,4)]))
  AB[i,2] <- max(abs(xd[,c(2,4)]%*%xfit$coefficients[c(3,5)]))
}
biphasic.factor <- AB[,2]/(AB[,1]+AB[,2])
buffer.image[in.country] <- biphasic.factor
writeRaster(buffer.image,file="outputs/final_seasonality_biphasic_factor.tif",overwrite=TRUE)

mean.monthly.cases.imputed <- matrix(0,nrow=nHFs,ncol=12)
for (i in 1:23) {mean.monthly.cases.imputed = mean.monthly.cases.imputed + aggregate(monthly.case.matrix.detrended.list[[i]],list(clusterCut),sum)/23}
write.csv(cbind(hf.longlats.clustered,mean.monthly.cases.imputed),file="outputs/monthly_imputed.csv")

mean.slopes <- sd.slopes <- numeric(15)
for (i in 1:23) {
  mean.slopes <- mean.slopes + seasonality_outputs[[i]]$rep$par.random[2:16]/23
  sd.slopes <- sd.slopes + 1/seasonality_outputs[[i]]$rep$jointPrecision[cbind(4:18,4:18)]
}
sd.slopes <- sqrt(sd.slopes/23)
summ.slopes <- cbind(names(seasonality_outputs[[1]]$rep$jointPrecision[1,])[4:18],as.character(mean.slopes),as.character(sd.slopes),as.character(as.integer(mean.slopes+3*sd.slopes < 0 | mean.slopes-3*sd.slopes >0)))
summ.slopes <- summ.slopes[sort.list(abs(mean.slopes),decreasing=TRUE),]
write.csv(summ.slopes,file="outputs/temporal_slopes.dat")

greatest.temporal.covariate <- matrix(NA,nrow=bigN,ncol=23)
for (i in 1:23) {
  pmatrix <- cbind(evi.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[2],evi.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[3],evi.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[4])
  pmatrix <- cbind(pmatrix,LSTday.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[5],LSTday.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[6],LSTday.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[7])
  pmatrix <- cbind(pmatrix,LSTdiff.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[8],LSTdiff.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[9],LSTdiff.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[10])
  pmatrix <- cbind(pmatrix,TCB.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[11],TCB.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[12],TCB.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[13])
  pmatrix <- cbind(pmatrix,TCW.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[14],TCW.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[15],TCW.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[16])
  greatest.temporal.covariate[,i] <- apply(pmatrix,1,which.max)
}
greatest.temporal.covariate <- apply(greatest.temporal.covariate,1,getmode)
greatest.temporal.covariate <- rep(1:15,each=12)[greatest.temporal.covariate]
buffer.image[in.country] <- greatest.temporal.covariate
writeRaster(buffer.image,file="outputs/max_temporal_cov.tif",overwrite=TRUE)

least.temporal.covariate <- matrix(NA,nrow=bigN,ncol=23)
for (i in 1:23) {
  pmatrix <- cbind(evi.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[2],evi.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[3],evi.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[4])
  pmatrix <- cbind(pmatrix,LSTday.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[5],LSTday.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[6],LSTday.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[7])
  pmatrix <- cbind(pmatrix,LSTdiff.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[8],LSTdiff.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[9],LSTdiff.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[10])
  pmatrix <- cbind(pmatrix,TCB.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[11],TCB.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[12],TCB.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[13])
  pmatrix <- cbind(pmatrix,TCW.covariate.matrix*seasonality_outputs[[i]]$rep$par.random[14],TCW.covariate.matrix[,(1:12-2) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[15],TCW.covariate.matrix[,(1:12-3) %% 12 + 1]*seasonality_outputs[[i]]$rep$par.random[16])
  least.temporal.covariate[,i] <- apply(pmatrix,1,which.min)
}
least.temporal.covariate <- apply(least.temporal.covariate,1,getmode)
least.temporal.covariate <- rep(1:15,each=12)[least.temporal.covariate]
buffer.image[in.country] <- least.temporal.covariate
writeRaster(buffer.image,file="outputs/min_temporal_cov.tif",overwrite=TRUE)

## Step 10: Model validation against TAS

tas.data <- read.csv("Haiti TAS 2015 2016 ALL CHILDREN malaria EU 1 to 23 and RDT.csv")
valid.tas.sample <- which(!is.na(tas.data$MSP1pos) & !is.na(tas.data$AMA1pos) & !is.na(tas.data$School.Code)  & !is.na(tas.data$GPS_long) & !is.na(tas.data$GPS_Lat) & tas.data$AGE <= 7 & tas.data$AGE >=5 )
tas.schools <- unique(tas.data$School)
n.tas.schools <- length(tas.schools)
tas.longlats <- cbind(tas.data$GPS_long,tas.data$GPS_Lat)[valid.tas.sample,]

tas.cells <- cellFromXY(reference.image,tas.longlats)
unique.tas.cells <- unique(tas.cells)
invalid.cells <- unique.tas.cells[which(!(unique.tas.cells %in% in.country))]
n.invalid <- length(invalid.cells)
invalid.lookup <- cbind(invalid.cells,rep(0,n.invalid))
for (i in 1:n.invalid) {
  cell.replacement <- in.country[which.min((in.country.coords[,1]-xyFromCell(reference.image,invalid.cells[i])[1])^2+(in.country.coords[,2]-xyFromCell(reference.image,invalid.cells[i])[2])^2)]
  invalid.lookup[i,2] <- cell.replacement
  unique.tas.cells[unique.tas.cells==invalid.lookup[i,1]] <- invalid.lookup[i,2] 
  tas.cells[tas.cells==invalid.lookup[i,1]] <- invalid.lookup[i,2] 
}
unique.tas.cells <- unique(unique.tas.cells)
n.tas.cells <- length(unique.tas.cells)

tas.any.pos <- tas.any.neg <- numeric(n.tas.cells)
for (i in 1:n.tas.cells) {
  tas.any.pos[i] <- length(which(tas.data$AMA1pos[valid.tas.sample[tas.cells==unique.tas.cells[i]]]==1 | tas.data$MSP1pos[valid.tas.sample[tas.cells==unique.tas.cells[i]]]==1))
  tas.any.neg[i] <- length(which(tas.data$AMA1pos[valid.tas.sample[tas.cells==unique.tas.cells[i]]]==0 & tas.data$MSP1pos[valid.tas.sample[tas.cells==unique.tas.cells[i]]]==0 ))
}

tas.predicted <- raster("outputs/final_baseline.tif")[unique.tas.cells]
tas.predicted.stddev <- raster("outputs/final_baseline_stddev.tif")[unique.tas.cells]
tas.predicted <- tas.predicted+log(1000)

tas.bins.low <- (-4):(5)
tas.bins.high <- (-3):(6)
tas.bins.mid <- tas.bins.low/2+tas.bins.high/2
tas.bins.pos <- tas.bins.neg <- tas.bins.sd <- tas.bins.mean <- numeric(length(tas.bins.mid))
for (i in 1:length(tas.bins.pos)) {
  tas.bins.pos[i] <- sum(tas.any.pos[tas.predicted >= tas.bins.low[i] & tas.predicted < tas.bins.high[i]])
  tas.bins.neg[i] <- sum(tas.any.neg[tas.predicted >= tas.bins.low[i] & tas.predicted < tas.bins.high[i]])
  tas.bins.mean[i] <- sum((tas.predicted/tas.predicted.stddev^2)[tas.predicted >= tas.bins.low[i] & tas.predicted < tas.bins.high[i]])/sum((1/tas.predicted.stddev^2)[tas.predicted >= tas.bins.low[i] & tas.predicted < tas.bins.high[i]])
  tas.bins.sd[i] <- sqrt(mean((tas.predicted.stddev^2)[tas.predicted >= tas.bins.low[i] & tas.predicted < tas.bins.high[i]]))/sqrt(length(which(tas.predicted >= tas.bins.low[i] & tas.predicted < tas.bins.high[i])))
}

library(extrafont)
loadfonts()
png("~/Desktop/test.png",width=12,height=8,units='in',res=600)

logit <- function(x) {log(x/(1-x))}
par(mai=c(1.5,1.5,0.1,0.1))
plot(-10000,-1000,xlim=c(-4,6),ylim=c(-3.5,1),xlab="",ylab="",xaxt='n',yaxt='n',family="Merriweather Sans")
for (i in 1:n.tas.cells) {
  lines(c(tas.predicted[i]-tas.predicted.stddev[i]*1.96,tas.predicted[i]+tas.predicted.stddev[i]*1.96),logit(rep(qbeta(0.5,tas.any.pos[i]+0.5,tas.any.neg[i]+0.5),2)),col=hsv(0.8,alpha=0.2),lwd=1.5)
  lines(rep(tas.predicted[i],2),logit(qbeta(c(0.025,0.975),tas.any.pos[i]+0.5,tas.any.neg[i]+0.5)),col=hsv(0.8,alpha=0.1),lwd=1.5)
}
for (i in 1:(length(tas.bins.pos)-2)) {
  lines(rep(tas.bins.mean[i],2),logit(qbeta(c(0.025,0.975),tas.bins.pos[i]+0.5,tas.bins.neg[i]+0.5)),col=hsv(0.66,v=0.8),lwd=2)
  points(tas.bins.mean[i],logit(qbeta(0.5,tas.bins.pos[i]+0.5,tas.bins.neg[i]+0.5)),pch=19,cex=1,col=hsv(0.66,v=0.8))
  lines(c(tas.bins.mean[i]-1.96*tas.bins.sd[i],tas.bins.mean[i]+1.96*tas.bins.sd[i]),rep(logit(qbeta(0.5,tas.bins.pos[i]+0.5,tas.bins.neg[i]+0.5)),2),col=hsv(0.66,v=0.8),lwd=2)
}
abline(lm(logit(qbeta(0.5,tas.bins.pos+0.5,tas.bins.neg+0.5))~tas.bins.mid),col=hsv(0.66,v=0.8),lty=2,lwd=2)
axis(1,at=log(c(0.01,0.1,1,10,100)),labels=c(0.01,0.1,1,10,100),tck=-0.0125,padj=-0.25,family="Merriweather Sans",cex.axis=1.3,lwd=1.5)
axis(2,at=logit(c(0.01,0.05,0.1,0.5)),labels=c("0.01 ","0.05 ",0.1,0.5),tck=-0.0125,hadj=0.75,las=2,family="Merriweather Sans",cex.axis=1.3,lwd=1.5)
#mtext("Predicted Incidence Rate (Cases Per 1000 PYO) [Model]",side=1,line=3,family="Merriweather Sans",cex=2)
#mtext("Estimated Sero-Prevalence [TAS]",side=2,line=4,family="Merriweather Sans",cex=2)
box(lwd=1.5)
dev.off()

serop <- cbind(xyFromCell(reference.image,unique.tas.cells),logit(qbeta(0.5,tas.any.pos+0.5,tas.any.neg+0.5)))
colnames(serop) <- c("Long","Lat","Logit P")
write.csv(serop,file="outputs/serop.csv")

library(boot)

#tas.predicted <- raster("../Haiti2019 - Linear Distance/outputs/final_baseline_lineardist.tif")[unique.tas.cells]

x <- tas.predicted
y <- qbeta(0.5,tas.any.pos+0.5,tas.any.neg+0.5)

n = length(x)
Brep = 10000

xy <- data.frame(cbind(x,y))
xy
xcor <- function(x,f) {cor(x[f,])[1,2]}

bootcorr <- boot(data=xy,statistic=xcor,R=Brep)
bootcorr  
boot.ci(bootcorr,conf=.95)
# Preferred Model
#     original       bias    std. error
#t1*0.4230968 0.0001728913  0.03641959
# Linear Dist
# original       bias    std. error
# t1* 0.4071753 0.0003404935  0.04484396
# Flat Treatment
# original        bias    std. error
# t1* 0.4108078 -1.819288e-05  0.04375269

seasonality_2018 <- read.csv("Phase I Seasonality MAP GA 2017-18.csv")
unique.hfs <- unique(seasonality_2018$hf_code_final[as.character(seasonality_2018$hf_code_final)!="."])
unique.hfs <- unique.hfs[!is.na(unique.hfs)]
unique.hfs <- unique.hfs[unique.hfs %in% hf.ids.reporting]
n.unique <- length(unique.hfs)
counts2018 <- matrix(0,nrow=n.unique,ncol=52)
for (i in 1:n.unique) {
  for (j in 1:52) {
    if (length(which(seasonality_2018$epiweek==j & seasonality_2018$hf_code_final==unique.hfs[i]))>0) {
      counts2018[i,j] <- sum(seasonality_2018$caseconfirmed[seasonality_2018$epiweek==j & seasonality_2018$hf_code_final==unique.hfs[i]],na.rm=TRUE)}
  }
}
agg.2018.counts <- aggregate(counts2018,list(clusterCut[match(unique.hfs,hf.ids.reporting)]),sum)[2:53]

matched.facs <- unique(clusterCut[match(unique.hfs,hf.ids.reporting)])
post.draws.counts <- list()
for (i in 1:10) {post.draws.counts[[i]] <- (baseline_replicants[[i]]$catchments[matched.facs,]%*%(exp(seasonality_outputs[[(i-1)*3+1]]$predicted_surface)/12*treatment*population))*
  matrix(rep((rowSums(aggregate(monthly.case.matrix.detrended.list[[i]],list(clusterCut),sum)[2:13])+
      as.numeric(baseline_replicants[[i*2-1]]$catchments%*%(exp(baseline_replicants[[i*2-1]]$field.draws)*population*treatment)))[matched.facs]
    ,12),ncol=12)} 
median.cases <- upper.cases <- lower.cases <- matrix(0,nrow=7,ncol=12)
for (i in 1:7) {
  for (j in 1:12) {
    buffer <- numeric(10)
    for (k in 1:10) {buffer[k] <- rnbinom(1,mu=post.draws.counts[[k]][,j],size=post.draws.counts[[k]][,j]/exp(baseline_replicants[[k]]$par.fixed[5]))} ### correct for new NB structure!!!!
    median.cases[i,j] <- quantile(buffer,0.5)
    lower.cases[i,j] <- quantile(buffer,0.025)
    upper.cases[i,j] <- quantile(buffer,0.975)
  }
}

day.months <- c(rep(1,31),rep(2,28),rep(3,31),rep(4,30),rep(5,31),rep(6,30),rep(7,31),rep(8,31),rep(9,30),rep(10,31),rep(11,30),rep(12,31))[1:(7*52)+1]
day.weeks <- rep(1:52,each=7)
week.month.matrix <- matrix(0,nrow=12,ncol=52)
for (i in 1:12) {for (j in 1:52) {
  week.month.matrix[i,j] <- length(which(day.months==i & day.weeks==j))/7
}}
agg.2018.counts.monthly <- as.matrix(agg.2018.counts)%*%t(week.month.matrix)

layout(1:8)
par(mai=c(0,0.5,0,0.05))
for (i in 1:7) {
  plot(1:12,median.cases[i,]^0.5,ylim=c(0,max(c(upper.cases[i,],agg.2018.counts.monthly[i,])))^0.5,col="magenta",xlab="",ylab="",xaxt='n',yaxt='n',pch=19,cex=2)
  for (k in 1:12) {lines(rep(k,2),c(lower.cases[i,k],upper.cases[i,k])^0.5,col="magenta",lwd=2)}
  points(1:12,agg.2018.counts.monthly[i,]^0.5,col="blue",pch=21)
}

median.cases <- upper.cases <- lower.cases <- numeric(12)
  for (j in 1:12) {
    buffer <- numeric(10)
    for (k in 1:10) {buffer[k] <- sum(rnbinom(7,mu=post.draws.counts[[k]][,j],size=post.draws.counts[[k]][,j]/exp(baseline_replicants[[k]]$par.fixed[5])))}
    median.cases[j] <- quantile(buffer,0.5)
    lower.cases[j] <- quantile(buffer,0.16)
    upper.cases[j] <- quantile(buffer,0.84)
  }

library(extrafont)
loadfonts()
png("~/Desktop/testx.png",width=12,height=8,units='in',res=600)

par(mai=c(1.5,1.5,0.1,0.1))
plot(1:12,(median.cases)^0.5,ylim=c(0,30),col="magenta",xlab="",ylab="",xaxt='n',yaxt='n',pch=19,cex=2,family="Merriweather Sans")
for (k in 1:12) {lines(rep(k,2),(c(lower.cases[k],upper.cases[k]))^0.5,col="magenta",lwd=2)}
points(1:12,colSums(agg.2018.counts.monthly)^0.5,col="blue",pch=21,cex=2,lwd=2)
points(1:12,(colSums(agg.2018.counts.monthly)*exp(-mean(log(colSums(agg.2018.counts.monthly)))+mean(log(median.cases))))^0.5,col="darkgreen",pch=22,cex=2,lwd=2)
axis(1,at=1:12,labels=c("Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"),tck=-0.0125,padj=-0.25,family="Merriweather Sans",cex.axis=1.6,lwd=1.5)
axis(2,at=sqrt(c(0,5,25,75,125,250,500,1000)),labels=c(0,5,25,75,125,250,500,1000),tck=-0.0125,hadj=0.75,las=2,family="Merriweather Sans",cex.axis=1.6,lwd=1.5)
mtext("Month",side=1,line=3,family="Merriweather Sans",cex=2)
mtext("Cases Reported Per Month",side=2,line=4,family="Merriweather Sans",cex=2)
legend("top",c("Posterior Predictive [2014-2019 Fit]","Observed 2018","Observed 2018 - Mean Shifted to 2019 Benchmark"),pch=c(19,21,22),pt.lwd=c(1,2,2),lwd=c(1.5,-1,-1),pt.cex=2,col=c("magenta","blue","darkgreen"),bty='n',cex=2)
box(lwd=1.5)
dev.off()

sinx <- sin(1:52/52*2*pi)
sin2x <- sin(1:52/52*2*2*pi)
cosx <- cos(1:52/52*2*pi)
cos2x <- cos(1:52/52*2*2*pi)
xd <- cbind(sinx,sin2x,cosx,cos2x)

seasonfac <- numeric(10)
for (i in 1:10) {
  xfit <- lm(counts2018[i,]~xd)
  seasonfac[i] <- 1-max(abs(xd[,c(1,3)]%*%xfit$coefficients[c(2,4)]))/(max(abs(xd[,c(1,3)]%*%xfit$coefficients[c(2,4)]))+max(abs(xd[,c(2,4)]%*%xfit$coefficients[c(3,5)])))
}
sumd <- cbind(hf.longlats.reporting[match(unique.hfs,hf.ids.reporting),],rowSums(counts2018),seasonfac)
colnames(sumd) <- c("long","lat","count","seasonality")
write.csv(sumd,file="outputs/2018comparisondata.csv")

AB <- matrix(nrow=bigN,ncol=2)
for (i in 1:bigN) {
  xfit <- lm(pointwise.mean.matrix[i,]~xd)
  AB[i,1] <- max(abs(xd[,c(1,3)]%*%xfit$coefficients[c(2,4)]))
  AB[i,2] <- max(abs(xd[,c(2,4)]%*%xfit$coefficients[c(3,5)]))
}
biphasic.factor <- AB[,2]/(AB[,1]+AB[,2])







test.data <- read.csv("2012_2019 OU Long 2020 06 12.csv")
hfnames <- unique(test.data$HF_Name)
Nhfnames <- length(hfnames)

reported.in.2019 <- numeric(Nhfnames)
for (i in 1:Nhfnames) {reported.in.2019[i] <- length(which(test.data$HF_Name==hfnames[i] & test.data$Year==2019))}

RDTtested <- as.numeric(as.character(test.data$RDTtestedPASSIVEPointofCare))
reported.in.2019.RDTtested <- numeric(Nhfnames)
for (i in 1:Nhfnames) {reported.in.2019.RDTtested[i] <- length(which(!is.na(RDTtested[test.data$HF_Name==hfnames[i] & test.data$Year==2019])))}

MICtested <- as.numeric(as.character(test.data$MicroscopytestedPASSIVEPoint))
reported.in.2019.MICtested <- numeric(Nhfnames)
for (i in 1:Nhfnames) {reported.in.2019.MICtested[i] <- length(which(!is.na(MICtested[test.data$HF_Name==hfnames[i] & test.data$Year==2019])))}

RDTcases <- as.numeric(as.character(test.data$ConfirmedcasesPOCMonthl))
reported.in.2019.RDTcases <- numeric(Nhfnames)
for (i in 1:Nhfnames) {reported.in.2019.RDTcases[i] <- sum(RDTcases[test.data$HF_Name==hfnames[i] & test.data$Year==2019],na.rm=TRUE)}

