# Script contains R commands used for determining patterns of gene family origin and loss
# based on OrthoFinder results. Based on KinFin orthogroup annotations, script is also
# provided to test if gene families gained and lost along certain branches determined in
# the first section were enriched for any GO or Pfam annotations.


######
#
#	Gene Family Origin and Losses
#
######

# Read in gene family presence table:
# Orthogroups.GeneCount.csv is generated by OrthoFinder
orthogroups <- read.table("Orthogroups.GeneCount.csv", header = TRUE, row.names = 1)

# Define taxa belonging to all clades to be analyzed
chor <- c("dnov", "bacu", "bmys", "ttru", "btau", "sscr", "clup", "hsap", "mmus", 
	"lafr", "pcap", "mdom", "oana", "amis", "ggal", "acar", "xtro", "lcha",
	"locu", "drer", "dnig", "trub", "mmol", "gacu", "olat", "onil", "gmor",
	"eluc", "sfor", "cpun", "rtyp", "ccar", "stor", "cmil", "pmar", "bflo",
	"cint")
ceph <- c("bflo")
olfa <- chor[! chor %in% ceph]
uroc <- c("cint")
vert <- olfa[! olfa %in% uroc]
agna <- c("pmar")
gnat <- vert[! vert %in% agna]
rtyp <- "rtyp"
cpun <- "cpun"
stor <- "stor"
ccar <- "ccar"
orec <- c("rtyp", "cpun")
lach <- c("stor", "ccar")
sela <- c(orec, lach)
chim <- c("cmil")
chon <- c(chim, sela)
oste <- vert[! vert %in% c(agna, chon)]
acti <- c("dnig", "drer", "eluc", "gacu", "gmor", "locu", "mmol", "olat", "onil", "sfor", "trub")
tele <- acti[! acti %in% "locu"]
sarc <- oste[! oste %in% acti]
tetr <- sarc[! sarc %in% c("lcha", "")]
amph <- c("xtro")
amni <- tetr[! tetr %in% amph ]

# Specify sister clades for each clade to be analyzed:
clades_all <- list(
"chor" = list("ceph" = ceph, "olfa" = olfa),
"ceph" = ceph,
"olfa" = list("uroc" = uroc, "vert" = vert),
"uroc" = uroc,
"vert" = list("agna" = agna, "gnat" = gnat),
"agna" = agna,
"gnat" = list("chon" = chon, "oste" = oste),
"chon" = list("chim" = chim, "sela" = sela),
"chim" = chim,
"sela" = list("orec" = orec, "lach" = lach),
"orec" = list("rtyp" = rtyp, "cpun" = cpun),
"lach" = list("ccar" = ccar, "stor" = stor),
"rtyp" = rtyp,
"cpun" = cpun,
"stor" = stor,
"ccar" = ccar,
"oste" = list("acti" = acti, "sarc" = sarc),
"acti" = list("locu" = "locu", "tele" = acti[! acti %in% "locu"]),
"tele" = acti[! acti %in% "locu"],
"sarc" = list("lcha" = "lcha", "tetr" = sarc[! sarc %in% "lcha"]),
"tetr" = list("amph" = amph, "amni" = amni)
)

# Define clades that will be studied where losses can be inferred,
# i.e. define clades where gene families can be inferred in the ancestor
clades_anc <- list(
"uroc" = "chor",
"vert" = "chor",
"agna" = c("chor", "olfa"),
"gnat" = c("chor", "olfa"),
"chon" = c("chor", "olfa", "vert"),
"chim" = c("chor", "olfa", "vert", "gnat"),
"sela" = c("chor", "olfa", "vert", "gnat"),
"orec" = c("chor", "olfa", "vert", "gnat", "chon"),
"lach" = c("chor", "olfa", "vert", "gnat", "chon"),
"rtyp" = c("chor", "olfa", "vert", "gnat", "chon", "sela"),
"cpun" = c("chor", "olfa", "vert", "gnat", "chon", "sela"),
"stor" = c("chor", "olfa", "vert", "gnat", "chon", "sela"),
"ccar" = c("chor", "olfa", "vert", "gnat", "chon", "sela"),
"oste" = c("chor", "olfa", "vert"),
"acti" = c("chor", "olfa", "vert", "gnat"),
"sarc" = c("chor", "olfa", "vert", "gnat")
)

