#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# YOLO edelmaie@colorado.edu (too)
# Basic
import sys, os, pdb
import gc
## Analysis
import numpy as np
import yaml

from microtubule import Microtubule

'''
Name: chromosome.py
Description:
Input:
Output:
'''

def SumPoles(n):
    r = 0
    while n:
        r, n = r + n % 10, n // 10
    return r

# Class definition for kinetochores
class Kinetochore():
    def __init__(self, ikc, naf):
        self.ikc = ikc
        self.r = np.zeros(3)
        self.u = np.zeros(3)
        self.v = np.zeros(3)
        self.w = np.zeros(3)
        self.nbound = np.zeros(1)
        self.naf = naf
        self.rsite = np.zeros([self.naf, 3])
        self.attach = np.zeros(self.naf)
        self.crosspos = np.zeros(self.naf)
        self.pole = np.zeros(self.naf)
        self.superpole = np.int_(0)

    def SetValues(self, r, u, v, w, nbound, rsite, attach, crosspos):
        i = 3*self.ikc
        j = i + 3
        self.r = r[i:j:1]
        self.u = u[i:j:1]
        self.v = v[i:j:1]
        self.w = w[i:j:1]
        self.nbound = nbound[self.ikc]
        for isite in xrange(self.naf):
            ii = self.naf*i + 3*isite
            jj = ii+3
            self.rsite[isite] = rsite[ii:jj:1]
            self.attach[isite] = attach[self.naf*self.ikc + isite]
            self.crosspos[isite] = crosspos[self.naf*self.ikc + isite]

    # Generate the pole attachment information
    def GeneratePoleInformation(self, microtubules):
        self.pole.fill(-1.0)
        self.superpole = np.int_(0)
        for isite in xrange(self.naf):
            if self.attach[isite] != -1.0:
                self.pole[isite] = microtubules[np.int_(self.attach[isite])].spb

        # Combine the pole information into superpole
        self.superpole = 0
        for isite in xrange(self.naf):
            if self.pole[isite] != -1.0:
                idx = (self.pole[isite]+1) * np.power(10, isite)
                self.superpole += np.int_(idx)

    def PrintFrame(self):
        print "Kinetochore[{}]".format(self.ikc)
        print "   r: {}".format(self.r)
        print "   u: {}".format(self.u)
        print "   v: {}".format(self.v)
        print "   w: {}".format(self.w)
        print "   nbound: {}".format(self.nbound)
        print "   superpole: {}".format(self.superpole)
        for isite in xrange(self.naf):
            print "      AF: r: {}".format(self.rsite[isite])
            print "      AF: attach: {}".format(self.attach[isite])
            print "      AF: crosspos: {}".format(self.crosspos[isite])
            print "      AF: pole: {}".format(self.pole[isite])

