# Series of scripts to recreate analysis and key figures of 
# ilr_var_truncateZeroes_permTree analysis
# Specifically this adds the analysis with tax2tree.R utilities to the plots

library(philr); packageVersion("philr")
library(phyloseq); packageVersion("phyloseq")
library(ggplot2); packageVersion("ggplot2")
library(dplyr); packageVersion("dplyr")
library(tidyr); packageVersion("tidyr")
library(ggtree); packageVersion("ggtree")
library(matrixStats); packageVersion('matrixStats')
library(parallel); packageVersion("parallel")
library(magrittr); packageVersion("magrittr")
library(purrr); packageVersion("purrr")
library(ggrepel)
library(gridExtra)
source('update_taxonomy_and_tax2tree/tax2tree.R')

set.seed(4)

load('HMP_norarify.RData')

newtax.fp <- 'update_taxonomy_and_tax2tree/taxonomy_results/rep_set_v35_phiselected_tax_assignments.txt'


# Load Results from ilr_var_truncateZeroes_permTree.R ---------------------

BYSITE <- list()
for (site in levels(get_variable(HMP, 'groupedsites'))){
  SITE <- prune_samples(sample_data(HMP)$groupedsites == site, HMP)
  SITE <- filter_taxa(SITE, function(x) sum(x > 1) > (0.2*length(x)) , TRUE)
  SITE <- prune_samples(colSums(otu_table(SITE)) > 50, SITE)
  BYSITE[[site]] <- SITE
}

load('bysite.20000.40.RData')


# Utility Functions -------------------------------------------------------

# tax - vector of labels at the phylum level
simplify.phylum <- function(tax, thresh=50){
  # Define color presets
  phylum.colors <- c('Actinobacteria'='#BFBBFF', 'Bacteroidetes'='#CB4D42', 
                     'Cyanobacteria'='#66CDAA', 'Fusobacteria'='#EEFF9A',
                     'Firmicutes'='#FFB756','Proteobacteria'='#00569D',
                     'Spirochaetes'='#9B00A6', 'Synergistetes'='#CDAF95', 
                     'Tenericutes'='#FFEA95', 'Verrucomicrobia'='#68688E', 
                     'Other'='#FFAAAA')
  
  # Retain only phyla that have more than thresh representatives for the given tax
  # table - lump everythin else in with "Other"
  phylum.to.keep <- c(names(table(tax)[table(tax) > thresh]), 'Other')
  phylum.colors <- phylum.colors[phylum.to.keep]
  
  # Simplify Tax vector by only keeping certain ones - combine the rest 
  # to "Others"
  if(!is.factor(tax))tax <- factor(tax) # Make sure we are dealing with a factor
  lt <- levels(tax)
  lt[!(lt %in% names(phylum.colors))] <- 'Other'
  levels(tax) <- lt
  
  # Reorder levels so Other is at the end
  lt <- c(lt[lt!='Other'], 'Other')
  tax <- factor(tax, levels=lt)
  
  list(tax=tax, phylum.colors=phylum.colors)
}


# Updated Taxonomy
newtax <- read.delim(newtax.fp, header = FALSE)[,c(1,2)]
newtax %<>% 
  separate(V2, c('Kingdom', 'Phylum', 'Class', 'Order', 
                 'Family', 'Genus', 'Species'), sep = '; ', fill='right')
newtax %<>% map(~gsub('.__','', .x)) %>% as.data.frame()
rownames(newtax) <- newtax$V1
newtax$V1 <- NULL
newtax[newtax==''] <- NA

nb.new <- function(c, site, return.votes=NULL)name.balance(phy_tree(BYSITE[[site]]), 
                                                           newtax, c, return.votes = return.votes)
nbfxn.new <- function(site, cs)sapply(cs, nb.new, site)

# From 'Exploratory Functions section'
df.clo <- BYSITE %>%
  map(~as(t(otu_table(.x)), 'matrix')) %>%
  map(compositions::clo)
sbp <- BYSITE %>% 
  map(. %>% phy_tree %>% phylo2sbp)


