"""Converts FASTA sequences to NEXUS format."""


import re
import os.path
import fasta


def WriteFASTAasNEXUS(seqs, outfile, seqtype):
    """Reads dates from influenza sequences and writes in NEXUS format.
    
    'seqtype' can be 'PROTEIN' or 'DNA'."""
    assert seqtype.upper() in ['PROTEIN', 'DNA']
    ntax = len(seqs)
    nchar = len(seqs[0][1])
    f = open(outfile, 'w')
    f.write("#NEXUS\n\nBegin DATA;\n\tDimensions ntax=%d nchar=%d;\n\tFormat datatype=%s gap=-;\n\tMatrix\n" % (ntax, nchar, seqtype))
    headmatch = re.compile('\S+ (?P<name>A/[\w\- ]+/\S+) *(?P<year>\d+)/(?P<month>\d*)/(?P<day>\d*) NP')
    names = {}
    for (head, seq) in seqs:
        assert len(seq) == nchar, "All sequences not of the same length."
        m = headmatch.search(head)  
        if not m:
            raise ValueError("Failed to match header: %s" % head)
        name = m.group('name').replace(' ', '_')
        if name in names:
            variant = names[name]
            names[name] += 1
        else:
            names[name] = 1
            variant = None
        year = int(m.group('year'))
        date = year
        if m.group('month'):
            month = int(m.group('month'))
            date += (month - 1) / 12.
        if m.group('day'):
            day = int(m.group('day'))
            date += (day - 1) / 30. / 12.
        if variant:
            f.write("\t%s_variant%d_%.2f %s\n" % (name, variant, date, seq))
        else:
            f.write("\t%s_%.2f %s\n" % (name, date, seq))
    f.write(';\nEnd;')
    f.close()


def main():
    """Main body of script."""
#    for (file, seqtype) in [('NPhumanH3N2_unique_protein_alignment.fasta', 'PROTEIN'), ('NPhumanH3N2_unique_nucleotide_alignment.fasta', 'DNA'), ('NPhumanH3N2_unique_coding_alignment.fasta', 'DNA')]:
    for (file, seqtype) in [('NPhumanH3N2_unique_protein_alignment.fasta', 'PROTEIN')]:
        (base, ext) = os.path.splitext(file)
        nexusfile = "%s.nex" % base
        print "\nConverting the %s sequences in %s to NEXUS format and writing to %s." % (seqtype, file, nexusfile)
        seqs = fasta.Read(file)
        print "Read %d sequences." % len(seqs)
        WriteFASTAasNEXUS(seqs, nexusfile, seqtype)

main() # run the script