# Divide taxa into clades that are solely defined by a single tip and non-tip clades:
clades_tip <- unlist(clades_all[ unlist(lapply(clades_all, length) == 1) ])
clades_nti <- clades_all[ ! unlist(lapply(clades_all, length) == 1) ]

# First, determine orthogroups present at tips (i.e. have more than 0 copies observed) in species analyzed
clades_tip_pres <- sapply(clades_tip, function(x) rownames(orthogroups)[(which(orthogroups[x] > 0))])
clades_tip_pres_num <- unlist(lapply(clades_tip_pres, length))

#	Determine the orthogroups that are present at the MRCA of a clade
#	Note, orthogroups present at clade MRCA include two groups:
#		Present in any member of the clade (even one) and an outgroup
#		Present in both subclades (including orthogroups not in the outgroup)
clades_nti_pres <- lapply(clades_nti, function(x) {
	clade1 <- x[[1]]
	clade2 <- x[[2]]
	ingroup <- c(clade1, clade2)
	outgroup <- chor[! chor %in% ingroup]
	rownames(orthogroups)[which(
		apply(orthogroups, 1, function(y) {
		present_in_outgroup_and_clade <- all(any(y[ outgroup ] > 0), any(y[ ingroup ] > 0))
		present_in_both_subclades <- all(c(any(y[ clade1 ] > 0), any(y[ clade2 ] > 0)))
		# Below evaluates orthogroup as TRUE if one or the other are TRUE
		present_in_outgroup_and_clade + present_in_both_subclades > 0
		}
	))]
	}
)
clades_nti_pres_num <- unlist(lapply(clades_nti_pres, length))

clades_pres <- c(clades_tip_pres, clades_nti_pres)[names(clades_all)]
clades_pres_num <- c(clades_tip_pres_num, clades_nti_pres_num)[names(clades_all)]


#	Determine gene families that are present at a node and are also not lost in any descendants:
#	These are sometimes considered to be "core" genes
#	Depends on the above present gene families to already be determined
clades_pres_noloss <- sapply(names(clades_pres)[names(clades_pres) %in% names(clades_nti)],
	function(x) {
	clade_pres <- clades_pres[[x]]
	clade_taxa <- unlist(clades_all[[x]])
	any_absence <- rownames(orthogroups)[which(
		apply(orthogroups, 1, function(y) {
		any(y[clade_taxa] == 0)
		}
	)
	)]
	clade_pres[ ! clade_pres %in% any_absence ]
	}
)
names(clades_pres_noloss) <- names(clades_pres)[names(clades_pres) %in% names(clades_nti)]
clades_pres_noloss_num <- unlist(lapply(clades_pres_noloss, length))


# 	Determine losses leading to the MRCA of each clade
#	Depends on gene families to be present in the ancestor. Because of this, losses cannot be inferred for the root, as there is no deeper ancestral node corresponding to that node
#	Also cannot be determined for clades immediately descendant to root (Branchiostoma vs. Olfactores) because absence in 
#		either clade would mean that those gene families are not inferred to be present at the root either
#	Note, this is not presented in the paper, but represents the sum of the losses separated by clade
clades_loss <- sapply(names(clades_all)[! names(clades_all) %in% c("chor", "ceph", "olfa")], 
	function(x) {
	    z <- clades_pres[[names(which(sapply(clades_all, function(y) x %in% names(y))))]]
		z[! z %in% clades_pres[[x]] ]
	}
)
clades_loss_num <- unlist(lapply(clades_loss, length))

#	Determine gains leading to the MRCA of each clade
#	Depends on gene families to be absent in the ancestor. Again, cannot be inferred for root.

