#!/usr/bin/python

# this script calculates a modified traveling ratio (mTR) for a list of genes. it requires a configured database
# containing loci information for the queried genes as well as information about enhancer marks (H3K27Ac)
# output will be in the format:
# GeneName	CumulativePausedPolII(pPolII)	CumulativeTravelingPolIIDensity(tPolII)	pPolII/tPolII

# syntax: calculateTR.py <gene file.txt> <bedgraph file.bdg>
# <gene file.txt> is simply a list of genes you wish to analyze
# <bedgraph file.bdg> is a bedgraph containing ChIP-seq read densities for the mark you are analysing (usually pol II)


import sys
import csv
import MySQLdb
import itertools

genefile = sys.argv[1]
bdgfile = sys.argv[2]
x = 0
lastchr = 'start'
bdgdict = {}
bdgdict['start'] = [0,0]
with open(bdgfile) as bdg:
	for line in bdg:
		linear = line.split("\t")	
		if linear[0] != lastchr:
			bdgdict[lastchr][1] = x
			bdgdict[linear[0]] = [x+1,x+1]
		x = x+1
		lastchr = linear[0]
bdgdict[lastchr][1] = x-1
genelines = list(csv.reader(open(genefile, 'rb'),delimiter='\t'))
genes = list(itertools.chain.from_iterable(genelines))
# in order to use this script, you'll have to have a database that contains a
# GENE_ISOFORM table containing address info for the genes you're analyzing.
# this db will also need to have peak information regarding H3K27Ac in a CHIP_PEAK table.
db = MySQLdb.connect(host="localhost",user="TATuser",passwd="TATpass",db="TATdb")
cur = db.cursor()
cur.execute("SELECT chr, chr_start, chr_end, strand,name FROM GENE_ISOFORM where name in (" + ",".join("'" + gene + "'" for gene in genes) + ") ORDER BY chr, chr_start")

areadict = {}
lastlineread = 0
for row in cur.fetchall():
	chr = row[0]
	chrstart = int(row[1])
	chrend = int(row[2])
	strand = row[3]
	name = row[4]

	if strand == "+":
		promoter = [chrstart-50, chrstart+1000]
		body = [chrstart+1001,chrend]
		totalstart = promoter[0]
		totalend = body[1]
	else:
		promoter = [chrend-1000, chrend+50]
		body = [chrstart, chrend-1001]
		totalstart = body[0]
		totalend = promoter[1]

	totalprom = 0
	totalbody = 0
	chrlinestart = bdgdict[chr][0]
	chrlineend = bdgdict[chr][1]
	bodydump = []
	with open(bdgfile) as bdg:
		for line in list(bdg)[chrlinestart:chrlineend+1]:
			bdgline = line.split("\t")
			promarea = bodyarea = 0
			if float(bdgline[3]) <=0:
				continue
			if int(bdgline[1]) > int(totalend):
				break
			if strand == "+":
				if bdgline[0] == chr and int(bdgline[1]) >= promoter[0] and int(bdgline[1]) <= promoter[1]:
					if int(bdgline[2]) <= promoter[1]:
					#whole thing is inside promoter
						promarea = (float(bdgline[2]) - float(bdgline[1])) * float(bdgline[3])
					else:
					#stretches into body, so calculate them separately
						promarea = (float(promoter[1]) - float(bdgline[1])) * float(bdgline[3])
						if int(bdgline[2]) <= body[1]:
							bodyarea = (float(bdgline[2]) - float(body[0])) * float(bdgline[3])
							bodydump.append([int(bdgline[2]),int(body[0]),float(bdgline[3])])
						else:
							bodyarea = (float(body[1]) - float(body[0])) * float(bdgline[3])
							bodydump.append([int(body[1]),int(body[0]),float(bdgline[3])])
				elif bdgline[0] == chr and int(bdgline[1]) >= body[0] and int(bdgline[1]) <= body[1]:
					if int(bdgline[2]) <= body[1]:
						bodyarea = (float(bdgline[2]) - float(bdgline[1])) * float(bdgline[3])	
						bodydump.append([int(bdgline[2]),int(bdgline[1]),float(bdgline[3])])
					else:
						bodyarea = (float(body[1]) - float(bdgline[1])) * float(bdgline[3])
						bodydump.append([int(body[1]),int(bdgline[1]),float(bdgline[3])])
			else:
				if bdgline[0] == chr and int(bdgline[1]) >= body[0] and int(bdgline[1]) <= body[1]:
					if int(bdgline[2]) <= body[1]:
					#whole thing is inside body
						bodyarea = (float(bdgline[2]) - float(bdgline[1])) * float(bdgline[3])
						bodydump.append([int(bdgline[2]),int(bdgline[1]),float(bdgline[3])])
					else:
					#stretches into promoter, so calculate them separately
						bodyarea = (float(body[1]) - float(bdgline[1])) * float(bdgline[3])
						bodydump.append([int(body[1]),int(bdgline[1]),float(bdgline[3])])
						if int(bdgline[2]) <= promoter[1]:
							promarea = (float(bdgline[2]) - float(promoter[0])) * float(bdgline[3])
						else:
							promarea = (float(promoter[1]) - float(promoter[0])) * float(bdgline[3])
				elif bdgline[0] == chr and int(bdgline[1]) >= promoter[0] and int(bdgline[1]) <= promoter[1]:
					if int(bdgline[2]) <= promoter[1]:
						promarea = (float(bdgline[2]) - float(bdgline[1])) * float(bdgline[3])	
					else:
						promarea = (float(promoter[1]) - float(bdgline[1])) * float(bdgline[3])
			totalprom = totalprom + promarea
			totalbody = totalbody + bodyarea
	cur.execute("SELECT chr, chr_start, chr_end FROM CHIP_PEAK where gene_name = '" + name + "' and mark like 'H3K27Ac%' and source = '20141219MACS2n'  ORDER BY chr, chr_start")
	K27peaks = cur.fetchall() 
	areadict[name] = [totalprom/(promoter[1]-promoter[0]),totalbody/(body[1]-body[0])]
	correctedbody = 0
	addpromoter = 0
	lengthdelta = 0
	correcteddict = {}
	for row in bodydump:
		length = int(row[0]) - int(row[1])
		insideK27 = False
		if float(row[2]) > float(areadict[name][1]) * 5:
			for peak in K27peaks:
				if row[1] > peak[1]-500 and row[0] < peak[2] + 500:	#row[0] is bigger than row[1]. Look at the bodydump lines above to see why
					insideK27 = True
					break
		if insideK27:
			addpromoter += float(length) * float(row[2])
			lengthdelta += length
		else:
			correctedbody += float(length) * float(row[2])
	correcteddict[name] = [(totalprom + addpromoter)/(promoter[1]-promoter[0]+lengthdelta),correctedbody/(body[1]-body[0]-lengthdelta)]
	if correcteddict[name][1] == 0:
		print str(name) + "\t" + str(correcteddict[name][0]) + "\t" + str(correcteddict[name][1]) + "\tNaN"
	else:
		print str(name) + "\t" + str(correcteddict[name][0]) + "\t" + str(correcteddict[name][1]) + "\t" + str(correcteddict[name][0]/correcteddict[name][1])
