#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# YOLO edelmaie@colorado.edu (too)
# Basic
import sys, os, pdb
import gc
## Analysis
import numpy as np
import yaml

'''
Name: xlink.py
Description:
Input:
Output:
'''

# Class holders for various stages
# Stage 0 xlinks
class XlinkStage0():
    def __init__(self):
        self.r = []

class XlinkStage1():
    def __init__(self):
        self.parent = 0
        self.pos = 0
        self.r = []

class XlinkStage2():
    def __init__(self):
        self.bond1 = 0
        self.bond2 = 0
        self.cross_position = [0 for x in xrange(2)]
        self.r = []

#Class definition
class Xlink():
    def __init__(self, posit_path, xlink_yaml, nstages, xtype):
        self.xlink_yaml = xlink_yaml
        self.xtype = xtype
        self.nstages = nstages
        self.posit_name = []
        self.r = []
        self.r_flag = False # Flag that specifies if xlink positions have been generated this frame
        
        # Create a list of posit file names
        for istage in xrange(self.nstages):
            x_name_posit = "{}.stage{}.type{}".format(posit_path, istage, self.xtype)
            self.posit_name.append(x_name_posit)

        #print "{}".format(self.posit_name)
        self.nfreedt = np.dtype([('nfree', np.int32)])
        self.nstage1dt = np.dtype([
            ('nbound10', np.int32),
            ('nbound11', np.int32),
        ])
        self.xstage1dt = np.dtype([
            ('headtype', np.int32),
            ('parent', np.int32),
            ('pos', np.float32),
        ])
        self.nstage2dt = np.dtype([
            ('nbound2', np.int32),
        ])
        self.xstage2dt = np.dtype([
            ('head1', np.int32),
            ('head2', np.int32),
            ('bond1', np.int32),
            ('bond2', np.int32),
            ('crosspos', 'f4', 2),
        ])
        self.f_posit = []
        for istage in xrange(self.nstages):
            f_posit = open(self.posit_name[istage], 'rb')
            self.f_posit.append(f_posit)

    def UnloadPosit(self):
        for f in self.f_posit:
            f.close()

    ### Print functionality
    def PrintFrame(self):
        print "Xlink[{}]".format(self.xtype)
        print "   stage0: nfree: {}".format(self.nfree)
        print "   stage1: nbound1: {}".format(self.nbound1)
        print "   stage2: nbound2: {}".format(self.nbound2)


    ### Read information
    def ReadFrame(self):
        self.ReadStage0()
        self.ReadStage1()
        self.ReadStage2()
        self.r_flag = False

    def ReadStage0(self):
        nfree = np.fromfile(self.f_posit[0], dtype=self.nfreedt, count=1)[0]
        self.nfree = nfree['nfree']
        #print "nfree local: {}".format(self.nfree)
        self.stage0 = [XlinkStage0() for x in xrange(self.nfree)]
        if self.nfree == 0:
            #print "Xlink{} no free xlinks".format(self.xtype)
            return
        stage0dt = np.dtype([
            ('r', np.float32, (self.nfree, 3)),
        ])
        rfree = np.fromfile(self.f_posit[0], dtype=stage0dt, count=1)[0]
        #if np.isnan(rfree['r']).any():
        #    print "Foudn NaN while reading stage0 xlinks!"
        #    print "rfree: {}".format(rfree['r'])
        #    sys.exit(1)
        for x in xrange(self.nfree):
            self.stage0[x].r = rfree['r'][x]

    def ReadStage1(self):
        self.nbound1 = np.zeros(2)
        nbound1 = np.fromfile(self.f_posit[1], dtype=self.nstage1dt, count=1)[0]
        self.nbound1[0] = nbound1['nbound10']
        self.nbound1[1] = nbound1['nbound11']
        self.ntotal1 = np.int32(self.nbound1[0] + self.nbound1[1])
        self.stage1 = [XlinkStage1() for x in xrange(self.ntotal1)]
        for xit in xrange(self.ntotal1):
            xs1 = np.fromfile(self.f_posit[1], dtype=self.xstage1dt, count=1)[0]
            self.stage1[xit].parent = xs1['parent']
            self.stage1[xit].pos = xs1['pos']

    def ReadStage2(self):
        nbound2 = np.fromfile(self.f_posit[2], dtype=self.nstage2dt, count=1)[0]
        self.nbound2 = nbound2['nbound2']
        self.stage2 = [XlinkStage2() for x in xrange(self.nbound2)]
        for xit in xrange(self.nbound2):
            xs2 = np.fromfile(self.f_posit[2], dtype=self.xstage2dt, count=1)[0]
            self.stage2[xit].bond1 = xs2['bond1']
            self.stage2[xit].bond2 = xs2['bond2']
            self.stage2[xit].cross_position = xs2['crosspos']

    def GeneratePositions(self, microtubules):
        """ Store and return xlink positions for a frame """
        # If positions have been found for frame return positions
        if self.r_flag: return self.r
        self.r = []
        # Generate the positions of the different stages
        # Stage0
        #for i0 in xrange(len(self.stage0)):
        #    if np.isnan(self.stage0[i0].r).any():
        #        print "Encountered NaN in stage0 xlink generating {}".format(i0)
        #        sys.exit(1)
        #    self.r.append(self.stage0[i0].r)
        for xit in self.stage1:
            r0 = microtubules[xit.parent].r
            r0 = r0 + (xit.pos - 0.5*microtubules[xit.parent].l)*microtubules[xit.parent].u
            xit.r = r0
            if np.isnan(r0).any():
                print "Encountered NaN in stage1 xlink generating"
                sys.exit(1)
            self.r.append(r0)

        for xit in self.stage2:
            # Two microtubules now!
            r1 = microtubules[xit.bond1].r
            r2 = microtubules[xit.bond2].r
            r1 = r1 + (xit.cross_position[0] - 0.5*microtubules[xit.bond1].l)*microtubules[xit.bond1].u
            r2 = r2 + (xit.cross_position[1] - 0.5*microtubules[xit.bond2].l)*microtubules[xit.bond2].u
            r0 = (r1+r2)/2.
            xit.r = r0
            if np.isnan(r0).any():
                print "Encountered NaN in stage2 xlink generating"
                sys.exit(1)
            self.r.append(r0)

        #print "r: {}".format(self.r)
        self.r_flag = True # positions have now been analyzed
        return self.r

