import pandas as pd 
import numpy as np
import re
import pickle

from loadData2 import *
from loadMouse import *
from loadHuman import *

import matplotlib
matplotlib.use('Agg')
import pylab as plt
#from matplotlib_venn import venn3, venn2
import multiprocessing as mp
import itertools
import networkx as nx
from plotInterMatch import *
from matplotlib import rc
import os
import csv
from itertools import zip_longest

# load datasets

humanTab = loadHodgesList()
print("Loading Mouse...")

mouseTab, mouseGraph = loadMouseTable()
print("Loading Fly...")

deseq, flyGraph = loadFlyTable()


# Mouse
padj_m = np.array(['FDR' in col for col in mouseTab.columns])
fc_m = np.array(['log2' in col for col in mouseTab.columns])

striatumCase = np.array(['Striatum' in col for col in mouseTab.columns])
cortexCase = np.array(['Cortex' in col for col in mouseTab.columns])
liverCase = np.array(['Liver' in col for col in mouseTab.columns])
case2 = np.array(['2 months' in col for col in mouseTab.columns])
case6 = np.array(['6 months' in col for col in mouseTab.columns])
case10 = np.array(['10 months' in col for col in mouseTab.columns])
contCase = np.array(['continuous' in col for col in mouseTab.columns])



# Fly
padj_f = np.array(['padj' in col for col in deseq.columns])
fc_f = np.array(['log2' in col for col in deseq.columns])
neuronCase = np.array(['E' in col for col in deseq.columns])
gliaCase = np.array(['R' in col for col in deseq.columns])
ntCase = np.array(['N' in col for col in deseq.columns])
flCase = np.logical_or(['EF' in col for col in deseq.columns],['RF' in col for col in deseq.columns])


earlyStrings = ['RN5', 'RF18','EN7','EF18']
middleStrings = ['RN7','RF20', 'EN9','EF20']
lateStrings = ['RN8','RF22','EN11', 'EF22']

earlyCase = deseq.columns.map(lambda x: np.any([string in x for string in earlyStrings]))
middleCase = deseq.columns.map(lambda x: np.any([string in x for string in middleStrings]))
lateCase = deseq.columns.map(lambda x: np.any([string in x for string in lateStrings]))


# create flyMouseGraph
flyMouseTab = pd.read_csv('/home/andrew/BotasRNASeq/orthologsFM.csv')
flyMouseTab = flyMouseTab.dropna(subset = ['Mouse GeneID'], axis = 0)
#flyMouseTab = flyMouseTab[flyMouseTab['DIOPT Score'] > 1]
G = nx.Graph()
for row in flyMouseTab.iterrows():
	flyGene = row[1]['Fly GeneID']	
	mouseGene = int(row[1]['Mouse GeneID'])
	G.add_node(flyGene, species = 'fly')
	G.add_node(mouseGene, species = 'mouse')
	G.add_edge(flyGene, mouseGene)

# defining quantComp function
def quantComps(net, how = 'max', numSpec = 2):
	#numSpec = 2
	#how = 'max'
	#net = flyGraph

	netComps = list(sorted(nx.connected_component_subgraphs(net),key = len, reverse = True))
	numComps = len(netComps)
	myCount = {i : 0 for i in range(numComps)}

	#creating geneTable
	geneTab = pd.DataFrame(columns = ['Human','Mouse','Fly', 'Count'])


	for i, comp in enumerate(netComps):
		nonZeroNumsList = {}
		speciesDict = nx.get_node_attributes(comp,'species')

		genes_h = [k for k, v in speciesDict.items() if v == 'human']
		genes_m = [k for k, v in speciesDict.items() if v == 'mouse']
		genes_f = [k for k, v in speciesDict.items() if v == 'fly']
		geneTab.loc[i] = [genes_h,genes_m,genes_f,0]
		numHuman = list(speciesDict.values()).count('human')
		numMouse = list(speciesDict.values()).count('mouse')
		numFly = list(speciesDict.values()).count('fly')

		allNums = np.array([numHuman,numMouse,numFly])
		nonZeroNums = allNums[np.nonzero(allNums)[0]]
		nonZeroNumsList[i] = nonZeroNums

		if len(nonZeroNums) < numSpec:
			geneTab.loc[i,'Count'] = 0
		else:
			if how == 'max':
				geneTab.loc[i,'Count'] = np.max(nonZeroNums)
			elif how == 'min':
				geneTab.loc[i,'Count'] = np.min(nonZeroNums)
			elif how == 'armean':
				geneTab.loc[i,'Count'] = np.mean(nonZeroNums)
			elif how == 'geomean':
				geneTab.loc[i,'Count'] = np.geomean(nonZeroNums)

	result = geneTab.sort_values(by = 'Count', ascending = False)

	return int(sum(geneTab['Count'])), result

