"""Module with functions for building *.dot digraph files.

Jesse Bloom, 2011."""


import math
import fasta

def HammingDistance(node, refnode):
    """Returns the Hamming distance between two nodes."""
    ref = dict([(int(x[1 : ]), x[0]) for x in refnode.split()])
    comp = dict([(int(x[1 : ]), x[0]) for x in node.split()])
    sites = dict([(x, True) for x in ref.keys() + comp.keys()])
    hd = 0
    for site in sites.iterkeys():
        if ((site in ref) and (not site in comp)) or ((site in comp) and (not site in ref)) or (ref[site] != comp[site]):
            hd += 1
    return hd


def Mutations(node1, node2, startseq):
    """Returns all mutations separating two nodes.

    'node1' and 'node2' are two nodes.
    'startseq' is the sequence of the reference sequence
        from which nodes are defined.
    Returns a list of all mutations that need to be made to 
        'node1' to give 'node2'. Sorted by residue number.
    """
    muts1 = dict([(int(x[1 : ]), x[0]) for x in node1.split()])
    muts2 = dict([(int(x[1 : ]), x[0]) for x in node2.split()])
    diffsites = dict([(site, True) for site in muts1.keys() + muts2.keys() if site not in muts1 or site not in muts2 or muts1[site] != muts2[site]])
    mutations = []
    for site in diffsites:
        if site in muts1:
            aa1 = muts1[site]
        else:
            aa1 = startseq[site - 1]
        if site in muts2:
            aa2 = muts2[site]
        else:
            aa2 = startseq[site - 1]
        mutations.append((site, aa1, aa2))
    mutations.sort()
    mutations = ["%s%d%s" % (aa1, site, aa2) for (site, aa1, aa2) in mutations]
    return mutations



def PathToNodesTimes(path, removecycles):
    """Converts a path to nodes.

    Takes a single path of the type input to 'MakeDigraphs'.
    'removecycles' is a Boolean switch specifying that we remove
        mutational cycles that return to the same sequence (only done
        if True).
    Returns the list 'pathnodes'.
    'pathnodes' is a list with length equal to the length of the path,
        with each element represesenting a node name as defined
        in 'MakeDigraph'.  The first element will always be ''
        corresponding to the starting sequence.
    """
    pathnodes = ['']
    currentseq = {}
    wtseq = {}
    for x in path.split(','):
        x = x.strip()
        (mut, t) = (x.split(':')[0], float(x.split(':')[1]))
        (wt, r, mut) = (mut[0], int(mut[1 : -1]), mut[-1])
        if r in currentseq:
            if currentseq[r] != wt:
                raise ValueError("Mismatch for wildtype residue.")
            else:
                if mut == wtseq[r]:
                    del currentseq[r]
                else:
                    currentseq[r] = mut
        else:
            wtseq[r] = wt
            currentseq[r] = mut
        aalist = [(r, aa) for (r, aa) in currentseq.iteritems()]
        aalist.sort()
        nodename = ' '.join(["%s%d" % (aa, r) for (r, aa) in aalist])
        pathnodes.append(nodename)
    if removecycles:
        keepgoing = True
        while keepgoing:
            newpathnodes = []
            namesfound = {}
            for name in pathnodes:
                if name in namesfound:
                    # occurs twice, must be a cycle
                    i = pathnodes.index(name)
                    j = pathnodes.index(name, i + 1)
                    assert 0 <= i < j < len(pathnodes)
                    newpathnodes = pathnodes[ : i] + pathnodes[j : ]
                    pathnodes = newpathnodes
                    break
                else:
                    namesfound[name] = True
            else:
                keepgoing = False
    return pathnodes

