"""Script for parsing through NP sequences.

Jesse Bloom, 2011."""


import fasta
import align


def RemoveRedundantSeqs(seqs):
    """Removes redundant sequences.

    Ignores case, also removes any substrings.  Keeps the first sequence listed
    in 'seqs'.  The returned list is a copy with redundant / substrings removed."""
    newseqs = []
    for (head, seq) in seqs:
        seq = seq.upper()
        i = 0
        for (x, y) in newseqs:
            y = y.upper()
            if seq == y:
                break    
            elif seq in y:
                break
            elif y in seq:
                newseqs[i] = (head, seq)
                break
            i += 1
        else:
            newseqs.append((head, seq))
    return newseqs


def main():
    """Main body of script."""
    musclepath = '/Users/jbloom/muscle3.8/'
    seqsfile = 'NPhumanH3N2.fasta'
    targetstrains = ['Aichi68', 'BR10', 'Nan95']
    specified_outliers = ['A/Nanjing/49/1977', 'A/Victoria/1968', 'A/Albany/4/1977']
    #
    # Read in all unique sequences
    seqs = fasta.Read(seqsfile)
    print "Read %d nucleotide sequences from %s." % (len(seqs), seqsfile)
    print "\nNow parsing all %d nucleotide sequences to get unique ones..." % (len(seqs))
    seqs = RemoveRedundantSeqs(seqs)
    print "After parsing, have %d unique nucleotide sequences remaining." % (len(seqs))
    #
    # Read in target sequences
    print "\nNow adding specific target sequences."
    targetseqs = {}
    for strain in targetstrains:
        f = "%s-NP.fasta" % strain
        print "Read sequence for strain %s from %s." % (strain, f)
        targetseqs[strain] = fasta.Read(f)[0]
        i = 0
        for (head, seq) in seqs:
            if seq.upper() == targetseqs[strain][1].upper():
                print "Nucleotide sequence for strain %s matches sequence for %s" % (strain, head)
                seqs = [targetseqs[strain]] + seqs[ : i] + seqs[i + 1 : ]
                break
            i += 1
        else:
            print "Nucleotide sequence for strain %s does not match any of the existing sequences." % strain
            seqs = [targetseqs[strain]] + seqs
    print "After adding the target sequences, have %d unique nucleotide sequences." % (len(seqs))
    # look for oddball sequences that might be problematic
    lengthdiff = 21
    identitycutoff = 0.9
    (refhead, refseq) = seqs[0]
    print "\nNow looking for oddball sequences that might be wrong by comparing them to a reference sequence for %s..." % refhead
    retainedseqs = [(refhead, refseq)]
    for (head, seq) in seqs[1 : ]:
        if abs(len(seq) - len(refseq)) > lengthdiff:
            print "Length of %d differs from reference sequence length of %d by more than %d for %s" % (len(seq), len(refseq), lengthdiff, head)
            print "Discarding this sequence."
            continue
        a = align.Align([(refhead, refseq), (head, seq)], musclepath, 'MUSCLE')
        if '-' in a[0][1]:
            print "A gap is created in the reference sequence when it is aligned with %s.  Discarding this sequence." % head
            continue
        identity = align.PairwiseStatistics(a)[0]
        if identity < identitycutoff:
            print "The identity with the reference sequence is only %.2f (less than cutoff of %.2f) for %s" % (identity, identitycutoff, head)
            print "Discarding this sequence."
            continue
        retainedseqs.append((head, seq))
    seqs = retainedseqs
    print "Retained a total of %d sequences." % len(seqs)
    # translate and get unique proteins
    prots = fasta.Translate(seqs, readthrough_n=True, truncate_incomplete=True)
    print "\nGetting the unique protein sequences..."
    prots = RemoveRedundantSeqs(prots)
    print "After removing redundant proteins, there are %d unique proteins." % len(prots)
    heads = {}
    for (head, prot) in prots:
        heads[head] = True
    coding = [(head, seq) for (head, seq) in seqs if head in heads]
    print "There are %d nucleotide sequences coding for unique proteins." % len(coding)
    # now look at specified outlier sequences
    print "\nNow removing sequences manually specified as outliers."
    for outlier in specified_outliers:
        print "Sequence %s is manually specified as an outlier which should be removed." % outlier
        outlier_seqs = [(head, seq) for (head, seq) in seqs if outlier in head]
        if len(outlier_seqs) != 1:
            raise ValueError("Failed to find exactly one sequence with header %s" % outlier)
        seqs = [(head, seq) for (head, seq) in seqs if outlier not in head]
        prots = [(head, seq) for (head, seq) in prots if outlier not in head]
        coding = [(head, seq) for (head, seq) in coding if outlier not in head]
    print "After removing the specified outliers, there are %d nucleotide sequences, %d protein sequences, and %d coding sequences remaining." % (len(seqs), len(prots), len(coding))
    # align the sequences
#    for (seqtype, s) in [('protein', prots), ('nucleotide', seqs), ('coding', coding)]:
    for (seqtype, s) in [('protein', prots)]:
        print "\nAligning the %s sequences..." % seqtype
        s = align.Align(s, musclepath, 'MUSCLE')
        s = align.StripGapsToFirstSequence(s)
        f = 'NPhumanH3N2_unique_%s_alignment.fasta' % seqtype
        print "Now writing this alignment to %s" % f
        fasta.Write(s, f)
        rare_cutoff = 0.05
        print "Inspecting target sequences for rare identities (present at less than %.2f frequency)" % rare_cutoff
        for (refhead, refseq) in s[ : len(targetstrains)]:
            print "Looking for rare identities in %s" % refhead
            for i in range(len(refseq)):
                counts = {}
                for (x, y) in s[len(targetstrains) : ]:
                    if y[i] in counts:
                        counts[y[i]] += 1
                    else:
                        counts[y[i]] = 1
                for x in counts.iterkeys():
                    counts[x] /= float(len(s) - len(targetstrains))
                if refseq[i] not in counts:
                    print "Identity of %s at position %d is not present in any other sequence." % (refseq[i], i + 1)
                elif counts[refseq[i]] < rare_cutoff:
                    print "Identity of %s at position %d is present in only %.3f of the other sequences." % (refseq[i], i + 1, counts[refseq[i]])
    

main() # run the script
