"""Finds CD8 T-cell epitopes in influenza nucleoprotein.

Finds epitopes from two sources:
(1) Parses the epitopes from NetCTLpan, and counts as epitopes
those with sufficiently high ransk.
(2) Parses data from the Immune Epitope Database and maps
it to the NP sequence.

Text files are generated summarizing the epitopes from
these two sources. 

Plots are then generated showing the density of epitopes
along the protein sequence, along with mutations from the
1968 -> 2007 nucleoprotein trajectory. Plots are also
generated showing the number of mutations at all sites that
change during the trajectory.

Written by Jesse Bloom, 2012."""


import os
import time
import random
import pylab
import matplotlib
import fasta
import align


def ReadImmuneEpitopeDatabase(infile, protnames, protseqfile, musclepath, identity_cutoff, maxlength, redundant_length_cutoff):
    """Reads epitopes from clustered output of Immune Epitope Database.

    infile -> *.csv file output by Immune Epitope Database, with 
        with clustering performed at 90% identity level.
    protnames -> a list of names for the protein (Source Molecule 
        Name). We only consider epitopes that are indicated as
        having come from one of these names.
    protseqfile -> FASTA file containing the sequence of the protein
        sequence in which we are searching for epitopes. We search
        for epitopes in the first sequence listed in this file.
    musclepath -> path to the MUSCLE alignment program.
    identity_cutoff -> only consider epitopes that align to the protein
        sequence with at least this identity.
    maxlength -> we only consider epitope sequences of <= this length.

    For each cluster of epitopes, we look at all epitope sequences of
        length <= maxlength. They are aligned to the protein in protseqfile.
        If they align without creating any gaps, and if the identity of the
        alignment is >= identity_cutoff, then we consider them a possible match. 
        For each cluster, we save the match with the highest identity in
        the alignment. We then go through all of these matches, and look
        at the sequence in the protein that is aligned to the match. If two
        matches have >= redundant_length_cutoff of exact identity, then
        we save the shorter one.
    The returned variable is the list good_epitopes. Its entries
        are the tuples (identity, rstart, rend, epitope, "unknown")
        where identity is the identity of the eptiope, rstart and rend are 
        the beginning and end of the epitope in the protein
        (1, 2, 3, ... numbering), epitope is the epitope sequence
        in the protein (not in the database epitope), and 'unknown'
        corresponds to the fact that we do not know the MHC allele.'
    """
    lines = open(infile).readlines()[1 : ]
    clusters = {} # keyed by cluster names, entries are lists of epitopes
    for line in lines:
        entries = line.split(',')
        (cluster, epitope, protein) = (entries[0], entries[6], entries[10].replace('"', ''))
        if protein not in protnames:
            continue # not from the correct protein
        if cluster in clusters:
            clusters[cluster].append(epitope)
        else:
            clusters[cluster] = [epitope]
    protseq = fasta.Read(protseqfile)[0]
    good_epitopes = []
    for (cluster, epitopes) in clusters.iteritems():
        clustermatched = False # once matched, set to identity
        bestmatch = None
        for epitope in epitopes:
            if len(epitope) > maxlength:
                continue # epitope is too long
            a = align.Align([protseq, ('epitope', epitope)], musclepath, 'MUSCLE')
            if len(a[0][1]) != len(protseq[1]):
                continue # alignment creates gaps
            identities = align.PairwiseStatistics(a)[0]
            if identities > 0.8:
                if (not clustermatched) or (clustermatched and identities > clustermatched):
                    clustermatched = identities
                    bestmatch = epitope
                    rstart = a[1][1].index(epitope) + 1
                    rend = rstart + len(epitope) - 1
                    matchedepitope = protseq[1][rstart - 1 : rend]
        if clustermatched:
            good_epitopes.append((clustermatched, rstart, rend, matchedepitope, 'unknown'))
    # now remove epitopes according to redundant_length_cutoff
    decorated_list = [(len(tup[3]), tup) for tup in good_epitopes]
    decorated_list.sort()
    good_epitopes = [tup[1] for tup in decorated_list] # epitopes now sorted from shortest to longest
    cleaned_epitopes = []
    for tup in good_epitopes:
        keepepitope = True
        residues = [r for r in range(tup[1], tup[2] + 1)]
        for cleantup in cleaned_epitopes:
            cleanresidues = [r for r in range(cleantup[1], cleantup[2] + 1)]
            sharedresidues = [r for r in residues if r in cleanresidues]
            nshared = len(sharedresidues)
            if nshared > redundant_length_cutoff:
                keepepitope = False
        if keepepitope:
            cleaned_epitopes.append(tup)
    good_epitopes = cleaned_epitopes
    return good_epitopes


