import pandas as pd 
import numpy as np
import re
from sklearn import datasets, linear_model
from sklearn.feature_selection import f_regression, SelectFdr
from statsmodels.sandbox.stats.multicomp import multipletests
import networkx as nx
# load DEGs


def loadMouseDEGs(fdrThresh):
	homedir = "/home/andrew/"
	mouseDat_all = pd.read_excel(homedir + "paper1/YangSup/nn.4256-S3.xlsx", sheetname = None)
	finalDict = dict.fromkeys(mouseDat_all.keys())
	preDict = dict.fromkeys(mouseDat_all.keys())

	#mouse2Human = pd.read_csv("/home/andrew/paper1/HMD_HumanPhenotype.rpt", sep = "\t", header = None, index_col = False)
	mouse2Human = pd.read_csv("/home/andrew/extData/hdinhd/orthoConv2.txt",sep = "\t")
	mouse2Human2 = mouse2Human.dropna(subset = ['Gene ID.1'])
	mouse2Human2['human'] = mouse2Human2['Gene ID.1'].apply(lambda x: int(x.split(';')[0]))
	convDictMH = dict(zip(mouse2Human2['Gene ID'],mouse2Human2['human']))


	for sheet in mouseDat_all.keys():
		mouseDat = mouseDat_all[sheet]
		fdrCols = [col for col in mouseDat.columns if 'FDR' in col]
		fcCols = [col for col in mouseDat.columns if 'log2FoldChange' in col]

		keys = [col.lstrip("FDR.") for col in fdrCols]
		mouseUpGenes = dict.fromkeys(keys)
		mouseDownGenes = dict.fromkeys(keys)

		for key, fdr, fc in zip(keys,fdrCols, fcCols):
			mouseUpGenes[key] = mouseDat[(mouseDat[fdr] < fdrThresh) & (mouseDat[fc] > 0)]['Entrez']
			mouseDownGenes[key] = mouseDat[(mouseDat[fdr] < fdrThresh) & (mouseDat[fc] < 0)]['Entrez']

	# converting mouse IDs to human entrez genes
		mouseUpGenes_human = dict.fromkeys(keys)
		mouseDownGenes_human = dict.fromkeys(keys)



		for key in keys:
			mouseUpGenes_human[key] = [convDictMH[gene] for gene in mouseUpGenes[key] if gene in convDictMH.keys()]
			mouseDownGenes_human[key] = [convDictMH[gene] for gene in mouseDownGenes[key] if gene in convDictMH.keys()]


		finalDict[sheet] = mouseUpGenes_human, mouseDownGenes_human
		preDict[sheet] = mouseUpGenes, mouseDownGenes
	return finalDict

def loadFlyDEGs(tissue, fdrThresh):
	neuron_comp = ['EN{0}vsEW{0}'.format(str(age)) for age in [7,9,11]] + ['EF{0}vsEW{0}'.format(str(age)) for age in [18,20,22]]
	glia_comp = ['RN{0}vsRW{0}'.format(str(age)) for age in [5,7,8]] + ['RF{0}vsRW{0}'.format(str(age)) for age in [18,20,22]]

	if tissue == "Neuron":
		comp = neuron_comp
		deseq = pd.read_excel("/home/andrew/BotasRNASeq/Report/DESeq_all.neuron_v2.xlsx")

	else:
		comp = glia_comp
		deseq = pd.read_excel("/home/andrew/BotasRNASeq/Report/DESeq_all.glia.xlsx")

	deseq['fbgene'] = deseq['gene'].apply(lambda x: x.split(':')[1])
	deseq.index = deseq['fbgene']
	deseq2 = deseq.drop_duplicates(subset = 'fbgene', keep ='first')
	deseq2.index = deseq2['fbgene']

	padjCols = [col for col in deseq2.columns if '_padj' in col]
	fcCols = [col for col in deseq2.columns if 'log2' in col]
	flyUpGenes = dict.fromkeys(comp)
	flyDownGenes = dict.fromkeys(comp)

	for key, fdr, fc in zip(comp,padjCols, fcCols):
		flyUpGenes[key] = deseq2[(deseq2[fdr] < fdrThresh) & (deseq2[fc] > 0)]['fbgene']
		flyDownGenes[key]= deseq2[(deseq2[fdr] < fdrThresh) & (deseq2[fc] < 0)]['fbgene']

	# converting genes to human identifieers
	flyConvTable = pd.read_excel('/home/andrew/BotasRNASeq/orthologs2.xlsx')
	flyConvTable2 = flyConvTable.dropna(subset = ['Human GeneID'])
	convDictFH = dict(zip(flyConvTable2['Fly GeneID'], flyConvTable2['Human GeneID']))
	
	flyUpGenes_human = dict.fromkeys(comp)
	flyDownGenes_human = dict.fromkeys(comp)

	for key in comp:
		flyUpGenes_human[key] = [convDictFH[gene] for gene in flyUpGenes[key] if gene in convDictFH.keys()]
		flyDownGenes_human[key] = [convDictFH[gene] for gene in flyDownGenes[key] if gene in convDictFH.keys()]

	return flyUpGenes_human, flyDownGenes_human

