"""Extracts annotated trees and mutational paths from BEAST *.tree files.

Written by Jesse Bloom, 2011."""

import os
import re
import copy
import gzip
import fasta
import tree
import parse_tree
import dotgraphs


def MergeTreeFiles(gzipped_treeinfiles, gzipped_merged_treeinfile, burnin):
    """Merges treefiles.

    'gzipped_treeinfiles' gives a list of gzipped *.tree files. These files
        are assumed (but not checked) to each correspond to a *.trees file
        for a different run of BEAST on the identically same sequence set.
        The first 'burnin' (an integer >= 0) trees from each are 
        ignored. The remaining are written to the new gzip file
        gzipped_merged_treeinfile. The trees are renumbered to go
        1, 2, 3, ..
    Returns the total number of trees in the merged file.
    """
    firstfile = True
    itotal = 0
    prefixmatch = re.compile('^(tree STATE\_)(\d+) ')
    f_out = gzip.open(gzipped_merged_treeinfile, 'w')
    for file in gzipped_treeinfiles:
        print "Reading trees from %s" % file
        filetotal = filenonburnin = 0
        f_in = gzip.open(file)
        for line in f_in:
            if 'tree ' != line[ : 5]:
                if firstfile:
                    f_out.write(line)
            else:
                filetotal += 1
                if filetotal > burnin:
                    itotal += 1
                    filenonburnin += 1
                    m = prefixmatch.search(line)
                    if not m:
                        raise ValueError("Failed to match:\n%s" % line[ : 100])
                    f_out.write("%s%d %s" % (m.group(1), itotal, line[len(m.group(0)) : ]))
        print "Read a total of %d trees from %s, of which %d were included in the merge (after removing %d as burnin)." % (filetotal, file, filenonburnin, burnin)
        f_in.close()
        firstfile = False
    f_out.write('End;')
    f_out.close()
    return itotal