var.truncate.zeroes <- function(coord, df, sbp, return.gm.only=FALSE, 
                                return.balance.only=FALSE,
                                plot=FALSE, plot.type='pair', plot.log=TRUE){
  df <- as(df, 'matrix')
  up <- which(sbp[,coord]==1)
  down <- which(sbp[,coord]==-1)
  df.up <- df[,up, drop=F]
  df.down <- df[,down, drop=F]
  df.up[abs(df.up-0) < 1e-14] <- NA
  df.down[abs(df.down-0) < 1e-14] <- NA
  log.gm.up <- rowMeans(log(df.up), na.rm=TRUE)
  log.gm.down <- rowMeans(log(df.down), na.rm=TRUE)
  log.ratio <- log.gm.up - log.gm.down
  # Now calculate scaling constant
  n.up <- ncol(df.up)
  n.down <- ncol(df.down)
  sc <- sqrt(n.up*n.down/(n.up+n.down))
  # Now calculate variance
  v <- var(sc*log.ratio, na.rm=T)
  support <- sum(!is.na(log.ratio))
  
  if (return.gm.only==TRUE){
    return(list(gm.up=exp(log.gm.up),
                gm.down=exp(log.gm.down)))
  } else if (return.balance.only==TRUE){
    return(list(balance=sc*log.ratio))
  }
  
  if (plot==TRUE){
    subtitle <- paste('var:',v, 'support:',support, sep=' ')
    if (plot.type=='pair'){
      gm.up <- exp(log.gm.up)
      gm.down <- exp(log.gm.down)
      if (plot.log==TRUE) plot(x=gm.up, y=gm.down, main=coord, log='xy')
      else plot(x=gm.up, y=gm.down, main=coord)
    } else if(plot.type == 'hist'){
      hist(na.exclude(sc*log.ratio), main=coord)
    }
    mtext(subtitle)
    return(NULL)
  } 
  return(c(v, support))
}

# tax - vector of labels at the phylum level
simplify.phylum <- function(tax, thresh=50){
  # Define color presets
  phylum.colors <- c('Actinobacteria'='#BFBBFF', 'Bacteroidetes'='#CB4D42', 
                     'Cyanobacteria'='#66CDAA', 'Fusobacteria'='#EEFF9A',
                     'Firmicutes'='#FFB756','Proteobacteria'='#00569D',
                     'Spirochaetes'='#9B00A6', 'Synergistetes'='#CDAF95', 
                     'Tenericutes'='#FFEA95', 'Verrucomicrobia'='#68688E', 
                     'Other'='#FFAAAA')
  
  # Retain only phyla that have more than thresh representatives for the given tax
  # table - lump everythin else in with "Other"
  phylum.to.keep <- c(names(table(tax)[table(tax) > thresh]), 'Other')
  phylum.colors <- phylum.colors[phylum.to.keep]
  
  # Simplify Tax vector by only keeping certain ones - combine the rest 
  # to "Others"
  if(!is.factor(tax))tax <- factor(tax) # Make sure we are dealing with a factor
  lt <- levels(tax)
  lt[!(lt %in% names(phylum.colors))] <- 'Other'
  levels(tax) <- lt
  
  # Reorder levels so Other is at the end
  lt <- c(lt[lt!='Other'], 'Other')
  tax <- factor(tax, levels=lt)
  
  list(tax=tax, phylum.colors=phylum.colors)
}

# Cleanup Sites Variables -------------------------------------------------

# Make extra variable for plotting labels
l <- get_variable(HMP, 'groupedsites') %>% 
  as.character() %>%
  strsplit(':') %>% 
  map(2) %>%
  gsub('_', ' ', .)
l[l=='Attached Keratinized gingiva'] <- 'Keratinized gingiva'
sample_data(HMP)$groupedsites.plot <- as.factor(l)

sample_data(HMP)$groupedsites.plot %<>% factor(levels=c('Stool','Retroauricular crease',
                                                        'Anterior nares', 'Antecubital fossa',
                                                        'Supragingival plaque', 'Subgingival plaque',
                                                        'Saliva', 'Tongue dorsum',
                                                        'Palatine Tonsils', 'Throat',
                                                        'Hard palate', 'Buccal mucosa',
                                                        'Keratinized gingiva', 'Vaginal introitus',
                                                        'Mid vagina', 'Posterior fornix'))


# pvd - figure ------------------------------------------------------------