def loadFlyDEGs2(tissue, fdrThresh):
	neuron_comp = ['EN{0}vsEW{0}'.format(str(age)) for age in [7,9,11]] + ['EF{0}vsEW{0}'.format(str(age)) for age in [18,20,22]]
	glia_comp = ['RN{0}vsRW{0}'.format(str(age)) for age in [5,7,8]] + ['RF{0}vsRW{0}'.format(str(age)) for age in [18,20,22]]

	#tissue = "Neuron"
	#fdrThresh = 0.05

	if tissue == "Neuron":
		comp = neuron_comp
		deseq = pd.read_excel("/home/andrew/BotasRNASeq/Report/DESeq_all.neuron_v2.xlsx")

	else:
		comp = glia_comp
		deseq = pd.read_excel("/home/andrew/BotasRNASeq/Report/DESeq_all.glia.xlsx")

	deseq['fbgene'] = deseq['gene'].apply(lambda x: x.split(':')[1])
	deseq.index = deseq['fbgene']
	deseq2 = deseq.drop_duplicates(subset = 'fbgene', keep ='first')
	deseq2.index = deseq2['fbgene']

	padjCols = [col for col in deseq2.columns if '_padj' in col]
	fcCols = [col for col in deseq2.columns if 'log2' in col]
	flyUpGenes = dict.fromkeys(comp)
	flyDownGenes = dict.fromkeys(comp)

	flyUpFCs = dict.fromkeys(comp)
	flyDownFCs = dict.fromkeys(comp)


	for key, fdr, fc in zip(comp,padjCols, fcCols):
		flyUpGenes[key] = deseq2[(deseq2[fdr] < fdrThresh) & (deseq2[fc] > 0)]['fbgene']
		flyDownGenes[key]= deseq2[(deseq2[fdr] < fdrThresh) & (deseq2[fc] < 0)]['fbgene']

		flyUpFCs[key] = {fbgene : deseq2.loc[fbgene,fc] for fbgene in flyUpGenes[key]}
		flyDownFCs[key] = {fbgene : deseq2.loc[fbgene,fc] for fbgene in flyDownGenes[key]}


	# converting genes to human identifieers
	flyConvTable = pd.read_excel('/home/andrew/BotasRNASeq/orthologs2.xlsx')
	flyConvTable2 = flyConvTable.dropna(subset = ['Human GeneID'])
	convDictFH = dict(zip(flyConvTable2['Fly GeneID'], flyConvTable2['Human GeneID']))

	flyUpGenes_human = dict.fromkeys(comp)
	flyDownGenes_human = dict.fromkeys(comp)
	flyUpFCs_human = dict.fromkeys(comp)
	flyDownFCs_human = dict.fromkeys(comp)

	for key in comp:
		flyUpGenes_human[key] = [convDictFH[gene] for gene in flyUpGenes[key] if gene in convDictFH.keys()]
		flyDownGenes_human[key] = [convDictFH[gene] for gene in flyDownGenes[key] if gene in convDictFH.keys()]
		flyUpFCs_human[key] = {convDictFH[gene] : fc for gene, fc in flyUpFCs[key].items() if gene in convDictFH.keys()}
		flyDownFCs_human[key] = {convDictFH[gene] : fc for gene, fc in flyDownFCs[key].items() if gene in convDictFH.keys()}

	return ((flyUpGenes_human, flyDownGenes_human), (flyUpFCs_human, flyDownFCs_human))

