"""Module for various functions for parsing phylogenetic trees.

Uses the trees represented in the tree.py module.

Written by Jesse Bloom, 2011."""


import re
import tree
import fasta


def AssignMutationPathFromNodes(t, startname, endname, mutationpath_name):
    """Assigns mutations to branches along a path, using node identities.
   
   't' is a phylogenetic tree. For any residues we are considering,
        the node.info dictionaries must have residue identities specified
        with the keys AA1, AA2, etc. Not all positions need to have 
        a residue identity specified -- if they do not, then no
        mutations for this residues are assigned. But if a residue has
        an identity assigned for one node, then it should be assigned
        for all nodes or the method may not work properly.
    'startname' and 'endname' name two distinct tip nodes on the tree.
    'mutationpath_name' is a string whose meaning is explained below.
    This method first defines all nodes that are on the path
        from 'startname' to 'endname'. Then tracing down from
        the common ancestor for these nodes (which is on the path),
        assigns to branches any mutations that are on this
        path, based on whether a residue identity changes between
        consecutive nodes. If it does, the mutation is reconstructed
        on the branch by parsimony. The mutation is then assigned
        as a string of the form "A25K" to the info dictionary for
        that branch. The key for info dictionary is 'mutationpath_name'.
        If there are multiple mutations, they are defined in one
        string, such as "A25K-G225S".
    """
    startnode = tree.GetTipNode(startname, t)
    endnode = tree.GetTipNode(endname, t)
    (commonancestor, pathlist) = tree.GetPath(startnode, endnode)
    aamatch = re.compile('^AA(\d+)$')
    for tipnode in (startnode, endnode):
        node = tipnode
        while node != commonancestor:
            mutations = []
            for (key, value) in node.info.iteritems():
                m = aamatch.search(key)
                if m:
                    r = int(m.group(1))
                    parentvalue = node.ancestor.info[m.group(0)]
                    if value != parentvalue:
                        mutations.append("%s%d%s" % (parentvalue, r, value))
            mutations = '-'.join(mutations)
            node.ancestorbranchinfo[mutationpath_name] = mutations
            node = node.ancestor


def ProteinMutationPath(t, startname, endname, seqtype):
    """Defines the path of mutations between two nodes.
    
    Takes as input a phylogenetic tree 't' for which the strings
        'startname' and 'endname' give the names of two
        distinct tip nodes.  Traces along the phylogenetic 
        tree from 'startname' to 'endname', and finds all amino
        acid mutations that occurred along this path.  Returns
        the 3-tuple:
            (mutation_path, cumulative_time, forward_time)
        cumulative_time is a number giving the cumulative branch
            lengths from startnode to endnode; forward_time is
            a number giving the cumulative branch length from
            the common ancestor of startnode and endnode to
            endnode.
        mutation_path is a list:
            [(mut1, t1), (mut2, t2), ... ]
        The first elements of the 2-tuples ('mut1', 'mut2', etc)
        give the mutations as strings in the form "N63K". The 't1', 't2', ...
        elements give the elapsed time along the path from the
        starting node and the mutation.  Note that the
        mutations are defined as they would be have to be made
        to 'startname' to give 'endname'.  So initially (as we 
        trace from 'startname' to the last common ancestor, they
        are the opposite of the mutation that actually occurred in
        forward time; as we start to trace from the last common ancestor
        to 'endname' then they are the mutation that ocurred in forward
        time.
    'seqtype' is either 'PROTEIN' or 'DNA' depending on which sequence
        type we are analyzing.
    In order for this function to work, the function 'AssignMutations'
        in this module needs to already have been called on the root 
        node of 't'.
    """
    startnode = tree.GetTipNode(startname, t)
    endnode = tree.GetTipNode(endname, t)
    (commonancestor, pathlist) = tree.GetPath(startnode, endnode)
    assert startnode == pathlist[0] and endnode == pathlist[-1]
    reverse_trace = True
    mutationpath = []
    cumulative_time = forward_time = 0.0
    for node in pathlist:
        if node == commonancestor: # now we start tracing forward along the tree
            reverse_trace = False
            continue # we go to the next node, and look back along ancestor branch
        branchlength = node.ancestorbranch
        if reverse_trace: # we are tracing backward along the tree
            protein_mutations = node.ancestorbranchinfo['protein_mutations']
            protein_mutations.sort()
            protein_mutations.reverse() # now in order from most recent to oldest
            for (t_since_ancestor, mut) in protein_mutations:
                mut = "%s%s%s" % (mut[-1], mut[1 : -1], mut[0]) # reverse mutation
                mutationpath.append((mut, cumulative_time + branchlength - t_since_ancestor))
        else: # we are tracing forward now
            protein_mutations = node.ancestorbranchinfo['protein_mutations']
            protein_mutations.sort() # now in order from oldest to most recent
            for (t_since_ancestor, mut) in protein_mutations:
                mutationpath.append((mut, cumulative_time + t_since_ancestor))
            forward_time += branchlength
        cumulative_time += branchlength
    return (mutationpath, cumulative_time, forward_time)