# For a list of coordinates and a given site 
# returns a dataframe with columns c('coord', 'gm.up', 'gm.down')
calc.gm.coords <- function(site, cs){
  df.clo <- BYSITE[[site]] %>% 
    otu_table() %>% 
    t() %>% 
    as('matrix') %>% 
    compositions::clo()
  sbp <- BYSITE[[site]] %>% 
    phy_tree() %>% 
    phylo2sbp()
  
  cs.l <- array_branch(cs)
  names(cs.l) <- cs
  
  cs.l %>% 
    map(~var.truncate.zeroes(.x, df.clo, sbp, return.gm.only=TRUE)) %>% 
    map(~as.data.frame(.x)) %>% 
    bind_rows(.id='coord')
}

plot.pvd <- function(var.tz, labels=NULL, n.support=40, blw.byrank=NULL, 
                     plot.rug=F, plot.rank.lines=T, rank.species.only=F, 
                     point.size=1){
  df <- var.tz %>%  
    filter(var.tz !=0, mean.dist.to.tips !=0) %>%
    filter(support >= n.support)
  
  # Plot basic scatterplot
  p <- ggplot(df, aes(x=mean.dist.to.tips, y=var.tz)) +
    geom_point(size=point.size) +
    scale_x_log10() +
    scale_y_log10()
  
  # Add layer of taxonomic rank info if needed
  if (!is.null(blw.byrank)){
    if (plot.rug){
      p <- p + geom_rug(data=(blw.byrank %>% filter(rank %in% c('species'))),
                        aes(x=mean.dist.to.tips, y=NULL, color=rank), lwd=1, show.legend = F)
    }
    if (plot.rank.lines){
      if (rank.species.only){
        blw.byrank %<>% filter(rank %in% c('species'))
      }
      blw.byrank  %<>%  group_by(rank)  %>% summarise(median=median(mean.dist.to.tips))
      p <- p + geom_vline(data=blw.byrank, aes(xintercept=median), linetype='dashed') #+
        #geom_text(data=blw.byrank, aes(x=median, y=0.25, label=substr(rank, 1, 1)))
    }
  }
  
  # Now add loess regression and linear regression then labels
  p <- p + geom_smooth(method='loess', color='darkgreen') +
    geom_smooth(method='lm', se=FALSE)
  if (!is.null(labels)){
    data <- var.tz %>% filter(coord %in% labels)
    #p <- p + geom_label_repel(aes(label=coord), data=data)
  }
  # Some final stylizing
  p +
    theme_bw()+
    #ylab('variance of balace') + xlab('mean distance to tips') +
    # theme(axis.text.y=element_text(size=10), 
    #       axis.text.x=element_text(size=10))
    theme(axis.title.y=element_blank(), 
          axis.title.x=element_blank(), 
          axis.text.x=element_text(size=16), 
          axis.text.y=element_text(size=16))
}

figure.pvd <- function(site, to.plot, plot.rug=F, plot.rank.lines=T, rank.species.only=F, 
                       point.size=1){
  tr <- tax2tree(phy_tree(BYSITE[[site]]), newtax.fp)
  blw.byrank <- extract.rank.data(tr)
  p <- plot.pvd(bysite[[site]]$var.tz, labels=to.plot, blw.byrank=blw.byrank, 
                plot.rug=plot.rug, plot.rank.lines=plot.rank.lines, 
                rank.species.only=rank.species.only, 
                point.size=point.size)
  p
}

figure.gms <- function(site, low, high){
  gm <- calc.gm.coords(site, c(high, low))
  gm$coord <- factor(gm$coord, levels=c(high, low))
  gms <- gm %>% 
    na.omit() %>% 
    ggplot(aes(x=gm.down, y=gm.up)) +
    geom_point() + 
    scale_x_log10() +
    scale_y_log10() +
    facet_grid(coord~., scales = 'free') + 
    theme_bw() +
    theme(axis.title.y=element_blank(), 
          axis.title.x=element_blank(), 
          axis.text.x=element_text(size=16), 
          axis.text.y=element_text(size=16))
}

driver <- function(site, low, high, plot.rug=F, plot.rank.lines=T){
  pvd <- figure.pvd(site, c(low, high), plot.rug, plot.rank.lines)
  names <- nbfxn.new(site, c(low, high))
  gms <- figure.gms(site, low, high)
  list(pvd=pvd, names=names, gms=gms)
}


driver('STOOL:Stool', c('n1068'), c('n3'), T, F)