#Class definition
class Chromosomes():
    def __init__(self, posit_path, chromo_yaml):
        self.chromo_yaml = chromo_yaml
        self.posit_name = ''

        # Get the numbers
        self.n_chromo = np.int_(len(self.chromo_yaml['chromosomes']['chromosome']))

        # Overload the number if we have replicate set
        if 'replicate' in self.chromo_yaml['chromosomes']['properties']:
            self.n_chromo = np.int_(self.chromo_yaml['chromosomes']['properties']['replicate'])
        self.nkc = 2*self.n_chromo
        self.naf = np.int_(self.chromo_yaml['chromosomes']['properties']['AF_number_complexes'])
        self.radiuskc = self.chromo_yaml['chromosomes']['properties']['kc_diameter']/2.

        if self.n_chromo > 0:
            self.f_posit = open(posit_path) 

        self.dtype_chromo = np.dtype([
            ('r', np.float32, 3),
            ('u', np.float32, 3),
            ('v', np.float32, 3),
            ('w', np.float32, 3),
            ('nbound', np.int32, 1),
        ])
        self.dtype_af = np.dtype([
            ('r', np.float32, 3),
            ('attach', np.int32, 1),
            ('crosspos', np.float32, 1),
        ])

        self.kinetochores = [Kinetochore(x, self.naf) for x in xrange(self.nkc)]

        # Start up the attachment lifetime stuff, so that we know which attachments go where
        # from the getgo, with -1 being the last attachment state, and we can add up lifetimes!
        self.last_attach = -1.0 * np.ones(self.nkc * self.naf)
        self.attach_lifetimes = np.zeros(self.nkc * self.naf)
        self.kmt_lifetimes = np.zeros(1) # the actual individual attachment lifetimes
        self.first_kmt = True

    def UnloadPosit(self):
        self.f_posit.close()

    ### Print Frame
    def PrintFrame(self):
        for ikc in xrange(self.nkc):
            self.kinetochores[ikc].PrintFrame()
        #self.PrintFrameOld()

    def PrintFrameOld(self):
        print "OLD:"
        for ikc in xrange(self.nkc):
            i = 3*ikc
            j = i+3
            print "Kinetochore[{}]:".format(ikc)
            print "   r: {}".format(self.r[i:j:1])
            print "   u: {}".format(self.u[i:j:1])
            print "   v: {}".format(self.v[i:j:1])
            print "   w: {}".format(self.w[i:j:1])
            print "   nbound: {}".format(self.nbound[ikc])
            for isite in xrange(self.naf):
                ii = self.naf*i + 3*isite
                jj = ii+3
                print "      AF: r: {}".format(self.rsite[ii:jj:1])
                print "      AF: attach: {}".format(self.attach[self.naf*ikc + isite])
                print "      AF: crosspos: {}".format(self.crosspos[self.naf*ikc + isite])

    ### Read functionality
    def ReadFrame(self):
        # Read in information
        self.r = []
        self.u = []
        self.v = []
        self.w = []
        self.nbound = []
        self.rsite = []
        self.attach = []
        self.crosspos = []
        for ikc in xrange(self.nkc):
            cinfo = np.fromfile(self.f_posit, dtype=self.dtype_chromo, count=1)[0]
            self.r = np.append(self.r, cinfo['r'])
            self.u = np.append(self.u, cinfo['u'])
            self.v = np.append(self.v, cinfo['v'])
            self.w = np.append(self.w, cinfo['w'])
            self.nbound = np.append(self.nbound, cinfo['nbound'])

            for isite in xrange(self.naf):
                sinfo = np.fromfile(self.f_posit, dtype=self.dtype_af, count=1)[0]
                self.rsite = np.append(self.rsite, sinfo['r'])
                self.attach = np.append(self.attach, sinfo['attach'])
                self.crosspos = np.append(self.crosspos, sinfo['crosspos'])
        
        # Do the calculation of attachment lifetime, now that everything is compiled together
        # If this is the first time step, make sure that we set up the original attach state correct, IE, set up the last
        # attach
        if self.first_kmt:
            self.last_attach = np.copy(self.attach)
            self.first_kmt = False
        #print "--------"
        #print "attach      = {}".format(self.attach)
        #print "last_attach = {}".format(self.last_attach)
        for ikc in xrange(self.nkc):
            for isite in xrange(self.naf):
                idx = self.naf*ikc + isite

                # If we were unattached, and become attached, record the attach
                if ((self.last_attach[idx] == -1) and (self.attach[idx] != -1)):
                    self.attach_lifetimes[idx] = 1.0
                # IF we were attached, and become unattached, set off the attach and reset counter
                elif ((self.last_attach[idx] != -1) and (self.attach[idx] == -1)):
                    self.kmt_lifetimes = np.append(self.kmt_lifetimes, self.attach_lifetimes[idx])
                    self.attach_lifetimes[idx] = 0.0
                # If we were attached last time, and are the same this time, then increment the counter
                elif ((self.last_attach[idx] != -1) and (self.attach[idx] != -1) and (self.last_attach[idx] == self.attach[idx])):
                    self.attach_lifetimes[idx] += 1.0
                # IF we were attached last time, but have moved to a new attachment this time, then save off info, and set lifetime to 1
                elif ((self.last_attach[idx] != -1) and (self.attach[idx] != -1) and (self.last_attach[idx] != self.attach[idx])):
                    self.kmt_lifetimes = np.append(self.kmt_lifetimes, self.attach_lifetimes[idx])
                    self.attach_lifetimes[idx] = 1.0
                # IF we were unattached and remain so, do nothing

                self.last_attach[idx] = self.attach[idx]


        #print " attach_lifetimes = {}".format(self.attach_lifetimes)
        #print " new last attach  = {}".format(self.last_attach)
        #print " kmt lifetimes = {}".format(self.kmt_lifetimes)
        #raw_input("Press enter to continue....")
                
        # Now translate into kinetochore information
        for ikc in xrange(self.nkc):
            self.kinetochores[ikc].SetValues(self.r,
                                             self.u,
                                             self.v,
                                             self.w,
                                             self.nbound,
                                             self.rsite,
                                             self.attach,
                                             self.crosspos)

    # 0 - Unattached
    # 1 - Monotelic
    # 2 - Merotelic
    # 3 - Syntelic
    # 4 - Amphitelic
    def GeneratePoleInformation(self, microtubules):
        # Genereate the pole information form the microtubules, based on attach
        for ikc in xrange(self.nkc):
            self.kinetochores[ikc].GeneratePoleInformation(microtubules)
        
        atype = np.zeros(self.n_chromo)

        #print "****"
        # Generate the chromosome information
        for ic in xrange(self.n_chromo):
            ikc0 = 2*ic
            ikc1 = 2*ic+1
            #print "Chromosome[{}], kc{}, kc{}".format(ic, ikc0, ikc1)

            pole0 = self.kinetochores[ikc0].pole
            pole1 = self.kinetochores[ikc1].pole

            kc0t = 0 # Unattached single
            kc1t = 0 # Unattached single

            superpole0 = -1
            superpole1 = -1

            #print "  poles: {}, {}".format(pole0, pole1)

            atype[ic] = -1

            # See if we are unattached, whatever, on a per kinetochore basis
            if np.all(pole0 == -1):
                kc0t = 0
            if np.all(pole1 == -1):
                kc1t = 0
            # Merotelic, that is, attached to both poles, or monotelic, attached to one pole
            if 0 in pole0 and 1 in pole0:
                kc0t = 2
            elif 0 in pole0:
                kc0t = 1
                superpole0 = 0
            elif 1 in pole0:
                kc0t = 1
                superpole0 = 1

            if 0 in pole1 and 1 in pole1:
                kc1t = 2
            elif 0 in pole1:
                kc1t = 1
                superpole1 = 0
            elif 1 in pole1:
                kc1t = 1
                superpole1 = 1
           

            # Chromosome is unattached at all
            if kc0t == 0 and kc1t == 0:
                atype[ic] = 0
            elif kc0t == 2 or kc1t == 2:
                atype[ic] = 2
            elif kc0t == 1 and kc1t == 0:
                atype[ic] = 1
            elif kc0t == 0 and kc1t == 1:
                atype[ic] = 1
            elif kc0t == 1 and kc1t == 1:
                # Both are attached to a pole, determine if the same pole for syntelic or different for amphetelic
                if superpole0 == superpole1:
                    atype[ic] = 3
                else:
                    atype[ic] = 4

        #print "  atype: {}".format(atype)
        return atype

    # Look to see for each chromosome how many end on microtubules are attached
    def GenerateEndOnAttachments(self, microtubules):
        # Look for end-on attachments for each chromosome
        nendon = np.zeros(self.n_chromo)

        # Do this regardless of attachment state
        for ic in xrange(self.n_chromo):
            ikc0 = 2*ic
            ikc1 = 2*ic+1

            nend = 0

            # Kinetochore 0 and 1
            for isite in xrange(self.naf):
                ibond0 = np.int(self.kinetochores[ikc0].attach[isite])
                ibond1 = np.int(self.kinetochores[ikc1].attach[isite])
                # Check for attachment 0
                if ibond0 > -1.0:
                    # Check if we are near the tip or not
                    dtip = microtubules[ibond0].l - self.kinetochores[ikc0].crosspos[isite]
                    if dtip < 1.0:
                        nend += 1
                if ibond1 > -1.0:
                    # Check if we are near the tip or not
                    dtip = microtubules[ibond1].l - self.kinetochores[ikc1].crosspos[isite]
                    if dtip < 1.0:
                        nend += 1

            nendon[ic] = nend
        
        return nendon

    def GenerateNBound(self):
        return [np.sum(self.nbound), self.nbound]
