#!/usr/bin/env python3

"""
Usage: ./volcano.py combine4_2HK.txt

generate volcano plot of qRT-PCR data
1. The .txt file was normalized by subtracting avg of 2 HK genes Ct values from gene Ct values (dCt)
2. ddCt calculated for all 4 trials by subtracting MDA-MB-453 dCt values from Hs578T dCt values
3. calculate 2^(-ddCt) for all 4 trials, then take the avg, and these are the x-values
4. Determine p-values from dCt values between MDA-MB-453 and Hs578T to get y-values
5. Plot values, and mark significant ones with log2(fold-change) > 2 or < -2, and -log10(p-val) > 2
"""

import sys
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
#import seaborn as sns

f = open(sys.argv[1])

gene_list = []
fold_list = []
p_list = []
up_gene_list = []
up_fold_list = []
up_p_list = []
down_gene_list = []
down_fold_list = []
down_p_list = []
for i, line in enumerate(f):
    if i == 0:
        continue
    fields = line.rstrip("\r\n").split("\t")
    name = fields[0]
    hs1 = float(fields[1])
    hs2 = float(fields[3])
    hs3 = float(fields[5])
    hs4 = float(fields[7])
    mda1 = float(fields[2])
    mda2 = float(fields[4])
    mda3 = float(fields[6])
    mda4 = float(fields[8])
    avg1 = np.power(2, -(hs1 - mda1))
    avg2 = np.power(2, -(hs2 - mda2))
    avg3 = np.power(2, -(hs3 - mda3))
    avg4 = np.power(2, -(hs4 - mda4))
    avg = (avg1+avg2+avg3+avg4) / 4
    log2_change = np.log2(avg)
    t_test, p_val = stats.ttest_ind([hs1,hs2,hs3,hs4], [mda1,mda2,mda3,mda4])
    log10_p = np.log10(p_val) * -1
    # gene_list.append([name, log2_change, log10_p])
    # if log10_p >= 2 and (log2_change >= 2 or log2_change <= -2):
    if log10_p >= 2 and log2_change >= 2:
        up_gene_list.append(name)
        up_fold_list.append(log2_change)
        up_p_list.append(log10_p)
    if log10_p >= 2 and log2_change <= -2:
        down_gene_list.append(name)
        down_fold_list.append(log2_change)
        down_p_list.append(log10_p)
    else:
        gene_list.append(name)
        fold_list.append(log2_change)
        p_list.append(log10_p)
    
print(up_gene_list)
print(up_fold_list)
print(up_p_list)

print(down_gene_list)
print(down_fold_list)
print(down_p_list)

ds = 40
fs = 14
fig, ax = plt.subplots(figsize=(10,6.67))
hfont = {'fontname':'Arial', 'fontsize':fs}
ax.scatter(fold_list, p_list, s=ds, color='black', label='Not significant')
ax.scatter(up_fold_list, up_p_list, s=ds, color='magenta', label='Significantly upregulated in Hs578T')
ax.scatter(down_fold_list, down_p_list, s=ds, color='green', label='Significantly downregulated in Hs578T')
ax.set_xlabel('log2(Fold change)', **hfont)
ax.set_ylabel('-log10(p-value)', **hfont)
#ax.set_title('Hs578T vs. MDA-MB-453 normalized to 2 HK genes', **hfont)
ax.set_xlim([-8.5,5.5])
for gene, x, y in zip(up_gene_list, up_fold_list, up_p_list):
    ax.annotate(gene, (x, y), (x+0.1, y-0.05))
for gene, x, y in zip(down_gene_list, down_fold_list, down_p_list):
    ax.annotate(gene, (x, y), (x+0.1, y-0.05))
# ax.axhline(2, linestyle='dashed', linewidth=1, color='black')
ax.legend(loc='lower left')
ax.grid(alpha=0.25)
plt.show()
fig.savefig('volcano2_ann2.pdf')
plt.close(fig)