def ReadMHCPathway(indir, epitopelength, percent_cutoff):
    """Reads epitopes from MHC-Pathway output.

    indir -> directory containing MHC-Pathway text output files,
        with files named with prefix equal to MHC allele.
    epitopelength -> the length of the epitopes in amino acids.
    percent_cutoff -> only keep epitopes that have total scores
        that are at least in this percentile. So a value of 1
        would correspond to keeping the top 1% of epitopes,
        and a value of 100 would correspond to keeping all epitopes.
    The returned variable is the list good_epitopes. Its entries
        are the tuples (score, rstart, rend, epitope, allele)
        where score is the %Rank, rstart and rend are the beginning
        and end of the epitope (1, 2, 3, ... numbering), epitope
        is the epitope sequence, and allele is the MHC allele.
    """
    assert 0 < percent_cutoff <= 100
    good_epitopes = []
    for file in os.listdir(indir):
        epitopes = [] # entries (score, rstart, rend, epitope, allele)
        (allele, ext) = os.path.splitext(file)
        if ext != '.txt':
            continue
        lines = open("%s/%s" % (indir, file)).readlines()[4 : ]
        rend = 0
        for line in lines:
            if not line.isspace():
                rend += 1
                if rend >= epitopelength:
                    entries = line.split()
                    if entries[2].strip() == '---':
                        continue
                    assert len(entries) == 6, line
                    assert rend == int(entries[0])
                    epitope = entries[1].strip()
                    assert len(epitope) == epitopelength
                    score = float(entries[2])
                    epitopes.append((score, rend - epitopelength + 1, rend, epitope, allele))
        # now get the percentile satisfying the cutoff
        epitopes.sort()
        epitopes.reverse()
        nkept = int((percent_cutoff / 100.) * len(epitopes))
        good_epitopes += epitopes[ : nkept]
    good_epitopes.sort()
    good_epitopes.reverse()
    return good_epitopes


def ReadNetCTLpan(infile, epitopelength, rank_cutoff):
    """Reads epitopes for NetCTLpan output file.

    infile -> output of NetCTLpan Excel file in text format.
    epitopelength -> the length of the epitopes in amino acids.
    rank_cutoff -> only keep epitopes with %Rank <= to this.
    The returned variable is the list good_epitopes. Its entries
        are the tuples (score, rstart, rend, epitope, allele)
        where score is the %Rank, rstart and rend are the beginning
        and end of the epitope (1, 2, 3, ... numbering), epitope
        is the epitope sequence, and allele is the MHC allele.
    """
    allepitopes = {}
    lines = open(infile).readlines()
    allele_entry_length = 6 # number of columns for each allele in output
    nalleles = (len(lines[0].split('\t')) - 3) / allele_entry_length
    assert (len(lines[0].split('\t')) - 3) % allele_entry_length == 0
    for line in lines[1 : ]:
        entries = line.split('\t')
        assert len(entries) == 3 + nalleles * allele_entry_length
        r = int(entries[0]) + 1
        epitope = entries[2].strip()
        assert len(epitope) == epitopelength
        for i in range(nalleles):
            allele = entries[3 + i * allele_entry_length].strip()
            rank = float(entries[3 + i * allele_entry_length + allele_entry_length - 1])
            allepitopes[(r, r + epitopelength - 1, epitope, allele)] = rank
    good_epitopes = [(rank, tup[0], tup[1], tup[2], tup[3]) for (tup, rank) in allepitopes.iteritems() if rank <= rank_cutoff]
    good_epitopes.sort()
    return good_epitopes


