# Script contains commands used for running R analyses associated with estimating rates
# of body size evolution for cartilaginous fishes, including the whale shark


# Data derive from Stein et al. 2018 (doi: 10.1038/s41559-017-0448-4)
# tree data available from sharktree.org
# body size data were provided from original author Chris Mull (creeas@gmail.com)

# gigantism analysis
# load ape
shark_trees <- read.nexus("Chondrichthyan.610sp.10_fossil_Calibration.500treePLtrees.nex")
size_data <- read.csv("610spp.body.size.for.Milton.csv",header = TRUE, stringsAsFactors = FALSE)

# Subset data for data that are present
max.mass <- size_data$max.mass[ which(! size_data$max.mass == 0) ]
names(max.mass) <- size_data$species[ which(! size_data$max.mass == 0) ]

# Subset trees for data that are present
shark.trees.prune <- lapply(shark_trees, drop.tip,size_data$species[ which(size_data$max.mass == 0) ])
class(shark.trees.prune) <- "multiPhylo"

# BAMM/BAMMtools used to identify shifts in body size evolution

# Write trait data and trees:
write.table(log(max.mass),file = "bamm_shark_mass/max_mass.txt", quote = FALSE, col.names = FALSE, sep = "\t")
lapply(1:500, function(x) write.tree(shark.trees.prune[[x]], file = paste("bamm_shark_mass/tree_",x,".tre",sep="")))

# Generate trait control files
bamm.priors <- t(sapply(1:500, function(x) setBAMMpriors(shark.trees.prune[[x]],
	traits = "bamm_shark_mass/max_mass.txt", 
	outfile = NULL)))

# Some of the numberOfGenerations are 50000000
sapply(1:500, function(x) generateControlFile(
file = paste("bamm_shark_mass/trait_control_",x,".txt",sep=""),
type = "trait", params = list(
    treefile = paste("tree_",x,".tre",sep=""),
    traitfile = "max_mass.txt",
    betaInitPrior = bamm.priors[x,2],
    betaShiftPrior = bamm.priors[x,3],
    overwrite = 1,
    expectedNumberOfShifts = 1.0,
    useObservedMinMaxAsTraitPriors = 1,
    numberOfGenerations = "30000000",
    mcmcWriteFreq = "10000",
    eventDataWriteFreq = "10000",
    printFreq = 1000,
    acceptanceResetFreq = "10000",
    outName = x
)
))

###########
#	
#	# IN BASH
#	# Navigate to directory containing the control files: 
#	for i in `ls trait* | cut -f3 -d'_' | cut -f1 -d'.'`; do
#	bamm -c trait_control_"$i".txt && cat "$i"_mcmc_out.txt | tr ',' '\t' > "$i"_mcmc_out.tab
#	done
#	
###########

# Load packages
library(coda)
library(BAMMtools)

# Assess which ones didn't converge using coda effectiveSize()

bamm.mcmc <- lapply(1:500, function(x) read.csv(paste("bamm_shark_mass/",x,"_mcmc_out.txt",sep=""), row.names = 1, header = TRUE))
bamm.mcmc.ess <- t(sapply(bamm.mcmc, function(x) effectiveSize(x[(.1*nrow(x)):nrow(x),])))
# Check if any failed:
which(apply(bamm.mcmc.ess,MARGIN = 1,function(x) any(x < 200)))

# Continue BAMM runs that failed, if any:
# Generate new control files:
failed_bamm_runs <- which(apply(bamm.mcmc.ess,MARGIN = 1,function(x) any(x < 200)))

sapply(failed_bamm_runs, function(x) generateControlFile(
file = paste("bamm_shark_sizes/trait_control_",x,".txt",sep="")
type = "trait", params = list(
    treefile = paste("tree_",x,".tre",sep=""),
    traitfile = "max_sizes.txt",
    betaInitPrior = bamm.priors[x,2],
    betaShiftPrior = bamm.priors[x,3],
    loadEventData = 1,
    overwrite = 0,
    expectedNumberOfShifts = 1.0,
    useObservedMinMaxAsTraitPriors = 1,
    numberOfGenerations = "50000000",
    mcmcWriteFreq = "10000",
    eventDataWriteFreq = "10000",
    printFreq = 1000,
    acceptanceResetFreq = "10000",
    outName = x
)
))
# Write out failed BAMM runs
write.table(failed_bamm_runs,row.names = FALSE, col.names = FALSE)

###########
#	
#	# IN BASH
#	# Run:
#	for i in `cat failed_bamms.txt`; do
#	bamm -c trait_control_"$i".txt && cat "$i"_mcmc_out.txt | tr ',' '\t' > "$i"_mcmc_out.tab
#	done
#	
###########

# Subsample event data for analyses:
bamm.edata <- lapply(1:500, function(x) getEventData(shark.trees.prune[[x]], paste("bamm_shark_mass/",x,"_event_data.txt",sep = ""), burnin = 0.1, type = "trait", nsamples = 2000))

# Get whale shark rates
bamm.cr.ws <- t(sapply(bamm.edata, function(x) getCladeRates(x, 
node = x$edge[ which(x$edge[,2] == which(x$tip.label == "Rhincodon_typus")),1], 
nodetype = "include")$beta ))


# Get background clade rate:

# To identify bg rates, you need to select nodes for which there is low support for a shift