def HeuristicTraceBack(nodes, edges, node, cutoff):
    """Traces back to find last high probability predecessor of a node.

    'nodes' and 'edges' are dictionaries defining the nodes and
        edges, in the format returned by 'MakeDigraphs'.
    'node' is a node in 'nodes'
    'cutoff' is a number specifying the threshold weight.  Nodes
        and edges with weights >= cutoff are over the threshold.
    It is assumed that 'node' is over the threshold, but has no
        single incoming edge over the threshold. This method traces
        trace back along incoming edges to find the closest node
        over the cutoff that originates edges to this node. The idea
        is that we will find the node where most of the edge flux into
        this node originates. However, this is not guaranteed -- that
        is why the title of this method includes the word "heuristic."
        The tracing is done back along the highest weighted edge for
        edge node along the traceback. But in pathological cases, this
        could give us a node that actually is not the closest originator
        of the flux to this node.
    Returns the predecessor node found by this method.
    """
    assert nodes[node] >= cutoff, "Node not over the threshold."
    # internal function
    def _HighestPredecessor(inode):
        """Returns the predecessor node attached to the highest weight path."""
        weightedlist = []
        for ((node1, node2), weight) in edges.iteritems():
            if node2 == inode:
                weightedlist.append((weight, node1))
        weightedlist.sort()
        assert weightedlist, "Empty list of predecessors."
        return weightedlist[-1][1]
    # end of internal function
    predecessor = _HighestPredecessor(node)
    while nodes[predecessor] < cutoff:
        predecessor = _HighestPredecessor(predecessor)
    return predecessor


def MakeDigraph(paths, cutoff, removecycles):
    """Constructs a digraph from the paths.

    The paths are strings of the form found in the mutationpaths files,
    such as:
        "R470K:0.041556, R102G:0.460018, K77R:1.117045"
    These strings represent mutations and the elapsed time between each
        mutation and the first sequence.
    The returned variable is a 4-tuple '(nodes, edges, startnode, endnode)'.
    'nodes' is a dictionary keyed by the node names for each sequence
        found along the path.  Sequences are named as strings giving
        the identity of each residue that differs from wildtype. So
        the string above would be:
        "R77 G102 KR470" 
        A string like the one above that also contains G102K would be
        "R77 K102 K470".  
        The mutations in these strings are ordered by site.
        The values in 'nodes' are numbers representing the fraction of the 
        total number of paths that contain that node at least once.
    'edges' is a dictionary keyed by 2-tuples '(startnode, endnode)'
        where each of these are strings found in 'nodes'. The values
        are numbers representing the fraction of the total number of paths
        that contain that edge at least once.
    'startnode' and 'endnode' are the names of the starting and ending nodes,
        which should be the same for all paths.
    The input variable 'cutoff' is an integer >= 0. Any node not in at least cutoff
        of the paths is removed from the returned set. After this cutoff is done,
        all nodes are examined to make sure that they have both an incoming and
        outgoing edges (except for the first and last nodes, which have only
        outgoing and incoming respectively). Any nodes that lack both such edges
        are removed from the returned set.  This exclusion criteria is applied
        recursively since exclusion of new nodes might exclude further nodes.
    The input variable 'removecycles' is a Boolean switch. If it is True,
        then when the paths are processed, any mutation cycles that return
        to the original sequence are removed.
    """
    nodes = {}
    edges = {}
    startnode = endnode = None
    for path in paths:
        node_in_this_path = {}
        edge_in_this_path = {}
        pathnodes = PathToNodesTimes(path, removecycles)
        if endnode == None:
            startnode = pathnodes[0]
            endnode = pathnodes[-1]
        if startnode != pathnodes[0]:
            raise ValueError("Different start nodes:\n%s\n%s" % (startnode, pathnodes[0]))
        if endnode != pathnodes[-1]:
            raise ValueError("Different end nodes:\n%s\n%s" % (endnode, pathnodes[-1]))
        try:
            if not pathnodes[0] in node_in_this_path: # don't double count
                nodes[pathnodes[0]] += 1
        except KeyError:
            nodes[pathnodes[0]] = 1
        node_in_this_path[pathnodes[0]] = True
        for i in range(len(pathnodes) - 1):
            (node1, node2) = (pathnodes[i], pathnodes[i + 1])
            try:
                if not node2 in node_in_this_path: # don't double count
                    nodes[node2] += 1
            except KeyError:
                nodes[node2] = 1
            node_in_this_path[node2] = True
            key = (node1, node2)
            try:
                if not key in edge_in_this_path: # don't double count
                    edges[key] += 1
            except KeyError:
                edges[key] = 1
            edge_in_this_path[key] = True

    def _RecursivelyExcludeNodes(nodes_in, edges_in):
        """Recursively do the node / edge exclusion."""
        excluded_something = False
        has_outgoing = {}
        has_incoming = {}
        edges_out = {}
        nodes_out = {}
        for ((node1, node2), n) in edges_in.iteritems():
            has_outgoing[node1] = True
            has_incoming[node2] = True
        for (node, n) in nodes_in.iteritems():
            if n >= cutoff and ((node in has_outgoing and node in has_incoming) or (node in has_outgoing and node == startnode) or (node in has_incoming and node == endnode)):
                nodes_out[node] = n
            else:
                excluded_something = True
        for ((node1, node2), weight) in edges_in.iteritems():
            if not (node1 in nodes_out and node2 in nodes_out):
                excluded_something = True
            else:
                edges_out[(node1, node2)] = weight
        if excluded_something:
            return _RecursivelyExcludeNodes(nodes_out, edges_out)
        else:
            return (nodes_out, edges_out)

    (cleaned_nodes, cleaned_edges) = _RecursivelyExcludeNodes(nodes, edges)
    npaths = float(len(paths))
    for node in cleaned_nodes.iterkeys():
        cleaned_nodes[node] /= npaths
    for edge in cleaned_edges.iterkeys():
        cleaned_edges[edge] /= npaths
    return (cleaned_nodes, cleaned_edges, startnode, endnode)