def GetPathsAndTrees(treeinfile, burnin, startname, endname, paths_out, treesites_out, treemuts_out, includesites, gzipped, assign_names, i_to_sample, sample_prefix):
    """Extracts mutational paths from a *.trees file, makes annotated trees.

    'treeinfile' is the name of a *.trees file generated by BEAST.  This file
        can be gzipped if the gzipped switch is set to True
    'burnin' is an integer >= 0 specifying the number of trees in
        'treeinfile' that are treated as burnin, and not used to generate
        a path.  Note that this the number of trees, not the number of
        steps, treated as burnin.
    'startname' is the name of the sequence for which we start the path
        tracing.
    'endname' is the name of the sequence for which we end the path tracing.
    'paths_out' is a writeable file-like object to which we write the paths.
    'treesites_out' is a writeable file-like object to which we re-write
        the trees file, annotating only the nodes with the identities at the
        sites indicated in 'includesites'.
    'treemuts_out' is a writeable file-like object to which we re-write the
        trees file, annotating branches with mutations.
    'includesites' is a dictionary keyed by residue numbers. It specifies
        which residues are included in the annotations of 'treesites_out'
        and 'treemuts_out'. Any residue number that is a key in the dictionary
        is included; the values for the keys are arbitrary.
    'gzipped' is a boolean switch, set to true if and only if 'treeinfile'
        is compressed with gzip.
    'assign_names' is a dictionary keyed by the sequence names for any sequences
        to which we want to attach name labels. The values are the those
        name labels.
    'i_to_sample', and 'sample_prefix' specify that we write mutation and 
        site annotated trees for specific trees in the file. 'i_to_sample'
        should be a list of integers. 'sample_prefix' should be a string
        with a formatting character for an integer, such as "sample_%d.trees".
        A file of name 'sample_prefix % i_to_sample' is then created for each
        tree after the burnin (numbering going 1, 2, 3, ...) that corresponds
        to a number in 'i_to_sample'.
    Returns a number corresponding to the number of paths generated.
    """
    replace_with_null = [ # patterns to remove from node and branch annotations
        re.compile('history\_\d+\=\{\}\,'),
        re.compile('\,history\_\d+\=\{\}'),
        re.compile('history\_\d+\=\{\}')]
    aamatch = re.compile('^AA\d+$')
    print "\nExtracting mutational paths from %s." % treeinfile
    if gzipped:
        f_in = gzip.open(treeinfile)
    else:
        f_in = open(treeinfile)
    startcode = endcode = None
    start_trees = False
    itree = 0
    intmatch = re.compile('^\d+ ')
    codes_to_names = {}
    ntrees = 0
    nontreetext = []
    for line in f_in:
        if 'tree' != line[ : 4]:
            treesites_out.write(line)
            treemuts_out.write(line)
            nontreetext.append(line)
        line = line.strip()
        if not start_trees:
            for (x, y) in assign_names.iteritems():
                if (x in line) and intmatch.search(line):
                    code = intmatch.search(line).group(0).strip()
                    codes_to_names[code] = y
                    break
            if (startname in line) and intmatch.search(line):
                if startcode:
                    raise ValueError("Duplicate codes for %s" % startname)
                startcode = intmatch.search(line).group(0).strip()
                print "The path will start from %s, which has code %s in this file." % (startname, startcode)
            if (endname in line) and intmatch.search(line):
                if endcode:
                    raise ValueError("Duplicate codes for %s" % endname)
                endcode = intmatch.search(line).group(0).strip()
                print "The path will end at %s, which has code %s in this file." % (endname, endcode)
        if 'tree' == line[ : 4]:
            start_tree = True
            if not startcode and endcode:
                raise ValueError("Starting trees, but have not found start and end code for %s and %s." % (startname, endname))
            itree += 1
            if itree <= burnin:
                print "Disregarding tree %d as burnin." % itree
            else:
                ntrees += 1
                print "Parsing tree %d in the file, which is the %d path generated from this file." % (itree, ntrees)
                tree_preface = line[ : line.index('(')] # stuff before tree
                newick_tree = parse_tree.GetTreeString(line, replace_with_null)
                t = tree.Tree(newick_tree)
                tip_list = []
                internal_list = []
                tree.ListsOfNodes(t.GetRoot(), tip_list, internal_list)
                for node in tip_list:
                    if node.name in codes_to_names:
                        node.info['sequence_name'] = codes_to_names[node.name]
                    else:
                        node.info['sequence_name'] = ''
                time_label = "time_since_root"
                total_time = tree.AssignNodeTimes(t.GetRoot(), 0.0, time_label)
                parse_tree.AssignMutations(t.GetRoot(), total_time, time_label, 'PROTEIN')
                tree.ApplyToNodes(t.GetRoot(), parse_tree.BreakProtSeqToAAs)
                (mutationpath, cumultime, fortime) = parse_tree.ProteinMutationPath(t, startcode, endcode, 'PROTEIN')
                paths_out.write("%s:%f" % mutationpath[0])
                for (mut, cumulative_time) in mutationpath[1 : ]:
                    paths_out.write(", %s:%f" % (mut, cumulative_time))
                paths_out.write('; cumulativetime=%f forwardtime=%f\n' % (cumultime, fortime))
                paths_out.flush()
                nodeinfo_keys = dict([('AA%d'  % r, None) for r in includesites.iterkeys()])
                nodeinfo_keys['sequence_name'] = None
                treesites_out.write('%s%s\n' % (tree_preface, t.WriteNewick(nodeinfo_keys=nodeinfo_keys)))
                treesites_out.flush()
                # defines what is printed for mutations
                def F(d):
                    assert isinstance(d, dict)
                    muts = d['protein_mutations']
                    if muts and 'on_path' in d:
                        mutations = copy.copy(muts)
                        mutations.sort()
                        mutations = [tup[1] for tup in mutations]
                        return '-'.join(mutations)                      
                    else:
                        return None
                # end definition of what is printed for mutations
                path = tree.GetPath(tree.GetTipNode(startcode, t), tree.GetTipNode(endcode, t), markpath='on_path') # mark nodes on path
                treemuts_out.write('%s%s\n' % (tree_preface, t.WriteNewick(nodeinfo_keys={'sequence_name':None}, branchinfo_keys={'protein_mutations':F, 'on_path':None})))
                if ntrees in i_to_sample:
                    samplefilename = sample_prefix % ntrees
                    print "This tree is being written to %s." % samplefilename
                    fsample = open(samplefilename, 'w')
                    fsample.write(''.join(nontreetext))
                    fsample.write('%s%s\n' % (tree_preface, t.WriteNewick(nodeinfo_keys=nodeinfo_keys, branchinfo_keys={'protein_mutations':F, 'on_path':None})))
                    fsample.write("End;")
                    fsample.close()
                treemuts_out.flush()
    f_in.close()
    return ntrees


