"""Script for analyzing dates of mutations along the mutational path.

Estimates the date of each mutation along each path where it occurs
exactly once.

Jesse Bloom, 2012."""


import re
import stats
import pylab
import matplotlib


def main():
    """Main body of script."""
    # input / output files
    datefile = 'mutation_dates.txt' # output file
    plotfile = 'mutationdates.pdf'
    probablepathfile = 'mostprobablepath.txt' # mutations in most probable path
    pathfile = 'NPhumanH3N2-MarkovJumps_paths.txt' # lists all paths
    problematic_cutoff = 0.95 # stop if less than this many paths have mutations exactly once
    lastyear = 2007.1
    # read paths
    print "Setting the date of the last node as %d" % lastyear
    probablepath = [mut.strip() for mut in open(probablepathfile).readlines()]
    print "Read a most probable path of %d mutations from %s." % (len(probablepath), probablepathfile)
    paths = [] # (totaltime, forwardtime, [(m1, t1), (m2, t2), ...])
    tmatch = re.compile('cumulativetime\=(?P<ctime>\d+\.\d+) forwardtime\=(?P<ftime>\d+\.\d+)')
    for line in open(pathfile):
        (ipath, timetup) = line.split(';')
        m = tmatch.search(timetup)
        if not m: 
            raise ValueError("Failed to match:\n%s" % timetup)
        (totaltime, forwardtime) = (float(m.group('ctime')), float(m.group('ftime')))
        paths.append((totaltime, forwardtime, [(mut.split(':')[0].strip(), float(mut.split(':')[1])) for mut in ipath.split(',')]))
    print "Read a total of %d mutational paths from %s." % (len(paths), pathfile)
    # see how many paths have all of the indicated mutations
    mutation_times = {} # track time since first sequence on all paths in which mutation occurs exactly once
    frac_forward = {} # fraction of paths on which mutation occurs in forward direction
    problematic = False
    for mut in probablepath:
        nforward = nreverse = 0
        mutation_times[mut] = []
        ndict = {}
        for (totaltime, forwardtime, path) in paths:
            n = 0
            mut_t = None
            for (mut2, t) in path:
                if mut == mut2:
                    n += 1
                    mut_t = t
            if n == 1: # only use if mutation occurs exactly once
                t_from_tip = totaltime - mut_t
                if t_from_tip > forwardtime:
                    nreverse += 1 # mutation occurs on reverse path
                    mutation_times[mut].append(lastyear - forwardtime + (t_from_tip - forwardtime))
                else:
                    nforward += 1 # mutation occurs on forward path
                    mutation_times[mut].append(lastyear - t_from_tip)
            if n in ndict: # how many times does mutation occur?
                ndict[n] += 1
            else:
                ndict[n] = 1
        frac_forward[mut] = nforward / float(nforward + nreverse)
        frac = ndict[1] / float(len(paths)) # fraction of paths with one occurrence of mutation
        print "For %s, %.3f paths have mutation exactly once. On %.2f of these paths, the mutation occurs in the forward direction.." % (mut, frac, frac_forward[mut])
        if frac < problematic_cutoff:
            problematic = True
            print "WARNING, THIS IS A LOW FRACTION. WILL NOT CONTINUE."
    if problematic:
        raise ValueError("Not continuing, since some mutations are not in most paths exactly once.")
    # now analyze median times and 80% intervals
    f = open(datefile, 'w')
    f.write('#MUTATION\tMEDIAN_DATE\t10TH_PERCENTILE\t90TH_PERCENTIL\tFRAC_ON_FORWARD_PATH\n')
    median_list = []
    errlow_list = []
    errhigh_list = []
    for_rev_list = []
    for mut in probablepath:
        mut_times = mutation_times[mut]
        mut_times.sort()
        n = len(mut_times)
        t90 = (mut_times[int(0.1 * n)], mut_times[int(0.9 * n)])
        median = stats.Median(mut_times)
        errlow_list.append(median - t90[0])
        errhigh_list.append(t90[1] - median)
        median_list.append(median)
        if frac_forward[mut] >= 0.5:
            for_rev_list.append('F')
        else:
            for_rev_list.append('R')
        print "%s has a median time of %.2f and a 80 percent credible interval from %.2f and %.2f." % (mut, median, t90[0], t90[1])
        f.write('%s\t%.2f\t%.2f\t%.2f\t%.2f\n' % (mut, median, t90[0], t90[1], frac_forward[mut]))
    f.close()

    # now make a plot using pylab
    n = len(probablepath)
    rev_i = [i for i in range(len(for_rev_list)) if for_rev_list[i] == 'R']
    for_i = [i for i in range(len(for_rev_list)) if for_rev_list[i] == 'F']
    ys = [i for i in range(len(probablepath))]
    (lmargin, rmargin, bmargin, tmargin) = (0.09, 0.01, 0.07, 0.07)
    matplotlib.rc('font', size=10)
    matplotlib.rc('legend', numpoints=1)
    matplotlib.rc('legend', fontsize=10)
    fig = pylab.figure(figsize=(7, 7))
    ax = pylab.axes([lmargin, bmargin, 1 - lmargin - rmargin, 1 - tmargin - bmargin])
    revbar = pylab.errorbar([median_list[i] for i in range(n) if i in rev_i], [ys[i] for i in range(n) if i in rev_i], xerr=[[errlow_list[i] for i in range(n) if i in rev_i], [errhigh_list[i] for i in range(n) if i in rev_i]], fmt='sr')
    forbar = pylab.errorbar([median_list[i] for i in range(n) if i in for_i], [ys[i] for i in range(n) if i in for_i], xerr=[[errlow_list[i] for i in range(n) if i in for_i], [errhigh_list[i] for i in range(n) if i in for_i]], fmt='sb')
    xticker = matplotlib.ticker.FixedLocator([1970, 1980, 1990, 2000])
    yticker = matplotlib.ticker.FixedLocator([y for y in range(n)])
    xminorlocator = matplotlib.ticker.AutoMinorLocator()
    pylab.gca().set_ylim((-1, n))
    pylab.gca().set_xlim((1966, 2008))
    pylab.gca().xaxis.set_major_locator(xticker)
    pylab.gca().yaxis.set_major_locator(yticker)
    pylab.gca().xaxis.set_minor_locator(xminorlocator)
    pylab.gca().yaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(probablepath))
    pylab.xlabel('Year', size=11)
    pylab.title('Median and median-centered 80% credible intervals for dates of mutations\non NP evolutionary path from Aichi/1968 to Brisbane/2007', size=11)
    pylab.legend([revbar[0], forbar[0]], ['from Aichi/1968 to common ancestor', 'from common ancestor to Brisbane/2007'], loc='lower right')
    pylab.savefig(plotfile)
    pylab.show()


main() # run the script