def SummarizeAAsFromDistribution(node):
    """Assigns individual amino acid values based on given distributions.

    'node' is a tree.Node object. In its 'info' property, it should have
        keys specifying the distribution of amino acid probabilities.
        These keys are of the form AA25.set={"G","S"} and 
        AA25.set.prob={0.7,0.3}. For each such residue, takes the
        highest probability amino acid and sets that in the dictionary
        with the key AA25 (for example). If two residues are tied for
        being most probably, just takes the first listed.
    """
    setmatch = re.compile('^AA(\d+)\.set$')
    aas = []
    for (key, aa_set) in node.info.iteritems():
        m = setmatch.search(key)
        if m:
            r = int(m.group(1))
            try:
                aa_probs = node.info['AA%d.set.prob' % r]
            except KeyError:
                raise ValueError("Found an entry for %s, but no corresponding set probabilities." % m.group(0))
            aa_set = aa_set.split(',')
            aa_probs = [float(x) for x in aa_probs.split(',')]
            if len(aa_set) != len(aa_probs):
                raise ValueError("Different lengths.")
            if len(aa_set) < 1:
                raise ValueError("Empty set.")
            imaxprob = 0
            maxprob = aa_probs[imaxprob]
            i = 1
            while i < len(aa_set):
                if aa_probs[i] > maxprob:
                    maxprob = aa_probs[i]
                    imaxprob = i
                i += 1
            node.info['AA%d' % r] = aa_set[imaxprob][1 : -1]


def BreakDNASeqToAAs(node):
    """Creates individual amino acid assignments based on nucleotide sequence.

    'node' is a tree.Node object.  In its 'info' property, it should have a key
        "states" that specifies a nucleotide sequence that can be translated
        (i.e. is in the correct reading frame).  Note that the translation is
        may not detect if the sequence is not translatable since a limited 
        number of gaps and stop codons will be read through..
    Upon completion of the function, for each amino acid created by translating
        the sequence, there is a new entry in 'node.info' with the key
        'AA1', 'AA2', etc. with the value giving the upper case one-letter
        amino acid translation for that element of the sequence."""
    threshold = 0.05 # raise error if gaps or ambiguous characters exceed this fraction
    if 'states' not in node.info:
        raise ValueError("No 'states' key in node.info:\n%s" % str(node.info))
    seq = node.info['states']
    if not (isinstance(seq, str) and (len(seq) % 3 == 0)):
        raise ValueError("Sequence is not a string with length a multiple of three:\n%s" % str(seq))
    prot = fasta.Translate([('head', seq)], readthrough_n=True, translate_gaps=True, readthrough_stop=True)[0][1]
    prot = prot.upper()
    n = len(prot)
    if prot.count('-') > n * threshold or prot.count('X') > n * threshold:
        raise ValueError("Too many 'X' or '-' characters in:\n%s" % prot)
    for i in range(n):
        node.info['AA%d' % (i + 1)] = prot[i]


def BreakProtSeqToAAs(node):
    """Creates individual amino acid assignments based on protein sequence.

    'node' is a tree.Node object.  In its 'info' property, it should have a key
        "states" that specifies a protein sequence.
    Upon completion of the function, for each amino acid in
        the sequence, there is a new entry in 'node.info' with the key
        'AA1', 'AA2', etc. with the value giving the upper case one-letter
        amino acid translation for that element of the sequence."""
    if 'states' not in node.info:
        raise ValueError("No 'states' key in node.info:\n%s" % str(node.info))
    seq = node.info['states']
    if not isinstance(seq, str):
        raise ValueError("Sequence is not a string.")
    seq = seq.upper()
    n = len(seq)
    for i in range(n):
        node.info['AA%d' % (i + 1)] = seq[i]