def main():
    # input / output files and variables
    burnin = 60 # number of "burn in" trees in each file, which we discard
    gzipped_treeinfiles = ['NPhumanH3N2-MarkovJumps_%d.trees.gz' % i for i in range(1, 6)] # input *.trees files, gzipped
    gzipped_merged_treeinfile = 'NPhumanH3N2-MarkovJumps_merged.trees.gz' # trees in gzipped_treeinfiles are merged into this gzip file after removing first burnin from each
    treesitesfile = 'NPhumanH3N2-MarkovJumps_sites.trees'
    treemutsfile = 'NPhumanH3N2-MarkovJumps_mutations.trees'
    pathsfile = 'NPhumanH3N2-MarkovJumps_paths.txt'
    digraphfile = 'NPhumanH3N2-MarkovJumps_digraph.dot' 
    pathtree_sample_prefix = 'NPhumanH3N2-MarkovJumps_treesample_%s.trees' # base name for tree sample files
    digraph_sample_prefix = 'NPhumanH3N2-MarkovJumps_sample_%s.dot' # base name for digraph sample files
    startname = 'A/Aichi/2/1968_1968.00'
    endname = 'A/Brisbane/10/2007_2007.10'
    startseqfile = 'Aichi68-NP.fasta' # nucleotide sequence of starting sequence
    assign_names = {'A/Aichi/2/1968_1968.00':'Aichi/1968',
                    'A/Brisbane/10/2007_2007.10':'Brisbane/2007',
                    'A/Nanchang/933/1995_1995.00':'Nanchang/1995'
                   }
    (digraphstartname, digraphendname) = ('Aichi/1968', 'Brisbane/2007')
    proteinalignmentfile = 'NPhumanH3N2_unique_protein_alignment.fasta'
    i_to_sample = [1, 801] # generate single trees and dot files for these paths
    cutoff_frac = 0.01 # only show nodes present on at least this many paths
    maxweight = 4.0 # making this value greater lightens all colors in digraphs
    labelcutoff = 0.6 # label edges with at least this weight
    # Begin execution
    print "\nBuilding files with annotated trees, paths, and digraph files."
    # merge input files
    print "\nFirst creating the merged tree file %s." % gzipped_merged_treeinfile
    generate_new = True
    if os.path.isfile(gzipped_merged_treeinfile):
        ans = raw_input("File %s already exists. Do you want to overwrite it [y/n]?" % gzipped_merged_treeinfile)
        if ans.strip() not in ['Y', 'y']:
            generate_new = False
            print "Not generating a new file."
        else:
            print "Will overwrite the file(s)."
    if generate_new:
        print "Generating %s by merging the following files, after removing the first %d trees as burnin from each: %s" % (gzipped_merged_treeinfile, burnin, ', '.join(gzipped_treeinfiles))
        itotal = MergeTreeFiles(gzipped_treeinfiles, gzipped_merged_treeinfile, burnin)
        print "This merged file contains a total of %d trees." % itotal
    # we only include sites that differ in at least one sequence - determine these
    a = fasta.Read(proteinalignmentfile)
    n = len(a[0][1]) # length of alignment
    includesites = {}
    for r in range(n):
        rd = {}
        for (head, seq) in a:
            rd[seq[r]] = True
        if len(rd) > 1:
            includesites[r + 1] = True # include this site as it differs
    print "In node annotations, we will include only the %d of %d sites that differ in at least one sequence." % (len(includesites), n)
    # begin getting the trees and paths
    fileexists = False
    for (file, description) in [(pathsfile, 'mutational paths'), (treesitesfile, 'site-annotated trees'), (treemutsfile, 'mutation-annotated trees')]:
        print "We will create a %s file of %s." % (description, file)
        if os.path.isfile(file):
            fileexists = True
    generate_new = True
    if fileexists:
        ans = raw_input("One or more of these files already exists. Do you want to overwrite all of them and generate new ones [y/n]?")
        if ans.strip() not in ['Y', 'y']:
            generate_new = False
            print "Not generating new file - using existing ones."
        else:
            print "Will overwrite these files."
    if generate_new:
        paths_out = open(pathsfile, 'w')
        treesites_out = open(treesitesfile, 'w')
        treemuts_out = open(treemutsfile, 'w')
        n = 0
        print "\nAnalyzing trees in %s." % gzipped_merged_treeinfile
        n = GetPathsAndTrees(gzipped_merged_treeinfile, 0, startname, endname, paths_out, treesites_out, treemuts_out, includesites, gzipped=True, assign_names=assign_names, i_to_sample=i_to_sample, sample_prefix=pathtree_sample_prefix)
        print "Generated a total of %d paths and annotated trees from this file." % n
        paths_out.close()
        treesites_out.close()
        treemuts_out.close()
    # Now generate the *.dot files
    startseq = fasta.Translate(fasta.Read(startseqfile))[0][1].strip()
    print "\nNow generating the *.dot files containing the digraphs."
    generate_new = True
    if os.path.isfile(digraphfile):
        ans = raw_input("File %s already exists. Do you want to overwrite it and any sampled path digraph files [y/n]?" % digraphfile)
        if ans.strip() not in ['Y', 'y']:
            generate_new = False
            print "Not generating a new file."
        else:
            print "Will overwrite the file(s)."
    if generate_new:
        paths = [line.strip() for line in open(pathsfile).readlines()]
        paths = [path.split(';')[0] for path in paths]
        cutoff = int(cutoff_frac * len(paths))
        print "Building the digraphs, excluding nodes that do not appear in at least %.3f (%d of %d) of the paths. All internal nodes that lack both incoming and outgoing edges after this exclusion are also removed. In the overall digraph file, paths with weights of at least %.3f are annotated. Cycles in sequence space are also removed from the overall digraph file in %s, although not from any sampled path digraph files." % (cutoff_frac, cutoff, len(paths), labelcutoff, digraphfile)
        for i in i_to_sample:
            if i < len(paths):
                dotfile = digraph_sample_prefix % i
                print "Writing digraph for sampled path %d to %s, with any cycles included." % (i, dotfile)
                (nodes, edges, startnode, endnode) = dotgraphs.MakeDigraph([paths[i - 1]], 0, removecycles=False)
                dotgraphs.WriteDOTFile(nodes, edges, dotfile, weightrange=(cutoff_frac, maxweight), logscaleweight=True, startnode=startnode, endnode=endnode, labelcutoff=labelcutoff, startseq=startseq, startname=digraphstartname, endname=digraphendname)
        (nodes, edges, startnode, endnode) = dotgraphs.MakeDigraph(paths, cutoff, removecycles=True)
        dotgraphs.WriteDOTFile(nodes, edges, digraphfile, weightrange=(cutoff_frac, maxweight), logscaleweight=True, startname=digraphstartname, endname=digraphendname, labelcutoff=labelcutoff, startseq=startseq, startnode=startnode, endnode=endnode)
        print "Wrote the final digraph to %s. It has %d nodes and %d edges." % (digraphfile, len(nodes), len(edges))


main()