#	Determine gains for each tip:
clades_gain_tip <- lapply(clades_tip,
	function(x) {
	rownames(orthogroups)[which(
		apply(orthogroups, 1, function(y) {
			allothertaxa <- chor[! chor %in% x]
			present_in_taxon <- y[x] > 0
			absent_in_allothertaxa <- all(y[allothertaxa] == 0)
			# Below evaluates to true
			all(c(present_in_taxon, absent_in_allothertaxa))
		}
	))]
	}
)
# 	Determine gains for each 
clades_gain_nti <- sapply(names(clades_nti)[! names(clades_nti) %in% "chor"],
	function(x) {
		clade1_taxa <- clades_all[[x]][[1]]
		clade2_taxa <- clades_all[[x]][[2]]
		clade_taxa <- c(clade1_taxa, clade2_taxa)
		parent_clade_name <- names(which(sapply(clades_all, function(y) x %in% names(y))))
		parent_clade_subclades <- clades_all[[ parent_clade_name ]]
		# For sister clade taxa, technically don't have to do it this way for a fully-resolved tree, but this handles clades that are polytomies. Technically the clade1_taxa and clade2_taxa commands above are NOT robust to polytomies
		sister_clade_taxa <- unlist(parent_clade_subclades[! names(parent_clade_subclades) %in% x])
		outgroup_taxa <- chor[ ! chor %in% c(clade_taxa, sister_clade_taxa) ]
		rownames(orthogroups)[which(
			apply(orthogroups, 1, function(y) {
			absent_in_outgroup_sister <- all(y[ c(outgroup_taxa, sister_clade_taxa) ] == 0)
			present_in_both_subclades <- all(c(any(y[ clade1_taxa ] > 0), any(y[ clade2_taxa ] > 0)))
			all(c(absent_in_outgroup_sister,present_in_both_subclades))
			}
		))]
	}
)
clades_gain <- c(clades_gain_tip, clades_gain_nti)[names(clades_all)[! names(clades_all) %in% "chor"]]
clades_gain_num <- unlist(lapply(clades_gain, length))

#	Determine gene families that were gained at a node and not lost in any descendants
clades_gain_noloss <- sapply(names(clades_gain)[names(clades_gain) %in% names(clades_nti)],
	function(x) {
	clade_gain <- clades_gain[[x]]
	clade_taxa <- unlist(clades_all[[x]])
	any_absence <- rownames(orthogroups)[which(
		apply(orthogroups, 1, function(y) {
		any(y[clade_taxa] == 0)
		}
	)
	)]
	clade_gain[ ! clade_gain %in% any_absence ]
	}
)
names(clades_gain_noloss) <- names(clades_gain)[names(clades_gain) %in% names(clades_nti)]
clades_gain_noloss_num <- unlist(lapply(clades_gain_noloss, length))


# Next:
# Gene families present never lost
# Gene families gained never lost
# Gene families lost by origin <- may need to add the names of clades that precede the branch to assign origins?

#	Determine gene families that are lost, and separate based on what branch they originated
clades_loss_byorigin <- lapply(names(clades_anc), 
	function(x) {
		li <- lapply(clades_anc[[x]], function(y) {
			if (y %in% "chor") {
			z <- clades_pres[[ y ]]
			} else {
			z <- clades_gain[[ y ]]
			}
		z[ z %in% clades_loss[[ x ]] ]
		}
		)
		names(li) <- clades_anc[[x]]
		li
	}
)
names(clades_loss_byorigin) <- names(clades_anc)
clades_loss_byorigin_num <- lapply(names(clades_loss_byorigin), 
function(x) {
	vec <- c(
	unlist(sapply(clades_loss_byorigin[[x]], length)), 
# Sum of gene families lost, should equal the total lost if correct
	"sum" = sum(unlist(sapply(clades_loss_byorigin[[x]], length))), 
# The original number loss
	"total_loss" = clades_loss_num[x])
	vec
	} 
)
names(clades_loss_byorigin_num) <- names(clades_loss_byorigin)


clades_loss_byorigin_noelse <- lapply(names(clades_anc), 
	function(x) {
		li <- lapply(clades_anc[[x]], function(y) {
			clade_loss <- clades_loss_byorigin[[x]][[y]]
			all_other_loss <- unlist(clades_loss[ ! names(clades_loss) %in% x ])
			clade_loss[! clade_loss %in% all_other_loss]
			}
		)
		names(li) <- clades_anc[[x]]
		li
	}
)
names(clades_loss_byorigin_noelse) <- names(clades_anc)

