# Here I use permutation of the leaf labels on the tree as the null model. 

library(philr); packageVersion("philr")
library(phyloseq); packageVersion("phyloseq")
library(ggplot2); packageVersion("ggplot2")
library(dplyr); packageVersion("dplyr")
library(tidyr); packageVersion("tidyr")
library(matrixStats); packageVersion('matrixStats')
library(parallel); packageVersion("parallel")
library(magrittr); packageVersion("magrittr")

set.seed(4)

load('../HMP.RData')

# Helpful fxns ------------------------------------------------------------

var.truncate.zeroes <- function(coord, df, sbp, plot=FALSE, plot.log=TRUE){
  df <- as(df, 'matrix')
  up <- which(sbp[,coord]==1)
  down <- which(sbp[,coord]==-1)
  df.up <- df[,up, drop=F]
  df.down <- df[,down, drop=F]
  df.up[abs(df.up-0) < 1e-14] <- NA
  df.down[abs(df.down-0) < 1e-14] <- NA
  log.gm.up <- rowMeans(log(df.up), na.rm=TRUE)
  log.gm.down <- rowMeans(log(df.down), na.rm=TRUE)
  log.ratio <- log.gm.up - log.gm.down
  # Now calculate scaling constant
  n.up <- ncol(df.up)
  n.down <- ncol(df.down)
  sc <- sqrt(n.up*n.down/(n.up+n.down))
  # Now calculate variance
  v <- var(sc*log.ratio, na.rm=T)
  support <- sum(!is.na(log.ratio))
  
  if (plot==TRUE){
    gm.up <- exp(log.gm.up)
    gm.down <- exp(log.gm.down)
    subtitle <- paste('var:',v, 'support:',support, sep=' ')
    if (plot.log==TRUE) plot(x=gm.up, y=gm.down, main=coord, log='xy')
    else plot(x=gm.up, y=gm.down, main=coord)
    mtext(subtitle)
    return(NULL)
  }
  return(c(v, support))
}

# Calls var.truncate.zeroes
calc.var <- function(sbp, bmd, df, n.support=100, return.var=FALSE){
  # Var Truncate Zeroes
  df <- as(df, 'matrix')
  df <- compositions::clo(df)
  var.tz <- sapply(colnames(sbp), var.truncate.zeroes, df, sbp)
  var.tz <- t(var.tz)
  colnames(var.tz) <- c('var.tz', 'support')
  var.tz <- add_rownames(as.data.frame(var.tz), var = 'coord')
  
  # Add distance to tips
  var.tz$mean.dist.to.tips <- bmd[var.tz$coord]
  if (return.var==TRUE)return(var.tz)
  
  # Set threshold below which can't calculate variance and fit model
  fit <- var.tz %>% 
    filter(var.tz !=0, mean.dist.to.tips !=0) %>% 
    filter(support >= n.support) %>%
    lm(log(var.tz) ~log(mean.dist.to.tips), data=.) %>% 
    broom::tidy()
  fit$term  <- factor(c('intercept', 'log_mdtt'))
  
  return(fit)
}

null.model <- function(i, sbp, bmd, df, n.support){
  i <- NULL # Just for mclapply
  sbp.perm <- sbp
  colnames(sbp.perm) <- sample(colnames(sbp.perm), 
                               length(colnames(sbp.perm)), 
                               replace=FALSE)
  return(calc.var(sbp.perm, bmd, df, n.support, return.var=FALSE))
}

# High level 
# n - number of permutations
# nmr - precomputed null model results
# n.support - number of non-zero/missing points/samples needed to calculate var
run.var.analysis <- function(phyloseq.filt, n, ncores, n.support, nmr=NULL){
  df <- t(otu_table(phyloseq.filt))
  tr <- phy_tree(phyloseq.filt)
  bmd <- blw.mean.descendants(tr)
  sbp <- phylo2sbp(tr)
  
  # Now calculate var
  var.tz <- calc.var(sbp, bmd, df, return.var=TRUE)
  
  # Fit linear model 
  lm <- var.tz %>% 
    filter(var.tz !=0, mean.dist.to.tips !=0) %>% 
    filter(support >= n.support) %>%
    lm(log(var.tz) ~log(mean.dist.to.tips), data=.) %>% 
    summary()
  print(lm)
  
  if (is.null(nmr)){
    # Now create null model/permute
    null.model.results <- mclapply(1:n, null.model, sbp, bmd, df, n.support, mc.cores=ncores)
    nmr <- bind_rows(null.model.results, .id='id')
    nmr %<>% 
      select(id, term, estimate) %>%
      spread(term,estimate)
  }
  
  # Display Null distribution
  p.null <- ggplot(nmr, aes(x=log_mdtt)) +
    geom_density() +
    geom_vline(aes(xintercept=lm$coefficients[2,'Estimate']), color='red') +
    ggtitle(paste('Null Distribution for Slope', 
                  ggtitle(deparse(substitute(phyloseq.filt))), 
                  sep=': '))
  
  # Plot var vs. depth with null models plotted as well
  p.pvd <- var.tz %>%
    filter(var.tz !=0, mean.dist.to.tips !=0) %>%
    filter(support >= n.support) %>%
    ggplot(aes(x=log(mean.dist.to.tips), y=log(var.tz), color=support), alpha=0.5) +
    geom_point() +
    geom_abline(aes(intercept=intercept, slope=log_mdtt), data=nmr, color='darkgreen', alpha=0.01) +
    geom_smooth(method='lm', se=FALSE)+
    ggtitle(deparse(substitute(phyloseq.filt)))
  
  # Is there a relationship between support and var?
  p.mp <- ggplot(var.tz, aes(x=support, y=log(var.tz), color=log(mean.dist.to.tips))) + 
    geom_point() + 
    ggtitle(deparse(substitute(phyloseq.filt)))
  
  result <- list(var.tz=var.tz,
                 lm = lm, 
                 nmr=nmr,
                 p.null=p.null,
                 p.pvd=p.pvd,
                 p.tree=p.tree,
                 p.mp=p.mp)
  return(result)
}


# By Site - Analyze -------------------------------------------------------
BYSITE <- list()
for (site in levels(get_variable(HMP, 'groupedsites'))){
  SITE <- prune_samples(sample_data(HMP)$groupedsites == site, HMP)
  SITE <- filter_taxa(SITE, function(x) sum(x > 1) > (0.2*length(x)) , TRUE)
  SITE <- prune_samples(colSums(otu_table(SITE)) > 50, SITE)
  BYSITE[[site]] <- SITE
}

bysite = list()
for (site in levels(get_variable(HMP, 'groupedsites'))){
  bysite[[site]] <- run.var.analysis(BYSITE[[site]], n=20000, n.support = 40, ncores=32)
}
save(bysite, file='bysite.20000.40.RData')