def WriteDOTFile(nodes, edges, filename, weightrange, logscaleweight, startnode, endnode, labelcutoff, startseq, startname, endname):
    """Writes a digraph in the DOT language to create a *.dot file.

    'nodes' and 'edges' define a digraph, as returned by MakeDigraph.
    'filename' gives a valid filename.
    'weightrange' is a 2-tuple of the form '(minweight, maxweight)'. Edge and
        node color saturation is proportional to weights or logarithm of weights,
        scaled so that weightrange[0] is 0 and weightrange[1] is 1.
    'logscaleweight' specifies what the color saturation is proportional to.
        If 'True', proportional to logarithm of weights, if 'False' then
        proporational to weights.
    'startnode' and 'endnode' are the starting and ending nodes.
    'labelcutoff' gives the minimum weight of an edge before we label it.
    'startseq' is the starting sequence of 'startnode'.
    'startname' and 'endname' are strings giving the names of the starting
        and ending nodes.
    Node areas are proportional to their weights.
    Node ranks (which determine horizontal placement) are equal to 
        the Hamming distance from the starting node minus the Hamming distance 
        from the ending node plus the Hamming distance between the starting
        and ending nodes.
    Writes the DOT language specification of the digraph to 'filename'.
    """
    nodesize = 0.9 # the height of a node with weight 1
    ranksep = 0.05 * nodesize # separation between ranks
    nodesep = 0.18 * nodesize # min separation between nodes of same rank
    penwidth = 10 # the penwidth of an edge
    fontsize = 40 # font size
    arrowsize = 1.6 # size of arrow
    fontname = 'Helvetica-Bold' # font style
    # Get maximum and minimum weights for color scaling
    def _ScaleWeight(weight):
        if logscaleweight:
            return (math.log(weight) - math.log(weightrange[0])) / (math.log(weightrange[1]) - math.log(weightrange[0]))
        else:
            return (weight - weightrange[0]) / (weightrange[1] - weightrange[0])
    # Start on the graph
    f = open(filename, 'w')
    f.write('digraph G { rankdir=TB; ranksep=%f; nodesep=%f;\n' % (ranksep, nodesep))
    hamming_nodes = {}
    hammingdistances = []
    for (node, weight) in nodes.iteritems():
        hammingdistance = HammingDistance(node, startnode) - HammingDistance(node, endnode) + HammingDistance(startnode, endnode)
        if hammingdistance not in hamming_nodes:
            hammingdistances.append(hammingdistance)
            hamming_nodes[hammingdistance] = {node:weight}
        else:
            hamming_nodes[hammingdistance][node] = weight
    for hammingdistance in hammingdistances:
        f.write('\tsubgraph %d { label="%d" rank=same\n' % (hammingdistance, hammingdistance))
        for (node, weight) in hamming_nodes[hammingdistance].iteritems():
            if node == startnode:
                f.write('\t\tnode [shape=diamond label="%s" height=%f color="0.7 %f 0.9" penwidth=%f fontsize=%d fontname="%s" "0.7 1.0 0.0"] "%s";\n' % (startname, nodesize * math.sqrt(weight), _ScaleWeight(weight), penwidth, fontsize, fontname, node)) 
            elif node == endnode:
                f.write('\t\tnode [shape=diamond label="%s" height=%f color="0.7 %f 0.9" penwidth=%f fontsize=%d fontname="%s" "0.7 1.0 0.0"] "%s";\n' % (endname, nodesize * math.sqrt(weight), _ScaleWeight(weight), penwidth, fontsize, fontname, node)) 
            else:
                f.write('\t\tnode [style=filled shape=circle label="" height=%f color="0.7 %f 0.9" penwidth=%f arrowsize=%f] "%s";\n' % (nodesize * math.sqrt(weight), _ScaleWeight(weight), penwidth, arrowsize, node)) 
        f.write('\t}\n')
    # Now write the edges
    node_haslabel = dict([(node, False) for node in nodes if node != startnode]) # True iff labeled incoming edge
    labeled_edges = []
    for ((node1, node2), weight) in edges.iteritems():
        # get the mutation separating these nodes
        if weight >= labelcutoff:
            mutations = Mutations(node1, node2, startseq)
            assert len(mutations) == 1
            edgelabel = mutations[0]
            node_haslabel[node2] = True
            labeled_edges.append('\t"%s" -> "%s" [weight=%f penwidth=%f color="0.7 %f 0.9" label="%s" fontsize=%d fontname="%s" fontcolor="0.7 1.0 0.0" labelfloat=false arrowsize=%f];\n' % (node1, node2, weight, penwidth * weight, _ScaleWeight(weight), edgelabel, fontsize, fontname, arrowsize))
        else:
            # write the edge
            f.write('\t"%s" -> "%s" [weight=%f penwidth=%f color="0.7 %f 0.9" arrowsize=%f];\n' % (node1, node2, weight, penwidth * weight, _ScaleWeight(weight), arrowsize))
    f.write(''.join(labeled_edges)) # we write the labeled edges last to make sure they are on top
    # Connections for nodes above labelcutoff lacking labeled incoming edges
    lackslabel = [node for node in node_haslabel if nodes[node] > labelcutoff and not node_haslabel[node]] # these are the nodes that need incoming labeled connections
    for node in lackslabel:
        predecessor = HeuristicTraceBack(nodes, edges, node, labelcutoff)
        mutations = (Mutations(predecessor, node, startseq))
        mutstring = [mutations[0]]
        i = 1
        for mutation in mutations[1 : ]:
            if i % 2:
                mutstring.append('-%s' % mutation)
            else:
                mutstring.append('-\\n%s' % mutation)
            i += 1
        mutstring = ''.join(mutstring)
        f.write('\t"%s" -> "%s" [weight=0, penwidth=%f color="0.0 1.0 0.9" label="%s" fontsize=%d fontname="%s" fontcolor="0.0 1.0 0.9" labelfloat=false arrowsize=%f];\n' % (predecessor, node, penwidth, mutstring, fontsize, fontname, arrowsize))
    f.write('}')
    f.close()