clades_loss_byorigin_noelse_num <- lapply(clades_loss_byorigin_noelse, 
function(x) sapply(x, length)
)
names(clades_loss_byorigin_noelse_num) <- names(clades_loss_byorigin_noelse)

# Put all of the counts in a single object:
orthogroup_evo <- lapply(names(clades_all), function(x)
	list(
	c("pres" = clades_pres_num[x],
	"pres_noloss" = clades_pres_noloss_num[x],
	"gain" = clades_gain_num[x],
	"gain_noloss" = clades_gain_noloss_num[x]),
	"loss" = clades_loss_byorigin_num[[x]],
	"loss_noelse" = clades_loss_byorigin_noelse_num[[x]]
	)
)
names(orthogroup_evo) <- names(clades_all)

# Figure made manually based on results



######
#
#	Enrichment Tests
#
######

# Load necessary functions
library(dplyr)
library(tidyr)
# biomaRt will also be needed, not loaded because of conflicts in command names

# Define a function to generate a contingency table:
cont.table <- function(func.table, group, background) {
functable_bg <- func.table[ func.table$X.cluster_id %in% background , ]
domains <- unique(functable_bg$domain_id)
nongroup <- setdiff(functable_bg$X.cluster_id,group)
X <- lapply(domains, function(x) {
wfxn <- func.table[(func.table$domain_id %in% x),1]
wofxn <- func.table[(! func.table[,1] %in% wfxn),1]
matrix(
c(
length(intersect(group,wfxn)),
length(intersect(nongroup,wfxn)),
length(intersect(group,wofxn)),
length(intersect(nongroup,wofxn))
), byrow = TRUE, nrow=2) })
names(X) <- domains
return(X)
}

# Define a function that performs the enrichment test
# Input func.tables will be made based on kinfin annotation output, corresponds to supplementary functional annotation file

enrich.test <- function(func.tables, group, background) {
conttables <- lapply(func.tables, function(x) { y <- cont.table(x,group = group, background = background); print("Contingency table created."); return(y) })
conttables <- do.call(c,conttables)
fisher <- transform(
	as.data.frame(
		data.frame(t(sapply(conttables,function(x) unlist(fisher.test(x)))), stringsAsFactors = FALSE, check.names = TRUE)[,1:7], stingsAsFactors = FALSE),
		p.value = as.numeric(p.value), 
		conf.int1 = as.numeric(conf.int1), 
		conf.int2 = as.numeric(conf.int2), 
		estimate.odds.ratio = as.numeric(estimate.odds.ratio), 
		null.value.odds.ratio = as.numeric(null.value.odds.ratio)
	)
rownames(fisher) <- gsub(pattern = "GO\\.", "GO:", rownames(fisher))
domains <- as.character(rownames(fisher))
fisher.padj <- p.adjust(fisher$p.value,method = "fdr")
fisher <- data.frame("domains"=domains,fisher,fisher.padj, stringsAsFactors = FALSE)
enrich_results <- list("Contigency.Tables"=conttables,"Fisher.Test"=fisher)
print("Enrichment test complete.")
return(enrich_results)
}

# Determine enriched terms:
enriched.terms <- function(clusters, foreground, enrichment.test.output) {
cluster_ann %>% filter(X.cluster_id %in% foreground,domain_ids %in% (enrichment.test.output[[2]] %>% filter(fisher.padj < .05, estimate.odds.ratio > 1) %>% pull(domains))) %>% select(X.cluster_id,domain_ids,domain_description) %>% count(domain_ids) %>% 
left_join(unique(select(clusters,domain_ids,domain_description)), by = "domain_ids") %>%
left_join(enrichment.test.output[[2]][c(1:2,5,9)], by = c("domain_ids" = "domains")) %>%
arrange(estimate.odds.ratio) %>% as.data.frame() }