stool <- driver('STOOL:Stool', c('n1068'), c('n3'))
stool$names
stool$pvd
ggsave('stool.pvd.pdf', width = 3.25*3, height=1.3333*3, units = 'in')
stool$gms
ggsave('stool.gms.pdf', width = 1.625*3, height=1.3333*3, units = 'in')

bm <- driver('ORAL:Buccal_mucosa', c('n3225'), c('n2634'))
bm$names
bm$pvd
ggsave('bm.pvd.pdf', width = 3.25*3, height=1.3333*3, units = 'in')
bm$gms
ggsave('bm.gms.pdf', width = 1.625*3, height=1.3333*3, units = 'in')

mv <- driver('GU:Mid_vagina', 'n1572', c("n1406"))
mv$names
mv$pvd
ggsave('mv.pvd.pdf', width = 3.25*3, height=1.3333*3, units = 'in')
mv$gms
ggsave('mv.gms.pdf', width = 1.625*3, height=1.3333*3, units = 'in')


# Trees -------------------------------------------------------------------



figure.tree <- function(site, newtax, to.plot){
  # Simply Phyla Levels 
  tax <- as.data.frame(newtax)
  tax <- tax[,'Phylum', drop=F]
  simple.tax <- simplify.phylum(tax$Phylum)
  tax$Phylum <- simple.tax$tax
  
  # create tree and color
  tr <- phy_tree(BYSITE[[site]])
  p <- ggtree(tr, aes(color=var.tz, size=isTip)) %<+% as.data.frame(bysite[[site]]$var.tz) +
    scale_size_manual(values=c(1.2, .7)) +
    labs(color='Balance Variance')
  
  # Add labels to highlihgted points
  if (!is.null(to.plot)){
    data <- p$data %>% filter(label %in% to.plot)
    #p <- p + geom_label_repel(aes(label=label), size=4, data=data) 
    p <- p + geom_point2(aes(label=label), size=3, data=data)
  }
  
  # Don't Add Phyla bar
  p <- p +
    theme(legend.position='bottom') +
    scale_color_gradient2(low='red', mid='darkgrey', high='blue', trans='log10',
                          midpoint = 0.5, limits=c(0.1, 110)) +
    guides(size='none') +
    geom_treescale()
  p
}

