#Created by Alicia Rogers

import sys

table_file = open(sys.argv[1], 'r')
table = []
totalsum=0
for line in table_file:
	#id, gene, chrm, start, end = line.split()
	gene, chrm, start, end = line.split()
	item = [gene, chrm, int(start), int(end)]
	table.append(item)

table_file.close()
print(len(table))

print("READING IN POSITIVE FILE")
positive_file = open(sys.argv[2], 'r')

data = []
for line in positive_file:
	if "track" not in line:
		chr, start, end, rpm = line.split()
		item = [chr, start, end, float(rpm)]
		data.append(item)
positive_file.close()

print("READING IN NEGATIVE WIG")
negative_file = open(sys.argv[3], 'r')
for line in negative_file:
	if "track" not in line:
		chr, start, end, rpm = line.split()
		rpm = (float(rpm) * -1)
		item = [chr, start, end, rpm]
		data.append(item)

negative_file.close()

print("SPLITTING DATA BASED ON CHR")

data_I = [] 
data_II = []
data_III = []
data_IV = []
data_V = []
data_X = []
data_mt = []

for line in data:
	#chr, start, end, rpm
	if line[0] == 'I':
		data_I.append(line)
	elif line[0] == 'II':
		data_II.append(line)
	elif line[0] == 'III':
		data_III.append(line)
	elif line[0] == 'IV':
		data_IV.append(line)
	elif line[0] == 'V':
		data_V.append(line)
	elif line[0] == 'IV':
		data_IV.append(line)
	elif line[0] == 'X':
		data_X.append(line)
	elif line[0] == 'MtDNA':
		data_mt.append(line)

table_I = []
table_II = []
table_III = []
table_IV = []
table_V = []
table_X = []
table_mt = []

for line in table:
        #gene, chr, start, end
        if line[1] == 'I':
                table_I.append(line)
        elif line[1] == 'II':
                table_II.append(line)
        elif line[1] == 'III':
                table_III.append(line)
        elif line[1] == 'IV':
                table_IV.append(line)
        elif line[1] == 'V':
                table_V.append(line)
        elif line[1] == 'IV':
                table_IV.append(line)
        elif line[1] == 'X':
                table_X.append(line)
        elif line[1] == 'MtDNA':
                table_mt.append(line)
 		
output = open(sys.argv[4], 'w')
count = 0 
print("SORTING THROUGH DATA ON CHR I")
for x in data_I:
	count += 1
	
	if count % 1000000 == 0:
		print("processed "+ str(count) + " reads")
	y = 0
	for line in table_I:
		if int(x[1]) >= int(line[2]) and int(x[2]) <= int(line[3]):
			table_I[y].append(x[3])
		y += 1
			

data_I = []
print("SUMMING RPMS OF CHR I")
for line in table_I:
	ind = 4
	sum = 0
	while ind < len(line):
		sum += float(line[ind])
		ind += 1
	
	output.write(line[0] + '\t' + line[1] + '\t' + str(line[2]) + '\t' + str(line[3]) + '\t' +str(sum) + '\n') 
	totalsum += sum
table_I = []

count = 0
print("SORTING THROUGH DATA ON CHR II")
for x in data_II:
        count += 1

        if count % 1000000 == 0:
        	print("processed "+ str(count) + " reads")
        y = 0
        for line in table_II:
                if int(x[1]) >= int(line[2]) and int(x[2]) <= int(line[3]):
                        table_II[y].append(x[3])
                y += 1

data_II = []
print("SUMMING RPMS OF CHR II")
for line in table_II:
	ind = 4
	sum = 0
	while ind < len(line):
		sum += float(line[ind])
		ind += 1
	output.write(line[0] + '\t' + line[1] + '\t' + str(line[2]) + '\t' + str(line[3]) + '\t' +str(sum) + '\n')
	totalsum += sum
table_II = []

count = 0
print("SORTING THROUGH DATA ON CHR III")
for x in data_III:
	count += 1

	if count % 1000000 == 0:	
		print("processed "+ str(count) + " reads")
	y = 0
	
	for line in table_III:
		if int(x[1]) >= int(line[2]) and int(x[2]) <= int(line[3]):
			table_III[y].append(x[3])
		y += 1

data_III = []

print("SUMMING RPMS OF CHR III")
for line in table_III:
	ind = 4
	sum = 0
	while ind < len(line):
		sum += float(line[ind])
		ind += 1

	output.write(line[0] + '\t' + line[1] + '\t' + str(line[2]) + '\t' + str(line[3]) + '\t' +str(sum) + '\n')
	totalsum += sum
table_III = []

count = 0
print("SORTING THROUGH DATA ON CHR IV")
for x in data_IV:
	count += 1

	if count % 1000000 == 0:
		print("processed "+ str(count) + " reads")
	y = 0

	for line in table_IV:
		if int(x[1]) >= int(line[2]) and int(x[2]) <= int(line[3]):
			table_IV[y].append(x[3])
		y += 1

data_IV = []

print("SUMMING RPMS OF CHR IV")
for line in table_IV:
	ind = 4
	sum = 0
	while ind < len(line):
		sum += float(line[ind])
		ind += 1
	output.write(line[0] + '\t' + line[1] + '\t' + str(line[2]) + '\t' + str(line[3]) + '\t' +str(sum) + '\n')
	totalsum += sum
table_IV = []

count = 0
print("SORTING THROUGH DATA ON CHR V")
for x in data_V:
	count += 1

	if count % 1000000 == 0:
		print("processed "+ str(count) + " reads")
	y = 0

	for line in table_V:
		if int(x[1]) >= int(line[2]) and int(x[2]) <= int(line[3]):
			table_V[y].append(x[3])
		y += 1

data_V = []

print("SUMMING RPMS OF CHR V")
for line in table_V:
	ind = 4
	sum = 0
	while ind < len(line):
		sum += float(line[ind])
		ind += 1

	output.write(line[0] + '\t' + line[1] + '\t' + str(line[2]) + '\t' + str(line[3]) + '\t' +str(sum) + '\n')
	totalsum += sum
table_V = []

count = 0
print("SORTING THROUGH DATA ON CHR X")
for x in data_X:
	count += 1

	if count % 1000000 == 0:
		print("processed "+ str(count) + " reads")	
	y = 0

	for line in table_X:
		if int(x[1]) >= int(line[2]) and int(x[2]) <= int(line[3]):
			table_X[y].append(x[3])
		y += 1

data_X = []

print("SUMMING RPMS OF CHR X")
for line in table_X:
	ind = 4
	sum = 0
	while ind < len(line):
		sum += float(line[ind])
		ind += 1
	output.write(line[0] + '\t' + line[1] + '\t' + str(line[2]) + '\t' + str(line[3]) + '\t' +str(sum) + '\n')
	totalsum += sum
table_X = []

count = 0
print("SORTING THROUGH DATA ON CHR MtDNA")
for x in data_mt:
	count += 1

	if count % 1000000 == 0:
		print("processed "+ str(count) + " reads")
	y = 0

	for line in table_mt:
		if int(x[1]) >= int(line[2]) and int(x[2]) <= int(line[3]):
			table_mt[y].append(x[3])
		y += 1

data_mt = []

print("SUMMING RPMS OF CHR MtDNA")
for line in table_mt:
	ind = 4
	sum = 0
	while ind < len(line):
		sum += float(line[ind])
		ind += 1
	output.write(line[0] + '\t' + line[1] + '\t' + str(line[2]) + '\t' + str(line[3]) + '\t' +str(sum) + '\n')
	totalsum += sum
table_mt = []

output.close()	

print("Total RPMS: "+ str(totalsum))