# Write function output the final result for Supplementary Table:
enriched.terms.table <- function(clusters, foreground, enrichment.test.output) {
    result <- clusters %>% filter(X.cluster_id %in% foreground,domain_ids %in% (enrichment.test.output[[2]] %>% filter(fisher.padj < .05, estimate.odds.ratio > 1) %>% pull(domains))) %>%
        select(X.cluster_id,domain_ids,domain_description) %>% group_by(domain_ids) %>% mutate(n = n_distinct(X.cluster_id)) %>%
        left_join(enrichment.test.output[[2]][c(1:2,5,9)], by = c("domain_ids" = "domains")) %>% 
        select(domain_ids,domain_description,n,p.value,estimate.odds.ratio,fisher.padj,X.cluster_id) %>% 
        left_join(y = cluster_ensembl_genename_collapse4print, by = c("X.cluster_id" = "clusters")) %>% arrange(domain_description) %>%
        as.data.frame()
    result <- aggregate.data.frame(result, list(result[,1]), function(x) paste0(unique(x))) %>% select(-Group.1) %>% mutate_at(vars(fisher.padj,estimate.odds.ratio),as.numeric) %>% arrange(fisher.padj,estimate.odds.ratio,domain_ids) %>% as.data.frame()
	result$cluster_name <- sapply(result$cluster_name, function(x) unique(unlist(lapply(x, function(y) strsplit(y, ", ")))) ) 
	result$cluster_name <- sapply(result$cluster_name, function(x) x[!x %in% "NA"])
	result$cluster_name <- sapply(result$cluster_name, function(x) paste(x, collapse = ", "))
	result
}

# Function that writes supplementary tables to file
write.enrich.terms.table <- function(enrich.terms.table, filename) {
tab <- enrich.terms.table
tab$X.cluster_id <- lapply(tab$X.cluster_id, function(x) paste(x, collapse = ", "))
tab <- do.call("cbind", tab)
write.table(tab, file = filename, quote=FALSE, row.names=FALSE, sep="\t")
}

# Read in the Orthogroups.csv, an output file from OrthoFinder
orthogroups_csv <- read.table("Orthogroups.csv", stringsAsFactors = FALSE, header = TRUE, sep = "\t")
colnames(orthogroups_csv)[1] <- "clusters"
ensembl_hsap_ids <- orthogroups_csv %>% select(clusters,hsap) %>% transform(hsap = strsplit(hsap, ", ")) %>% unnest(hsap)
mart <- biomaRt::useMart("ensembl",dataset="hsapiens_gene_ensembl")
BM_hsap_ids <- biomaRt::getBM(filters = "ensembl_peptide_id_version", attributes=c("ensembl_peptide_id_version","ensembl_gene_id","external_gene_name","description"), values = ensembl_hsap_ids$hsap, mart = mart)
cluster_ensembl_genename <- full_join(ensembl_hsap_ids, BM_hsap_ids[,c(1:3)], by = c("hsap" = "ensembl_peptide_id_version"))
cluster_ensembl_genename_collapse <- cluster_ensembl_genename %>% select(clusters,external_gene_name) %>% count(clusters,external_gene_name) %>% group_by(clusters) %>% mutate(cluster_name = paste(external_gene_name,n,sep = "_",collapse = ";")) %>% select(clusters,cluster_name) %>% unique()

cluster_ensembl_genename_collapse4print <- cluster_ensembl_genename %>% drop_na() %>% select(clusters,external_gene_name) %>% group_by(clusters) %>% mutate(cluster_name = paste(external_gene_name,collapse = ", ")) %>% select(clusters,cluster_name) %>% unique()

# Supp File 7
write.table(cluster_ensembl_genename_collapse4print, "Supp_File_7_orthogroup_human_gene_names.txt", row.names = FALSE, col.names = FALSE, quote = FALSE)


# Read in kinfin annotation result
# Note, KinFin outputs multiple annotation files, one for each annotation category (e.g. GO, Pfam), this is a single combined table of the results
cluster_ann <- read.table("cluster_functional_annotation.tsv",comment.char = "",header=TRUE,sep="\t",quote="",na.strings = "None",stringsAsFactors = FALSE)

# split annotations so that they're easier to process, and reformat
cluster_n <- length(unique(cluster_ann$X.cluster_id))
cluster_GO <- cluster_ann[1:cluster_n, ]
cluster_Pfam <- cluster_ann[(cluster_n+1):nrow(cluster_ann), ]

