"""Makes mutation-annotated tree from max clade credibility tree with sites.

Written by Jesse Bloom, 2011."""


import re
import copy
import gzip
import fasta
import tree
import parse_tree


def AnnotateTree(infile, outfile, startname, endname, assign_names):
    """Annotates a max clade credibility true to show mutations.

    Takes a max clade credibility tree with site identities indicated.
        For each node, assigns it the site identity with the highest
        probability indicated. Then traces along a mutational from an
        indicated starting node to an indicated ending node (both tips)
        and annotates branches with all mutations along that path
        based on site identities of the nodes.
    'infile' is the name of the file with the input max clade credibility tree.
    'outfile' is the file that is created with the annotated tree.
    'startname' is the name of the sequence for which we start the path
        tracing.
    'endname' is the name of the sequence for which we end the path tracing.
    'assign_names' is a dictionary keyed by the sequence names for any sequences
        to which we want to attach name labels. The values are the those
        name labels.
    Any branch lengths less than zero are set to zero.
    """
    f_in = open(infile)
    f_out = open(outfile, 'w')
    startcode = endcode = None
    intmatch = re.compile('^\d+ ')
    codes_to_names = {}
    for line in f_in:
        if 'tree' != line[ : 4]:
            f_out.write(line)
            line = line.strip()
            for (x, y) in assign_names.iteritems():
                if (x in line) and intmatch.search(line):
                    code = intmatch.search(line).group(0).strip()
                    codes_to_names[code] = y
                    break
            if (startname in line) and intmatch.search(line):
                if startcode:
                    raise ValueError("Duplicate codes for %s" % startname)
                startcode = intmatch.search(line).group(0).strip()
                print "The path will start from %s, which has code %s in this file." % (startname, startcode)
            if (endname in line) and intmatch.search(line):
                if endcode:
                    raise ValueError("Duplicate codes for %s" % endname)
                endcode = intmatch.search(line).group(0).strip()
                print "The path will end at %s, which has code %s in this file." % (endname, endcode)
        else:
            if not startcode and endcode:
                raise ValueError("Have not found start and end code for %s and %s." % (startname, endname))
            tree_preface = line[ : line.index('(')] # stuff before tree
            newick_tree = parse_tree.GetTreeString(line)
            t = tree.Tree(newick_tree, make_lengths_non_negative=True)
            tip_list = []
            internal_list = []
            tree.ListsOfNodes(t.GetRoot(), tip_list, internal_list)
            for node in tip_list:
                if node.name in codes_to_names:
                    node.info['sequence_name'] = codes_to_names[node.name]
                else:
                    node.info['sequence_name'] = ''
            tree.ApplyToNodes(t.GetRoot(), parse_tree.SummarizeAAsFromDistribution) # assign amino acid identities based on most probable
            # defines what is printed for mutations
#            def F(d):
#                assert isinstance(d, dict)
#                muts = d['protein_mutations']
#                if muts and 'on_path' in d:
#                    mutations = copy.copy(muts)
#                    mutations.sort()
#                    mutations = [tup[1] for tup in mutations]
#                    return '-'.join(mutations)                      
#                else:
#                    return None
            # end definition of what is printed for mutations
            path = tree.GetPath(tree.GetTipNode(startcode, t), tree.GetTipNode(endcode, t), markpath='on_path') # mark nodes on path
            parse_tree.AssignMutationPathFromNodes(t, startcode, endcode, 'mutations_on_path')
            f_out.write('%s%s\n' % (tree_preface, t.WriteNewick(nodeinfo_keys={'sequence_name':None, 'on_path':None}, nodeinfo_res={re.compile('^AA\d+$'):None}, branchinfo_keys={'on_path':None, 'mutations_on_path':None})))
#            f_out.write('%s%s\n' % (tree_preface, t.WriteNewick(nodeinfo_keys={'sequence_name':None}, branchinfo_keys={'protein_mutations':F, 'on_path':None}))) NOT CORRECT YET
    f_in.close()
    f_out.close()


def main():
    infile = 'NPhumanH3N2-MarkovJumps_maxcladecredibility.trees'
    outfile = 'NPhumanH3N2-MarkovJumps_annotated_maxcladecredibility.trees'
    startname = 'A/Aichi/2/1968_1968.00'
    endname = 'A/Brisbane/10/2007_2007.10'
    assign_names = {'A/Aichi/2/1968_1968.00':'Aichi/1968',
                    'A/Brisbane/10/2007_2007.10':'Brisbane/2007',
                    'A/Nanchang/933/1995_1995.00':'Nanchang/1995'
                   }
    # Begin annotating the tree
    print "Making an annotated version of the tree in %s..." % infile
    AnnotateTree(infile, outfile, startname, endname, assign_names=assign_names)
    print "The annotated version of the tree has been written to %s." % outfile

main()