def loadFlyTable():
	neuron_comp = ['EN{0}vsEW{0}'.format(str(age)) for age in [7,9,11]] + ['EF{0}vsEW{0}'.format(str(age)) for age in [18,20,22]]
	glia_comp = ['RN{0}vsRW{0}'.format(str(age)) for age in [5,7,8]] + ['RF{0}vsRW{0}'.format(str(age)) for age in [18,20,22]]


	deseq_n = pd.read_excel("/home/andrew/BotasRNASeq/Report/DESeq_all.neuron_v2.xlsx")
	deseq_g = pd.read_excel("/home/andrew/BotasRNASeq/Report/DESeq_all.glia.xlsx")

	deseq_n['fbgene'] = deseq_n['gene'].apply(lambda x: 'FB' + x.split(':FB')[1])
	deseq_g['fbgene'] = deseq_g['gene'].apply(lambda x: 'FB' + x.split(':FB')[1])

	deseq_n = deseq_n.drop_duplicates(subset = 'fbgene', keep ='first')
	deseq_g = deseq_g.drop_duplicates(subset = 'fbgene', keep = 'first')

	deseq = pd.merge(left = deseq_n, right = deseq_g, on = "fbgene", how = 'outer')
	deseq.index = deseq['fbgene']

	uniqueFB = list(set(deseq.index))

	#with open('/home/andrew/BotasRNASeq/allGenes3.txt','w') as f:
	#	[f.write(str(gene) + '\n') for gene in uniqueFB]

	#padjCols = [col for col in deseq2.columns if '_padj' in col]
	#fcCols = [col for col in deseq2.columns if 'log2' in col]

	flyConvTable = pd.read_excel('/home/andrew/BotasRNASeq/orthologs2.xlsx')
	flyConvTable2 = flyConvTable.dropna(subset = ['Human GeneID'])
	'''
	mylist = set(flyMouseTab[np.isnan(flyMouseTab['Mouse GeneID'])]['FlyBaseID'])
	with open('/home/andrew/BotasRNASeq/fbNotMappedToMouse.txt','w') as f:
		[f.write(str(gene) + '\n') for gene in mylist]
	'''


	G = nx.Graph()
	for row in flyConvTable2.iterrows():
		G.add_node(row[1]['Fly GeneID'], species = 'fly', symbol = row[1]['Fly Symbol'])
		G.add_node(row[1]['Human GeneID'], species = 'human', symbol = row[1]['Human Symbol'])
		G.add_edge(row[1]['Fly GeneID'], int(row[1]['Human GeneID']))

	#convDictFH = dict(zip(flyConvTable2['Fly GeneID'], flyConvTable2['Human GeneID']))
	'''
	def convToEntrez(flyid):
		try:
			return convDictFH[flyid]
		except:
			return 'NA'

	deseq['Entrez'] = deseq['fbgene'].apply(convToEntrez)
	'''
	return deseq, G