def PValue(subset, fullset, nrandom):
    """Computes a P-value that mean in subset is > than that in full set.

    subset -> a list of numbers.
    fullset -> a list of numbers, len(fullset) > len(subset)
    nrandom -> the number of random samples drawn to compute the P-value.
    
    This method compares the mean of the numbers in subset to the mean
        of nrandom randomly chosen sets of the same size as subset (chosen
        without replacement) from fullset. Returns the fraction of these
        random sets with means that are >= than the actual
        mean in subset. So these are one-tailed P-values.
    """
    nsubset = len(subset)
    assert len(fullset) > nsubset
    subset_total = 0
    for x in subset:
        subset_total += x
    totalsrandom = []
    for irandom in range(int(nrandom)):
        totalrandom = 0
        for ni in random.sample(fullset, nsubset):
            totalrandom += ni
        totalsrandom.append(totalrandom)
    totalsrandom.sort()
    assert len(totalsrandom) == nrandom
    i = 0
    while i < nrandom:
        if totalsrandom[i] >= subset_total:
            break
        i += 1
    pvalue = 1.0 - i / float(nrandom)
    return pvalue


def main():
    """Main body of script."""

    # input / output variables
    mapepitopesplot = 'mapped_epitopes.pdf' # name for created plot of epitopes mapped along sequence
    plotdir = 'plots' # save plots in this directory
    if not os.path.isdir(plotdir):
        raise IOError("Directory %s does not exist." % plotdir)
    mutationsfile = 'mutations.txt' # lists trajectory mutations by line
    infiles = {'NetCTLpan':'11639_NetCTLpan.txt',
               'ImmuneEpitopeDatabase':'ImmuneEpitopeDatabaseQuery.csv',
              } # input files
    epitopelength = 9 # length of epitopes for NetCTLpan
    score_cutoff = 1 # only consider NetCTLpan epitopes with %Rank <= this
    protseqfile = 'Aichi68-NP.fasta' # FASTA file containing sequence of target protein
    protlength = len(fasta.Read(protseqfile)[0][1]) # length of protein sequence
    musclepath = '/Users/jbloom/muscle3.8/' # path to MUSCLE alignment program
    identity_cutoff = 0.8 # Immune Epitope Database epitopes must align with >= this
    maxlength = 12 # Immune Epitope Database epitopes must be <= this long
    redundant_length_cutoff = 8 # consider Immune Epitope Database Epitopes identical if they share >= this many commonly aligned residues

    # define ReadFuncs 
    ied_epitopes = ReadImmuneEpitopeDatabase(infiles['ImmuneEpitopeDatabase'], 
        ['NP', 'Nucleoprotein', 'nucleoprotein', 'Nucleocapsid protein', 'nucleocapsid protein'],
        protseqfile, musclepath, identity_cutoff, maxlength, redundant_length_cutoff)
    netctl_epitopes = ReadNetCTLpan(infiles['NetCTLpan'], epitopelength, score_cutoff)

    # loop over epitope methods, calculate epitopes per site, write text summary files
    nbyres = {} # nybres[method][r] is the number of epitopes from method that residue r is part of (1, 2, 3, ...) numbering
    for (name, epitopes) in [('NetCTLpan', netctl_epitopes), ('ImmuneEpitopeDatabase', ied_epitopes)]:
        print "Read a total of %d good epitopes for %s." % (len(epitopes), name)
        epitopefile = '%s_epitopes.txt' % name
        sitesfile = '%s_epitopes_by_site.txt' % name
        print "Writing these epitopes to %s and %s." % (epitopefile, sitesfile)
        nbyres[name] = dict([(r, 0) for r in range(1, protlength + 1)])
        f = open(epitopefile, 'w')
        f.write('# Epitopes from %s.' % name)
        f.write('#Score\tResidues\tEpitope\tAllele\n')
        for (score, rstart, rend, epitope, allele) in epitopes:
            f.write('%.3f\t%d-%d\t%s\t%s\n' % (score, rstart, rend, epitope, allele))
            for r in range(rstart, rend + 1):
                assert 1 <= r <= protlength
                nbyres[name][r] += 1
        f.close()
        f = open(sitesfile, 'w')
        f.write("# Number of epitopes from %s that each residue is part of.\n#Residue\tNsites\n" % name)
        for r in range(1, protlength + 1):
            f.write('%d\t%d\n' % (r, nbyres[name][r]))
        f.close()

    # make a graph showing the epitopes along the primary sequence
    barwidth = 1.0 # width of bars
    residues = [r for r in range(1, protlength + 1)]
    x = [r - barwidth / 2.0 for r in residues]
    counts_ied = [nbyres['ImmuneEpitopeDatabase'][r] for r in residues]
    counts_netctl = [nbyres['NetCTLpan'][r] for r in residues]
    y_ied = [y / float(max(counts_ied)) for y in counts_ied] # normalize
    y_netctl = [y / float(max(counts_netctl)) for y in counts_netctl] # normalize
    (lmargin, rmargin, tmargin, bmargin) = (0.11, 0.02, 0.03, 0.17)
    matplotlib.rc('font', size=9)
    matplotlib.rc('text', usetex=True)
    matplotlib.rc('legend', fontsize=10)
    fig = pylab.figure(figsize=(6, 2.5), facecolor='white')
    ax = pylab.axes([lmargin, bmargin, 1.0 - lmargin - rmargin, 1.0 - tmargin - bmargin], frameon=True)
    ied_bars = pylab.bar(x, y_ied, width=barwidth, linewidth=0, color='blue', alpha=0.6)
    netctl_bars = pylab.bar(x, [-y for y in y_netctl], width=barwidth, linewidth=0, color='red', alpha=0.6)
    ax.spines['top'].set_position(('data', -1)) # move top spine to center
    ax.spines['right'].set_color('none') # get rid of right spine
    ax.spines['left'].set_position(('outward', 6)) # move left spine out a bit
    ax.spines['bottom'].set_position(('outward', 6)) # move bottom spine down a bit
    pylab.ylabel('\\begin{tabular}{r|l} \multicolumn{2}{c}{Epitope density} \\\\ {\large $\longleftarrow$} predicted & database {\large $\\longrightarrow$} \end{tabular}', size=10, verticalalignment='center')
    ax.set_xlim(0, protlength + 1)
    ax.set_ylim(-1, 1)
    ax.yaxis.set_major_locator(matplotlib.ticker.FixedLocator([-1, 0, 1]))
    ax.yaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(['1', '0', '1']))
    pylab.xlabel('Residue number', size=10)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator([1, 100, 200, 300, 400, 498]))
    ax.xaxis.set_major_formatter(matplotlib.ticker.FixedFormatter(['1', '100', '200', '300', '400', '498']))
    # now add the mutations
    minlabelspacing = 15 # slant lines if labels further than this apart
    arrowlength = {'nocomma':0.54, 'withcomma':0.22} # length of arrows for single and double labels
    textmargin = 0.02
    mutations = {}
    for m in open(mutationsfile).readlines():
        m = m.strip()
        r = int(m[1 : -1])
        if r in mutations:
            mutations[r].append(m)
        else:
            mutations[r] = [m]
    mutations = [(r, ', '.join(muts)) for (r, muts) in mutations.iteritems()]
    mutations.sort()
    sign = 1
    last_r = {1:-100, -1:-100}
    for (r, mutstring) in mutations:
        if (r - last_r[sign]) < minlabelspacing:
            xshift = minlabelspacing - (r - last_r[sign])
        else:
            xshift = 0
        if ',' in mutstring:
            al = arrowlength['withcomma']
        else:
            al = arrowlength['nocomma']
        last_r[sign] = r + xshift
        pylab.arrow(r + xshift, al * sign, -xshift, -al * sign, lw=0.75)
        pylab.text(r + xshift, (al + textmargin) * sign, '{\\bf %s}' % mutstring, fontsize=8, rotation='vertical', horizontalalignment='center', verticalalignment={1:'bottom', -1:'top'}[sign])
        sign *= -1