# Extract legend
g_separate_legend<-function(a.gplot){
  tmp <- ggplot_gtable(ggplot_build(a.gplot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  legend <- tmp$grobs[[leg]]
  tree <- a.gplot + guides(color='none')
  list(legend=legend,
       tree=tree)
}

# Stool
p <- figure.tree('STOOL:Stool', newtax, c('n1068', 'n3'))
parts <- g_separate_legend(p)

pdf('tree_legend.pdf')
grid.arrange(parts$legend)
dev.off()

(stool.tree <- parts$tree)
ggsave('stool.tree.pdf', width = 1.625*3, height=1.333*3, units = 'in')

p <- figure.tree('ORAL:Buccal_mucosa', newtax, c('n3225', 'n2634'))
parts <- g_separate_legend(p)
grid.arrange(parts$legend)
(bm.tree <- parts$tree)
ggsave('bm.tree.pdf', width = 1.625*3, height=1.333*3, units = 'in')

p <- figure.tree('GU:Mid_vagina', newtax, c('n1572', 'n1406'))
parts <- g_separate_legend(p)
grid.arrange(parts$legend)
(mv.tree <- parts$tree)
ggsave('mv.tree.pdf', width = 1.625*3, height=1.333*3, units = 'in')

# Supplemental Figures (pvd) --------------------------------------------------

site.names <- names(bysite)  %>% strsplit(':')  %>% map(2)  %>% as_vector()  %>% sub('_', ' ', .)
site.names[2] <- 'Keratinized gingiva'
names(site.names) <- names(bysite)

p.list <- list()
for (site in names(bysite)){
  p.list[[site]] <- figure.pvd(site, NULL, plot.rug=F, plot.rank.lines=T, rank.species.only=T, 
                               point.size=0.1) + 
    theme(axis.title.x=element_blank(), 
          axis.title.y=element_blank(), 
          axis.text.x=element_text(size=8), 
          axis.text.y=element_text(size=8), 
          title=element_text(size=10)) + 
    ggtitle(site.names[site])
}
pdf('extended.pvds.pdf', width = 6.93, height=1.4*6+.2)
marrangeGrob(p.list, ncol=3,nrow = 6, 
             widths=unit(rep(2.31,3), 'in'), 
             heights=unit(rep(1.4,6), 'in'), 
             top='')
dev.off()



# Supplemental Figures (PVD Breakdown) ------------------------------------

# Just going to do this for stool 
p.pvd <- figure.pvd('STOOL:Stool', NULL, plot.rug=F, plot.rank.lines = F) +
  theme(axis.title.x=element_text(size=16), 
        axis.title.y=element_text(size=16, angle = 90),
        axis.text.x=element_text(size=12),
        axis.text.y=element_text(size=12)) + 
  xlab('Mean Distance to Tips') + ylab('') +
  theme(axis.title.y=element_blank())
p.pvd$layers[c(2,3)]  <- NULL # Remove Loess line and regression Line 


# Get balances in the top and bottom 10 for variance
bottom.10 <- bysite$`STOOL:Stool`$var.tz %>% 
  filter(mean.dist.to.tips > 0) %>% 
  top_n(-10, var.tz) %>% 
  arrange(var.tz)

top.10 <- bysite$`STOOL:Stool`$var.tz %>% 
  top_n(10, var.tz) %>% 
  arrange(var.tz)

bottom.top.10 <- bind_rows(bottom.10, top.10) %>% 
  add_rownames(var = 'id')

# Add Labels to pvd figure
p.pvd <- p.pvd + 
  geom_label_repel(aes(label=id, x=mean.dist.to.tips, y=var.tz), 
                  data=bottom.top.10)


sup.figure.gms <- function(site, coords, labels){
  gm <- calc.gm.coords(site, coords)
  gm$coord <- factor(gm$coord, levels=coords)
  levels(gm$coord) <- labels 
  gms <- gm %>% 
    na.omit() %>% 
    ggplot(aes(x=gm.down, y=gm.up)) +
    geom_point(size=.5) + 
    scale_x_log10() +
    scale_y_log10() +
    facet_wrap(~coord, scales = 'free', nrow=5, ncol=2) + 
    theme_bw() +
    theme(axis.title.y=element_blank(), 
          axis.title.x=element_blank(), 
          axis.text.x=element_text(size=10), 
          axis.text.y=element_text(size=10))
}
# Graph gms for top and bottom separately
(p.bottom.10 <- sup.figure.gms('STOOL:Stool',  bottom.10$coord, 1:10))
(p.top.10 <- sup.figure.gms('STOOL:Stool',  top.10$coord, 11:20))

# Arrange results 
pdf('stool.pvd.expanded.pdf', 
    width=7, height=9)
grid.arrange(p.pvd, p.bottom.10, p.top.10, 
             layout_matrix=rbind(c(1,1), c(2,3), c(2,3)), 
             widths=unit(c(3.5, 3.5), 'in'), 
             heights=unit(c(3, 3, 3), 'in'))
dev.off()


# Supplemental Figure - Null Model Results --------------------------------

site.names <- names(bysite)  %>% strsplit(':')  %>% map(2)  %>% as_vector()  %>% sub('_', ' ', .)
site.names[2] <- 'Keratinized gingiva'
names(site.names) <- names(bysite)


plot.null.dist <- function(site){
  title <- site.names[site]
  bysite[[site]]$nmr %>% 
    ggplot(aes(x=log_mdtt)) +
    geom_histogram() +
    geom_vline(aes(xintercept=bysite[[site]]$lm$coefficients[2,1]), color='red') +
    # xlab(expression(beta)) +
    ggtitle(title) +
    theme_bw() +
    theme(axis.title.x=element_blank(),
          axis.title.y=element_blank(),
          axis.text.x=element_text(size=8),
          axis.text.y=element_text(size=8),
          title=element_text(size=10))
}

p.list <- list()
for (site in names(bysite)){
  p.list[[site]] <- plot.null.dist(site)
}
pdf('extended.nulls.pdf', width = 6.93, height=1.4*6+.2)
marrangeGrob(p.list, ncol=3,nrow = 6,
             widths=unit(rep(2.31,3), 'in'),
             heights=unit(rep(1.4,6), 'in'),
             top='')
dev.off()