def loadFlyTable2():
	# Loading fly table for plots based on my mac

	baseDir = '/Users/andrew/rna/'
	neuron_comp = ['EN{0}vsEW{0}'.format(str(age)) for age in [7,9,11]] + ['EF{0}vsEW{0}'.format(str(age)) for age in [18,20,22]]
	glia_comp = ['RN{0}vsRW{0}'.format(str(age)) for age in [5,7,8]] + ['RF{0}vsRW{0}'.format(str(age)) for age in [18,20,22]]


	deseq_n = pd.read_excel(baseDir + "BotasRNASeq/Report/DESeq_all.neuron_v2.xlsx")
	deseq_g = pd.read_excel(baseDir + "BotasRNASeq/Report/DESeq_all.glia.xlsx")

	deseq_n['fbgene'] = deseq_n['gene'].apply(lambda x: 'FB' + x.split(':FB')[1])
	deseq_g['fbgene'] = deseq_g['gene'].apply(lambda x: 'FB' + x.split(':FB')[1])

	deseq_n = deseq_n.drop_duplicates(subset = 'fbgene', keep ='first')
	deseq_g = deseq_g.drop_duplicates(subset = 'fbgene', keep = 'first')

	deseq = pd.merge(left = deseq_n, right = deseq_g, on = "fbgene", how = 'outer')
	deseq.index = deseq['fbgene']

	uniqueFB = list(set(deseq.index))

	#with open('/home/andrew/BotasRNASeq/allGenes3.txt','w') as f:
	#	[f.write(str(gene) + '\n') for gene in uniqueFB]

	#padjCols = [col for col in deseq2.columns if '_padj' in col]
	#fcCols = [col for col in deseq2.columns if 'log2' in col]

	flyConvTable = pd.read_excel(baseDir + 'BotasRNASeq/orthologs2.xlsx')
	flyConvTable2 = flyConvTable.dropna(subset = ['Human GeneID'])
	'''
	mylist = set(flyMouseTab[np.isnan(flyMouseTab['Mouse GeneID'])]['FlyBaseID'])
	with open('/home/andrew/BotasRNASeq/fbNotMappedToMouse.txt','w') as f:
		[f.write(str(gene) + '\n') for gene in mylist]
	'''


	G = nx.Graph()
	for row in flyConvTable2.iterrows():
		G.add_node(row[1]['Fly GeneID'], species = 'fly', symbol = row[1]['Fly Symbol'])
		G.add_node(row[1]['Human GeneID'], species = 'human', symbol = row[1]['Human Symbol'])
		G.add_edge(row[1]['Fly GeneID'], int(row[1]['Human GeneID']))

	#convDictFH = dict(zip(flyConvTable2['Fly GeneID'], flyConvTable2['Human GeneID']))
	'''
	def convToEntrez(flyid):
		try:
			return convDictFH[flyid]
		except:
			return 'NA'

	deseq['Entrez'] = deseq['fbgene'].apply(convToEntrez)
	'''
	return deseq, G