#    pylab.legend([ied_bars[0], netctl_bars[0]], ['Immune Epitope Database', 'NetCTLpan predictions'], ncol=2, loc='upper center', handlelength=1.5, borderaxespad=0, bbox_to_anchor=(0.5, 1.15))
    print "Creating %s" % mapepitopesplot
    pylab.savefig("%s/%s" % (plotdir, mapepitopesplot))
    time.sleep(0.4) # wait for plot to be created
    pylab.close()

    # now make graphs showing density of epitopes
    mutated_sites = dict(mutations).keys()
    #for muts_to_mark in [[384, 259, 280], []] + [[r] for r in mutated_sites]: # make plot marking each of these sets of mutations
    for muts_to_mark in [[384, 259, 280]]: # make plot marking each of these sets of mutations
        for (method, epitopesbysite) in [('database', counts_ied), ('predicted', counts_netctl)]:
            if muts_to_mark:
                plotfile = 'epitopedensity_%s_%s.pdf' % (method, '-'.join([str(r) for r in muts_to_mark]))
            else:
                plotfile = 'epitopedensity_%s.pdf' % method
            print "Making plot %s" % plotfile
            xmax = max(epitopesbysite) + 1
            xs = [x for x in range(0, xmax + 1)]
            yall = dict([(x, 0) for x in xs])
            ymutated = dict([(x, 0) for x in xs])
            allsum = mutatedsum = 0
            for r in range(1, protlength + 1):
                n = epitopesbysite[r - 1]
                yall[n] += 1 / float(len(epitopesbysite))
                if r in mutated_sites:
                    ymutated[n] += 1 / float(len(mutated_sites))
            yall = [yall[x] for x in xs]
            ymutated = [ymutated[x] for x in xs]
            ymax = max(yall + ymutated)
            (lmargin, rmargin, tmargin, bmargin) = (0.14, 0.01, 0.01, 0.16)
            fig = pylab.figure(figsize=(3, 2.3), facecolor='white')
            ax = pylab.axes([lmargin, bmargin, 1.0 - lmargin - rmargin, 1.0 - tmargin - bmargin], frameon=True)
            pylab.plot(xs, yall, 'bo-', label='all sites', alpha=0.75)
            pylab.plot(xs, ymutated, 'rs--', label='mutated sites', alpha=0.75)
            pylab.xlabel('Number of %s epitopes' % method, size=10)
            pylab.ylabel('Fraction of sites', size=10)
            ax.set_ylim(0 - 0.02 * ymax, 1.08 * ymax)
            ax.set_xlim(-0.02 * xmax, xmax * 1.02)
            ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(5))
            ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(4))
            pylab.legend(ncol=2, loc='upper right', numpoints=1, handlelength=1.5, borderaxespad=0)
            if len(muts_to_mark) > 1: # compute P-values and put on plot
                muts_to_mark_string = '/'.join([str(m) for m in muts_to_mark])
                nrandom = 1e5 # this many random draws
                epitopesformarkedmuts = [epitopesbysite[r] for r in range(1, protlength + 1) if r in muts_to_mark]
                epitopesformuts = [epitopesbysite[r] for r in range(1, protlength + 1) if r in mutated_sites]
                pvalue = ['\\begin{tabular}{r}']
                pvalue.append('%s $>$ mutated: $P = %.3f$ \\\\' % (muts_to_mark_string, PValue(epitopesformarkedmuts, epitopesformuts, nrandom=nrandom)))
                pvalue.append('%s $>$ all: $P = %.3f$ \\\\' % (muts_to_mark_string, PValue(epitopesformarkedmuts, epitopesbysite, nrandom=nrandom)))
                pvalue.append('mutated $>$ all: $P = %.3f$ \\\\' % PValue(epitopesformuts, epitopesbysite, nrandom=nrandom))
                pvalue.append('\end{tabular}')
                pvalue = ''.join(pvalue)
                pylab.text(xmax, 0.65 * ymax, pvalue, fontsize=10, verticalalignment='bottom', horizontalalignment='right')
            textmargin = 0.02
            arrowlength = 0.3 * ymax
            xslant = 0.1 * xmax
            xslants = [0.07 * xmax, -0.07 * xmax, 0.07 * xmax]
            nslants = len(xslants)
            i = 0
            for r in muts_to_mark:
                x = epitopesbysite[r - 1]
                xslant = xslants[i % nslants]
                i += 1
                pylab.arrow(x + xslant, arrowlength, -xslant, -arrowlength, lw=1.5)
                pylab.text(x + xslant, arrowlength + textmargin, '{\\bf %d}' % r, fontsize=10, rotation='vertical', horizontalalignment='center', verticalalignment='bottom')
            pylab.savefig("%s/%s" % (plotdir, plotfile))
            time.sleep(1.0) # wait for plot to be created
            pylab.close()

main() # run the script