# process geneTab files into lists
def getGeneLists(geneTab):
	geneTab_sub = geneTab[geneTab['Count'] > 0]
	geneTab_fly = list(itertools.chain(*geneTab_sub['Fly']))
	geneTab_mouse = list(itertools.chain(*geneTab_sub['Mouse']))
	geneTab_human = list(itertools.chain(*geneTab_sub['Human']))
	return(geneTab_fly,geneTab_mouse,geneTab_human)


# calculating gene sets
# label Human
sigHuman = humanTab['adjp'] < 0.05
goodConv = humanTab['validList']
humanMask = sigHuman & goodConv
humanUp = humanTab['M'] > 0.263
humanDown = humanTab['M'] < -0.263

humanTab['geneIn'] = humanTab['Entrez'] == '1994'

humanUpGenes = list(set([int(x[0]) for x in humanTab[humanMask & humanUp]['entrezList_filt']]))
humanDownGenes = list(set([int(x[0]) for x in humanTab[humanMask & humanDown]['entrezList_filt']]))
#humanGenes = [int(gene) for gene in humanTab[humanMask]['Entrez2']]
#countHuman = len(set(humanuGenes))
countHumanUp = len(set(humanUpGenes))
countHumanDown = len(set(humanDownGenes))

bothUpDown = set(humanUpGenes) & set(humanDownGenes)

'''
filename = '/home/andrew/paper1/IsolatedMakeNetwork/humanDownGenes.txt'
genelist = humanDownGenes
with open(filename,'w') as f:
	[f.write(str(gene) + '\n') for gene in genelist]
'''


# premouse calculations
mouseCase = striatumCase & contCase

# label mouse




# one test run
# input: 
# 1) concordance direction: 'up' or 'down'
# 2) drosophila time point (dtp):  'E', 'M', or 'L'
# 3) drosophila genotype (flyCase) :  'neuronNT' or 'neuronFL'
# 4) mouse time point (mtp): 2, 6, or 10
outDir = '/home/andrew/paper1/hmf_allNeuronGlia_040818/'
if not os.path.exists(outDir):
	os.makedirs(outDir)

neuronNTcase = neuronCase & ntCase
neuronFLcase = neuronCase & flCase
gliaNTcase = gliaCase & ntCase
gliaFLcase = gliaCase & flCase

allTimeCase = earlyCase | middleCase | lateCase

allFlyCase = neuronCase | gliaCase

concDirs = ['up','down']
dtps = ['allTimeCase']
flyCases = ['allFlyCase']
mtps = ['case6']

paramList = list(itertools.product(concDirs,dtps,flyCases,mtps))
resultDict = dict.fromkeys(paramList)

'''
concDir = 'up'
dtp = earlyCase
flyCase = neuronCase & ntCase
mtp = case2
'''

graphDict = {}

for i, (concDir_n, dtp_n, flyCase_n, mtp_n) in enumerate(paramList):
	print(concDir_n,dtp_n,flyCase_n,mtp_n)
	concDir = concDir_n
	dtp = eval(dtp_n)
	flyCase = eval(flyCase_n)
	mtp = eval(mtp_n)

	# setting human genes
	if concDir == 'up':
		humanSubset = humanUpGenes
	else:
		humanSubset = humanDownGenes 

	# setting mouse genes
	caseMask = mouseCase & mtp
	padjBool = mouseTab.loc[:,caseMask & padj_m] < 0.05
	fcMask = caseMask & fc_m 



	if concDir == 'up':
		sigMouse = np.logical_and(padjBool,mouseTab.loc[:,fcMask] > 0)
	else:
		sigMouse = np.logical_and(padjBool,mouseTab.loc[:,fcMask] < 0)


	mouseSubset = list(mouseTab.index[np.where(sigMouse)[0]])

	# setting drosophila genes
	caseMask_f = flyCase & dtp
	padjBool_f = deseq.loc[:,caseMask_f & padj_f] < 0.05
	fcMask_f = caseMask_f & fc_f

	if concDir == 'up':
		sigFly = np.logical_and(padjBool_f,deseq.loc[:,fcMask_f] > 0)
	else:
		sigFly = np.logical_and(padjBool_f,deseq.loc[:,fcMask_f] < 0)

	flySubset = list(deseq.index[np.where(sigFly)[0]])

	mergedGraph = nx.compose(nx.compose(flyGraph,mouseGraph),G)
	subMerged = mergedGraph.subgraph(mouseSubset + flySubset + humanSubset)
	countFMH, geneTabFMH = quantComps(subMerged,numSpec = 3)

	graphDict[(concDir_n, dtp_n, flyCase_n, mtp_n)] = subMerged

	resultDict[(concDir_n, dtp_n, flyCase_n, mtp_n)] = (countFMH,geneTabFMH,(flySubset,mouseSubset,humanSubset))