# Perform enrichment tests for various categories of gained gene families and lost gene families:
enrich_golfa <- enrich.test(list(cluster_GO,cluster_Pfam), clades_gain$olfa, clades_pres$olfa) # Gene families gained in Olfactores vs. gene families present in Olfactores
enrich_gvert <- enrich.test(list(cluster_GO,cluster_Pfam), clades_gain$vert, clades_pres$vert) # Gene families gained in vertebrates vs. gene families present in vertebrates
enrich_ggnat <- enrich.test(list(cluster_GO,cluster_Pfam), clades_gain$gnat, clades_pres$gnat) # Gene families gained in gnathostomes vs. gene families present in gnathostomes
enrich_goste <- enrich.test(list(cluster_GO,cluster_Pfam), clades_gain$oste, clades_pres$oste) # Gene families gained in Osteichthyes vs. gene families present in Osteichthyes
enrich_gchon <- enrich.test(list(cluster_GO,cluster_Pfam), clades_gain$chon, clades_pres$chon) # Gene families gained in Chondrichthyes vs. gene families present in Chondrichthyes
enrich_loste <- enrich.test(list(cluster_GO,cluster_Pfam), clades_loss$oste, clades_pres$gnat) # Gene families lost in Osteichthyes vs. gene families present in gnathostomes
enrich_lchon <- enrich.test(list(cluster_GO,cluster_Pfam), clades_loss$chon, clades_pres$gnat) # Gene families lost in Chondrichthyes vs. gene families present in gnathostomes
enrich_grtyp <- enrich.test(list(cluster_GO,cluster_Pfam), clades_gain$rtyp, clades_pres$rtyp) # Gene families gained in whale shark vs. gene families present in whale shark
enrich_lrtyp <- enrich.test(list(cluster_GO,cluster_Pfam), clades_loss$rtyp, clades_pres$orec) # Gene families lost in ancestor of whale shark vs. gene families present in MRCA of whale shark and brownspotted bamboo shark

# Tabulate enriched terms

# Gained in Olfactores
enrich_golfa_tab <- enriched.terms.table(cluster_ann,clades_gain$olfa,enrich_golfa)
write.enrich.terms.table(enrich_golfa_tab, "Supp_Table_gain_olfactores.txt")

# Gained vertebrate genes
enrich_gvert_tab <- enriched.terms.table(cluster_ann,clades_gain$vert,enrich_gvert)
write.enrich.terms.table(enrich_gvert_tab, "Supp_Table_gain_vert.txt")

# Gained gnathostome genes
enrich_ggnat_tab <- enriched.terms.table(cluster_ann,clades_gain$gnat,enrich_ggnat)
write.enrich.terms.table(enrich_ggnat_tab, "Supp_Table_gain_gnat.txt")

# Gained in chondricthyans
enrich_gchon_tab <- enriched.terms.table(cluster_ann,clades_gain$chon,enrich_gchon)
write.enrich.terms.table(enrich_gchon_tab, "Supp_Table_gain_chon.txt")

# Lost in chondrichthyans
enrich_lchon_tab <- enriched.terms.table(cluster_ann,clades_loss$chon,enrich_lchon)
write.enrich.terms.table(enrich_lchon_tab, "Supp_Table_loss_chon.txt")

# Gained in osteichthyans
enrich_goste_tab <- enriched.terms.table(cluster_ann,clades_gain$oste,enrich_goste)
write.enrich.terms.table(enrich_goste_tab, "Supp_Table_gain_oste.txt")

# Lost in osteichthyans
enrich_loste_tab <- enriched.terms.table(cluster_ann,clades_loss$oste,enrich_loste)
write.enrich.terms.table(enrich_loste_tab, "Supp_Table_loss_oste.txt")



# Determine proportion of gains in chon vs. oste with annotations

length(
union(clades_gain$chon[clades_gain$chon %in% cluster_GO$X.cluster_id[! is.na(cluster_GO$domain_ids)]], 
	clades_gain$chon[clades_gain$chon %in% cluster_Pfam$X.cluster_id[! is.na(cluster_Pfam$domain_ids)]]
)
)

length(
union(clades_gain$oste[clades_gain$oste %in% cluster_GO$X.cluster_id[! is.na(cluster_GO$domain_ids)]], 
	clades_gain$oste[clades_gain$oste %in% cluster_Pfam$X.cluster_id[! is.na(cluster_Pfam$domain_ids)]]
)
)