'''
# testing code
testnew[0]
'''
def loadHodgesGenes():
	gsemat = pd.read_csv("/home/andrew/GSE3790/GSE3790-GPL97_series_matrix.txt", sep = "\t", header = 81, index_col = 0)
	descmatall = pd.read_csv("/home/andrew/GSE3790/GSE3790-GPL97_series_matrix.txt", sep = "\t", header = 50, index_col = 0, nrows = 15)

	def extractFeatures(charstring):
		teststring = charstring
		genotype = int(re.search('/\d*',teststring).group().lstrip('/'))
		age = int(re.search('Age = \d*',teststring).group().lstrip('Age = '))
		sex = re.search('sex = \w',teststring).group().lstrip('sex = ')
		if re.search('grade',teststring) == None:
			grade = 0
		else:
			grade = int(re.search('grade \d',teststring).group().lstrip('grade '))
		return genotype, age, sex, grade

	gsegenotype = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[0])
	gseage = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[1])
	gsesex = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[2])
	gsegrade = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[3])

	caudateIndices = descmatall.loc[descmatall.index[5],] == "Caudate Nucleus"
	lowgradeindices = (gsegrade < 3) 
	mask = caudateIndices & lowgradeindices

	gsemat2 = gsemat.dropna(axis = 0).T[mask].apply(lambda x: np.log(x + 1))
	age = gseage[mask]
	genotype = gsegenotype[mask]

	grade = gsegrade[mask]
	#editing gsegrade to replace C with 0
	#grade[grade == 'C'] = 0
	features = pd.concat([genotype, grade, age],axis = 1)

	# initializing regression model
	regr = linear_model.LinearRegression()

	regcoef = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade','age'])
	fcoef = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade','age'])
	fpvals = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade','age'])

	affygenes = gsemat2.columns

	for gene in affygenes:
		print(gene)
		y = gsemat2[gene]
		freg = f_regression(np.array(features), y, center = True)
		fcoef.loc[gene] = freg[0]
		fpvals.loc[gene] = freg[1]


	corpvals = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade'])

	# correcting pvalues
	fdrthresh = 0.01

	corpvals['cag'] = multipletests(fpvals['cag'], alpha = fdrthresh, method = "fdr_bh")[1]
	corpvals['grade'] = multipletests(fpvals['grade'], alpha = fdrthresh, method = "fdr_bh")[1]

	threshGenes = corpvals[corpvals.apply(lambda x: (x[0] < fdrthresh) | (x[1] < fdrthresh), axis = 1)].index

	'''
		affyconv = pd.read_csv("/home/andrew/NP/affy2Entrezb.txt", sep = "\t")
		affydict = dict(zip(affyconv['From'],affyconv['To']))
	'''

	affyconv2 = pd.read_csv("/home/andrew/paper1/HodgesSup/GPL97-17394.txt", header = 16, sep = "\t")
	affyconv3 = affyconv2.dropna(subset = ['ENTREZ_GENE_ID'])
	affydict2 = dict(zip(affyconv3['ID'],affyconv3['ENTREZ_GENE_ID']))
	entrezHumanStrings = [affydict2[probe] for probe in threshGenes if probe in affydict2.keys()]
	entrezHuman = [int(gene.split(' /// ')[0]) for gene in entrezHumanStrings]

	return entrezHuman

def loadAllHodgesGenes():
	gsemat = pd.read_csv("/home/andrew/GSE3790/GSE3790-GPL97_series_matrix.txt", sep = "\t", header = 81, index_col = 0)
	descmatall = pd.read_csv("/home/andrew/GSE3790/GSE3790-GPL97_series_matrix.txt", sep = "\t", header = 50, index_col = 0, nrows = 15)

	def extractFeatures(charstring):
		teststring = charstring
		genotype = int(re.search('/\d*',teststring).group().lstrip('/'))
		age = int(re.search('Age = \d*',teststring).group().lstrip('Age = '))
		sex = re.search('sex = \w',teststring).group().lstrip('sex = ')
		if re.search('grade',teststring) == None:
			grade = -1
		else:
			grade = int(re.search('grade \d',teststring).group().lstrip('grade '))
		return genotype, age, sex, grade

	gsegenotype = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[0])
	gseage = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[1])
	gsesex = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[2])
	gsegrade = descmatall.loc[descmatall.index[7],].apply(lambda x: extractFeatures(x)[3])

	caudateIndices = descmatall.loc[descmatall.index[5],] == "Caudate Nucleus"
		 
	mask = caudateIndices

	gsemat2 = gsemat.dropna(axis = 0).T[mask].apply(lambda x: np.log(x + 1))
	age = gseage[mask]
	genotype = gsegenotype[mask]

	grade = gsegrade[mask]
	#editing gsegrade to replace C with 0
	#grade[grade == 'C'] = 0
	features = pd.concat([genotype, grade, age],axis = 1)

	# initializing regression model
	regr = linear_model.LinearRegression()

	regcoef = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade','age'])
	fcoef = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade','age'])
	fpvals = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade','age'])

	affygenes = gsemat2.columns

	for gene in affygenes:
		print(gene)
		y = gsemat2[gene]
		freg = f_regression(np.array(features), y, center = True)
		fcoef.loc[gene] = freg[0]
		fpvals.loc[gene] = freg[1]


	corpvals = pd.DataFrame(index = gsemat2.columns, columns = ['cag','grade'])

	# correcting pvalues
	fdrthresh = 0.01

	corpvals['cag'] = multipletests(fpvals['cag'], alpha = fdrthresh, method = "fdr_bh")[1]
	corpvals['grade'] = multipletests(fpvals['grade'], alpha = fdrthresh, method = "fdr_bh")[1]

	threshGenes = corpvals[corpvals.apply(lambda x: (x[0] < fdrthresh) | (x[1] < fdrthresh), axis = 1)].index

	'''
		affyconv = pd.read_csv("/home/andrew/NP/affy2Entrezb.txt", sep = "\t")
		affydict = dict(zip(affyconv['From'],affyconv['To']))
	'''

	affyconv2 = pd.read_csv("/home/andrew/paper1/HodgesSup/GPL97-17394.txt", header = 16, sep = "\t")
	affyconv3 = affyconv2.dropna(subset = ['ENTREZ_GENE_ID'])
	affydict2 = dict(zip(affyconv3['ID'],affyconv3['ENTREZ_GENE_ID']))
	entrezHumanStrings = [affydict2[probe] for probe in threshGenes if probe in affydict2.keys()]
	entrezHuman = [int(gene.split(' /// ')[0]) for gene in entrezHumanStrings]

	return entrezHuman