# saving numbers and gene lists to file
masterChart = pd.DataFrame(columns = ['Direction', 'Drosophila genotype', 
	'Drosophila time point', 'Mouse time point','Intersection count','Drosophila DEG count', 
	'Mouse DEG count','Human DEG count'])



for i, key in enumerate(resultDict.keys()):
	flySubset = resultDict[key][2][0]
	mouseSubset = resultDict[key][2][1]
	humanSubset = resultDict[key][2][2]
	masterChart.loc[i] = [key[0],key[2],key[1],key[3],resultDict[key][0], len(resultDict[key][2][0]),len(resultDict[key][2][1]),len(resultDict[key][2][2])]

	human_degs = list(map(str,list(humanSubset)))
	mouse_degs = list(map(str,list(mouseSubset)))
	fly_degs = list(flySubset)

	(fly_degs_o, mouse_degs_o, human_degs_o) = getGeneLists(resultDict[key][1])
	fly_degs_overlap = fly_degs_o
	mouse_degs_overlap = list(map(str,mouse_degs_o))
	human_degs_overlap = list(map(str,human_degs_o))

	colnames = ['Human DEGs','Mouse DEGs','Fly DEGs','Human overlap DEGs','Mouse overlap DEGs','Fly overlap DEGs']
	colstr = ','.join(colnames)

	d = [human_degs, mouse_degs, fly_degs, human_degs_overlap, mouse_degs_overlap,fly_degs_overlap]
	
	with open(outDir + 'case' + str(i) + '_genelists.csv',"w+") as f:
	    writer = csv.writer(f)
	    writer.writerow(colnames)
	    for values in zip_longest(*d):
	        writer.writerow(values)
	

pickle.dump(resultDict, open(outDir + 'resultDict.p','wb'))

masterChart.to_csv(outDir + 'masterChart.csv')



reskeys = np.array(list(resultDict.keys()))
mouseCaseMask = np.array([key[3] == 'case6' for key in reskeys])
neuronCaseMask = np.array(['neuron' in key[2] for key in reskeys])
gliaCaseMask = np.array(['glia' in key[2] for key in reskeys])
upCaseMask = np.array([key[0] == 'up' for key in reskeys])
downCaseMask = np.array([key[0] == 'down' for key in reskeys])


# Fly time case masks
earlyCaseMask = np.array([key[1] == 'earlyCase' for key in reskeys])
middleCaseMask = np.array([key[1] == 'middleCase' for key in reskeys])
lateCaseMask = np.array([key[1] == 'lateCase' for key in reskeys])

neuronUp = mouseCaseMask & neuronCaseMask & upCaseMask
neuronDown = mouseCaseMask & neuronCaseMask & downCaseMask
gliaUp = mouseCaseMask & gliaCaseMask & upCaseMask
gliaDown = mouseCaseMask & gliaCaseMask & downCaseMask


cases = ['upCaseMask','downCaseMask']


combList = {case : [] for case in cases}
combList_f = {case : [] for case in cases}
combList_m = {case : [] for case in cases}


flyBefore = {case : [] for case in cases}

for case in cases:
	for key in reskeys[eval(case)]:
		(fly_degs_o, mouse_degs_o, human_degs_o) = getGeneLists(resultDict[tuple(key)][1])
		combList[case].extend(human_degs_o)
		combList_f[case].extend(fly_degs_o)
		combList_m[case].extend(mouse_degs_o)
		flyBefore[case].extend(resultDict[tuple(key)][2][0])



combList2 = {case : set(combList[case]) for case in cases}
pickle.dump(combList2, open(outDir + 'combList.p','wb'))

combList2_m = {case : set(combList_m[case]) for case in cases}
pickle.dump(combList2_m, open(outDir + 'combList_m.p','wb'))

combList2_f = {case : set(combList_f[case]) for case in cases}
pickle.dump(combList2_f, open(outDir + 'combList_f.p','wb'))


''' random test code here
import pickle
outDir = '/home/andrew/paper1/hmf_allGlia_040618/'
combList_m = pickle.load(open(outDir + 'combList_m.p','rb'))

outDir2 = '/home/andrew/paper1/fig1objects/wgcna/'

def listToFile(myList, filename):
	with open(outDir2 + filename, 'w') as f:
		[f.write(str(gene) + '\n') for gene in myList]

listToFile(list(combList2_m['gliaDown']),'gliaDown.txt')
listToFile(list(combList2_m['gliaUp']),'gliaUp.txt')



'''