# Get all the nodes from the edata file
# Code based on getCladeRates function in BAMMtools
# identifies the nodes for each posterior sample!
getBackgroundRates <- function (ephy, verbose = FALSE) 
{
    if (!"bammdata" %in% class(ephy)) {
        stop("Object ephy must be of class bammdata\n")
    }
    
    timeIntegratedBranchRate <- function(t1, t2, p1, p2){
	tol <- 0.00001;
	res <- vector(mode = 'numeric', length = length(t1));
	# constant rate
	zero <- which(abs(p2) < tol);
	p1s <- p1[zero];
	t1s <- t1[zero];
	t2s <- t2[zero];
	res[zero] <- p1s * (t2s - t1s);
	# declining rate
	nonzero <- which(p2 < -tol);
	p1s <- p1[nonzero];
	p2s <- p2[nonzero];
	t1s <- t1[nonzero];
	t2s <- t2[nonzero];
	res[nonzero] <- (p1s/p2s)*(exp(p2s*t2s) - exp(p2s*t1s));
	# increasing rate
	nonzero <- which(p2 > tol);
	p1s <- p1[nonzero];
	p2s <- p2[nonzero];
	t1s <- t1[nonzero];
	t2s <- t2[nonzero];
	res[nonzero] <- (p1s/p2s)*(2*p2s*(t2s-t1s) + exp(-p2s*t2s) - exp(-p2s*t1s));
	return(res);
	}
    
    eventnodes <- lapply(1:length(ephy$eventData), function(x) ephy$eventData[[x]][2:ephy$numberEvents[x],1])
    descendantnodes <- lapply(eventnodes, function(x) sapply(x, function(y) getDescendants(as.phylo(ephy),node = y)))
    nodeset <- lapply(1:length(ephy$eventData), function(x) setdiff(ephy$edge[,2], c(eventnodes[[x]],descendantnodes[[x]])))
#    nodeset <- lapply(1:length(ephy$eventData), function(x) setdiff(ephy$edge[,2], descendantnodes[[x]]))

    lambda_vector <- numeric(length(ephy$eventBranchSegs))
    mu_vector <- numeric(length(ephy$eventBranchSegs))
    weights <- "branchlengths"
    
    for (i in 1:length(ephy$eventBranchSegs)) {
        if (verbose) {
            cat("Processing sample ", i, "\n")
        }
        esegs <- ephy$eventBranchSegs[[i]]
        esegs <- esegs[esegs[, 1] %in% nodeset[[i]], ]
        if (is.null(nrow(esegs))) {
            esegs <- t(as.matrix(esegs))
        }
        events <- ephy$eventData[[i]]
        events <- events[order(events$index), ]
        relsegmentstart <- esegs[, 2] - ephy$eventData[[i]]$time[esegs[, 
            4]]
        relsegmentend <- esegs[, 3] - ephy$eventData[[i]]$time[esegs[, 
            4]]
        lam1 <- ephy$eventData[[i]]$lam1[esegs[, 4]]
        lam2 <- ephy$eventData[[i]]$lam2[esegs[, 4]]
        mu1 <- ephy$eventData[[i]]$mu1[esegs[, 4]]
        mu2 <- ephy$eventData[[i]]$mu2[esegs[, 4]]
        seglengths <- esegs[, 3] - esegs[, 2]
        wts <- seglengths/sum(seglengths)
        lamseg <- timeIntegratedBranchRate(relsegmentstart, relsegmentend, 
            lam1, lam2)/seglengths
        museg <- timeIntegratedBranchRate(relsegmentstart, relsegmentend, 
            mu1, mu2)/seglengths
        lambda_vector[i] <- sum(lamseg * wts)
        mu_vector[i] <- sum(museg * wts)
    }
    if (ephy$type == "diversification") {
        return(list(lambda = lambda_vector, mu = mu_vector))
    }
    if (ephy$type == "trait") {
        return(list(beta = lambda_vector))
    }
}


# Trees are rows, posterior samples are columns
bamm.rate.bg <- t(sapply(bamm.edata, function(x) getBackgroundRates(x)$beta ))

bamm.margodds <- lapply(bamm.edata, marginalOddsRatioBranches, expectedNumberOfShifts = 1)

bamm.margodds.ws <- sapply(bamm.margodds, function(x) x$edge.length[ which(x$edge[,2] == which(x$tip.label == "Rhincodon_typus")) ])

# Plot together
# Modified from Mick Watson's code: https://gist.github.com/mw55309/7faddf5b0804f70dbc4f7cecb63fd202

plot.histograms <- function(vec1, vec2) {

	t_col <- function(rgb.val, percent = 50) {
	#	  color = color name
	#	percent = % transparency
	#	   name = an optional name for the color
	## Get RGB values for named color
	## Make new color using input color as base and alpha set by transparency
  	t.col <- rgb(rgb.val[1], rgb.val[2], rgb.val[3],
              	 max = 255,
              	 alpha = (100-percent)*255/100)

	return(t.col)
	}

	dws <- density(vec1)
	dbg <- density(vec2)

	ymax <- max(c(dws$y, dbg$y))
	ymin <- min(c(dws$y, dbg$y))

	ylim <- c(ymin,ymax)

	xmax <- 0.8
	xmin <- 0
	
	xlim <- c(xmin,xmax)

	plot(1,1, pch="", xlim=xlim, ylim=ylim, bty="n", xlab="Rate (log-grams per million yr)", ylab="Frequency")
	lines(dws$x, dws$y)
	polygon(dws$x,  dws$y, col=t_col(c(253,228,140)))
	abline(v = mean(vec1), lty = 5)
	lines(dbg$x, dbg$y)
	polygon(dbg$x,  dbg$y, col=t_col(c(37,34,130)))
	abline(v = mean(vec2), lty = 3)
}

plot.histograms(apply(bamm.cr.ws,MARGIN = 1, mean),apply(bamm.rate.bg,MARGIN = 1, mean))
plot.histograms(as.vector(bamm.cr.ws),as.vector(bamm.rate.bg))