#missingHuman = [probe.upper() for probe in threshGenes if probe.upper() not in affydict.keys()]

'''
#### converting mouse to human
mouse2Human = pd.read_csv("/home/andrew/paper1/HMD_HumanPhenotype.rpt", sep = "\t", header = None, index_col = False)

convDictMH = dict(zip(mouse2Human.iloc[:,2],mouse2Human.iloc[:,1]))


stri10_175 = list(stri10[0]['Q.175.vs.20']) + list(stri10[1]['Q.175.vs.20'])
stri10_175h = [convDictMH[gene] for gene in stri10_175 if gene in convDictMH.keys()]


allStri = []
for stri in stri2,stri6,stri10:
	for key, val in stri[0]:
		allStri.extend(val)
	for key, val in stri[1]:
		allStri.append(val)


allstri_h = [convDictMH[gene] for gene in allStri if gene in convDictMH.keys()]
#converting fly to Human
flyConvTable = pd.read_excel('/home/andrew/BotasRNASeq/orthologs2.xlsx')

convDictFH = dict(zip(flyConvTable['Fly GeneID'], flyConvTable['Human GeneID']))

keys = convDictFH.keys()
for gene in keys:
	if np.isnan(convDictFH[gene]):
		convDictFH[gene] = 0

Nall = {}
for i in range(len(neuron_comp)):
	Nall[i] = list(flyN[0][neuron_comp[i]]) + list(flyN[1][neuron_comp[i]])

Nallh= {}
for i in range(len(neuron_comp)):
	Gallh[i] = np.unique([int(convDictFH[gene]) for gene in Nall[i] if gene in convDictFH.keys()])

Gall = {}
for i in range(len(glia_comp)):
	Gall[i] = list(flyG[0][glia_comp[i]]) + list(flyG[1][glia_comp[i]])

Gallh= {}
for i in range(len(glia_comp)):
	Gallh[i] = np.unique([int(convDictFH[gene]) for gene in Gall[i] if gene in convDictFH.keys()])

flyset = list(Gallh[5]) + list(Gallh[4]) + list(Gallh[3])
mouseset = stri10_175h
#only human
onlyh = [gene for gene in entrezHuman if (gene not in flyset) and (gene not in mouseset)]
onlym = [gene for gene in mouseset if (gene not in entrezHuman) and (gene not in flyset)]
onlyf = [gene for gene in flyset if (gene not in entrezHuman) and (gene not in mouseset)]

hm = set(entrezHuman) & set(mouseset)
hf = set(entrezHuman) & set(flyset)
mf = set(mouseset) & set(flyset)

hmf = set(mouseset) & set(flyset) & set(entrezHuman)

for gene in affygenes:
	print(gene)
	y_pre = gsemat2[gene]
	y = (y_pre - np.mean(y_pre))/np.std(y_pre)
	regr.fit(np.array(features),y)
	regcoef.loc[gene] = regr.coef_

# function for plotting
def plotVD(flyset,mouseset,humanset):
	# input flygenes, mousegenes, humangenes as lists
'''