def AssignMutations(node, total_time, time_label, seqtype):
    """Assigns state mutations to node branches.

    'nodes' is a tree.Node object.  Typically, you would call this function
        on the root node.  For each branch
        descended from this node, looks for branch information (specified
        as 'node.rightbranchinfo' or 'node.leftbranchinfo') that is of the
        form 'history_XX=[A:G:11.481025704879713]' where 'XX' is the residue number.
        There may be one or more such entry.  Any entry of the form 'history_XX={}'
        is ignored.  Otherwise, if 'seqtype' is 'DNA' then these entries 
        are assumed to specify nucleotide mutations and 
        the units in backwards time since the most recent sample.  If 'seqtype'
        is 'PROTEIN', then they are assumed to specify protein mutations. It is also
        assumed that each node has a node.info key "states" giving the  
        sequence.  If 'seqtype' is 'DNA', then this 
        method uses this information to create two new entries 
        in the branch dictionary keyed by the strings "nucleotide_mutations"
        and "protein_mutations" with entries of the form of 2-tuples 
        '(time_since_ancestor, mut)' where 'mut' is a string of
        the form "A27G" and 'time_since_ancestor' is the time since that branch's
        ancestor.  The amino acid mutations are determined from the nucleotide
        sequences specified in node.info["states"]. If 'seqtype' is 'PROTEIN',
        then just the "protein_mutations" entry is created.
    'total_time' is a number that represents the total time from 'node' to its
        most recent descendent, and 'time_label' is a string giving the entry
        for the time since 'node' for all descendent nodes in the 'info'
        dictionaries.  If you first call 
            'total_time = tree.AssignNodeTimes(node, 0.0, time_label)'
        and then pass the 'total_time' argument to this function, everything
        will work.
    """
    historymatch = re.compile('^history\_(\d+)$')
    if seqtype == 'DNA':
        mutationmatch = re.compile('^\[([AGCT])\:([AGCT])\:(\d+\.{0,1}\d*(E\-\d+){0,1})\]$')
    elif seqtype == 'PROTEIN':
        mutationmatch = re.compile('^\[([A-Y])\:([A-Y])\:(\d+\.{0,1}\d*(E\-\d+){0,1})\]$')

    else:
        raise ValueError("Invalid 'seqtype' of: %s" % seqtype)
    if node.tip:
        return # nothing to do for tip nodes
    # handle descendents
    for (descendent, branchlength, branchinfo) in [(node.rightdescendent, node.rightbranch, node.rightbranchinfo), (node.leftdescendent, node.leftbranch, node.leftbranchinfo)]:
        nucleotide_mutations = []
        protein_mutations = []
        for (key, value) in branchinfo.iteritems():
            m = historymatch.search(key)
            if not m:
                raise ValueError("Found a branch annotation for something other than 'history_XX='.  This could be valid if branches are being annotated with something in addition to mutation histories.  If so, just enter the code and remove this error check.  The offending entry is:\n%s" % key)
            i_nt = int(m.group(1))
            if not value:
                continue # we allow for empty history annotations
            # see if we can parse the mutations
            for mutstring in value.split(','):
                m = mutationmatch.search(mutstring)
                if not m:
                    raise ValueError("Failed to parse mutations from:\n%s\nFor key\n%s" % (mutstring, key))
                (wt_nt, mut_nt, x) = (m.group(1), m.group(2), float(m.group(3)))
                # make sure the mutation time falls in the expected range
                x_since_ancestor = total_time - x - node.info[time_label] # x is time since most recent tip
                if not (0 <= x_since_ancestor <= branchlength):
                    raise ValueError("The assigned time of the mutation seems invalid.  We have: x = %f, x_since_ancestor = %f, total_time = %f, node.info[time_label] = %f, descendent.info[time_label] = %f" % (x, x_since_ancestor, total_time, node.info[time_label], descendent.info[time_label]))
                if seqtype == 'DNA':
                    nt_mut = "%s%d%s" % (wt_nt, i_nt, mut_nt)
                    nucleotide_mutations.append((x_since_ancestor, nt_mut))
                else:
                    aa_mut = "%s%d%s" % (wt_nt, i_nt, mut_nt)
                    protein_mutations.append((x_since_ancestor, aa_mut))
        if seqtype == 'DNA':
            nucleotide_mutations.sort()
            ntseq = node.info['states'] # current nucleotide sequence of node
            protseq = fasta.Translate([('head', ntseq)], readthrough_n=True, translate_gaps=True, readthrough_stop=True)[0][1]
            protein_mutations = []
            for (x_since_ancestor, nt_mut) in nucleotide_mutations:
                (wt_nt, i_nt, mut_nt) = (nt_mut[0], int(nt_mut[1 : -1]), nt_mut[-1])
                if ntseq[i_nt - 1] != wt_nt:
                    raise ValueError("A mutation of %s was specified, but the sequence identity at this site is %s" % (nt_mut, ntseq[i_nt - 1]))
                newntseq = "%s%s%s" % (ntseq[ : i_nt - 1], mut_nt, ntseq[i_nt : ]) # new sequence
                newprotseq = fasta.Translate([('head', newntseq)], readthrough_n=True, translate_gaps=True, readthrough_stop=True)[0][1]
                if len(protseq) != len(newprotseq): # error check
                    # We just disregard amino acid mutations that alter the stop codon. 
                    if len(protseq) - 1 == len(newprotseq) and newntseq[-3 : ] in ['TAA', 'TAG', 'TGA']:
                        pass
                    elif len(protseq) == len(newprotseq) - 1 and ntseq[-3 : ] in ['TAA', 'TAG', 'TGA']:
                        pass
                    else:
                        print ntseq[-3 : ], newntseq[-3 : ]
                        raise ValueError("Protein sequence lengths of %d and %d differ after nucleotide mutation %s.  Sequences are:\n%s\n%s" % (len(protseq), len(newprotseq), nt_mut, protseq, newprotseq))
                for i in range(min([len(protseq), len(newprotseq)])):
                    if newprotseq[i] != protseq[i]:
                        prot_mut = "%s%d%s" % (protseq[i], i + 1, newprotseq[i])
                        protein_mutations.append((x_since_ancestor, prot_mut))
                protseq = newprotseq
                ntseq = newntseq
        protein_mutations.sort()
        branchinfo['nucleotide_mutations'] = nucleotide_mutations
        branchinfo['protein_mutations'] = protein_mutations
        if seqtype == 'DNA':
            if ntseq != descendent.info['states']: # error check to make sure final sequence OK
                print "Problem after making mutation %s" % nt_mut
                if len(ntseq) != len(descendent.info['states']):
                    print "Sequences of different length."
                else:
                    for i in range(len(ntseq)):
                        if ntseq[i] != descendent.info['states'][i]:
                            print "Sequences differ at nucleotide %d (%s versus %s)" % (i + 1, ntseq[i], descendent.info['states'][i])
                raise ValueError("The nucleotide sequence after making the mutations differs from the descendent sequence:\n%s\n%s" % (ntseq, descendent.info['states']))
        elif seqtype == 'PROTEIN':
            protseq = node.info['states'] # current protein sequence of node
            newseq = list(protseq)
            for (x, m) in protein_mutations:
                (wt, i, mut) = (m[0], int(m[1 : -1]) - 1, m[-1])
                if wt != newseq[i]:
                    raise ValueError("Problem: %s" % m)
                newseq[i] = mut
            newseq = ''.join(newseq)
            if newseq != descendent.info['states']:
                raise ValueError("Sequences differ:\n%s\n%s" % (newseq, descendent.info['states']))
        AssignMutations(descendent, total_time, time_label, seqtype)


def GetTreeString(line, replace_with_null=[]):
    """Gets the newick tree from a line from a BEAST tree file.

    'line' should represent a line representing a tree from a BEAST
        tree file.
    'replace' is a list of regular expression patterns that we
        wish to remove from the tree string.  The replacements
        are done in the order that the items are specified.
    Returns the newick tree portion of line only, with all patterns
        matching those listed in 'replace' having been removed.
    """
    statematch = re.compile('^tree [A-Z]+\_{0,1}\d+ ')
    m = statematch.search(line)
    if not m:
        raise ValueError("Failed to find expected beginning of tree line:\n%s" % line)
    line = line[len(m.group(0)) : ]
    if line[ : 2] == '[&':
        i = tree.GetBalancedIndex(line, 0)
        line = line[i + 1 : ]
    if ' = [&R] ' == line[ : 8]:
        line = line[8 : ]
    elif '= [&R] ' == line[ : 7]:
        line = line[7 : ]
    else:
        raise ValueError("Failed to find expected information:\n%s" % line)
    for rematch in replace_with_null:
        line = re.sub(rematch, '', line)
    return line
