#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# YOLO edelmaie@colorado.edu (too)
# Basic
import sys, os, pdb
import gc
## Analysis
import numpy as np
import yaml

from operator import attrgetter

from scipy import special
import scipy.misc

from read_posit_base import ReadPositBase
from xlink import Xlink
from gaussian_imaging import GaussianSpot2D
from microtubule import Microtubule

# import pandas as pd
# import matplotlib.pyplot as plt
# import matplotlib as mpl
# from math import *

'''
Name: read_posit_xlinks.py
Description:
Input:
Output:
'''

#Class definition
class ReadPositXlinks(ReadPositBase):
    def __init__(self, seed_path, posit_name, 
                 default_file='default.yaml', equil_file='equil.yaml'):
        ReadPositBase.__init__(self, seed_path, posit_name, default_file, equil_file)

        self.xlinks = [] # List of Xlink objects for each species

    def LoadPosit(self):
        # Find the crosslink yaml file
        self.xlink_yaml_name = self.CheckDefaultThenEquil('crosslink_file')
        if self.xlink_yaml_name == None:
            print "     *** No xlink config file in seed *** "
            return False

        # Load all species posit files
        try:
            with open(os.path.join(self.seed_path, self.xlink_yaml_name), 'r') as stream:
                self.xlink_yaml = yaml.load(stream)
        except:
            print " ***  Could not open xlink config file: {}  ***".format(self.xlink_yaml_name)
            return False

        # Load any numbers we need
        # pdb.set_trace()
        if self.xlink_yaml['crosslink'] == None: 
            print " *** No xlinks used in this seed. *** "
            return False
        else:
            self.ntypes = np.int_(len(self.xlink_yaml['crosslink']))
            self.nstages = 3 # FIXME hardcoded for now
            try: 
                self.xlinks = [Xlink(self.posit_path, self.xlink_yaml, self.nstages, x) for x in xrange(self.ntypes)]
            except IOError:
                print " *** No xlinks used in this seed. *** "
                return False

            return True

    def UnloadPosit(self):
        for xlink in self.xlinks:
            xlink.UnloadPosit()

    def ReadFramePosit(self):
        # Defer to the xlink class itself
        for xlink in self.xlinks:
            xlink.ReadFrame()

    def PrintFrame(self):
        for xlink in self.xlinks:
            xlink.PrintFrame()

    ### Image generation functionality
    def ImageXlink2D(self, imageParams, h=None, xtype=0, microtubules=None):
        pixelSize = imageParams['pixelsize']
        noiseStd = imageParams['noisestd']
        bkglevel = imageParams['bkglevel']
        A = imageParams['A']
        sigma = imageParams['sigmaxy']
        offset_distance = imageParams['modeloffset']

        numPixelsX = np.int_(np.ceil((h[0][0]+offset_distance)/pixelSize))
        numPixelsY = np.int_(np.ceil((h[0][0]+offset_distance)/pixelSize))
        xpixels = np.arange(0,numPixelsX)
        ypixels = np.arange(0,numPixelsY)
        [x,y] = np.meshgrid(xpixels, ypixels)

        imagedata = bkglevel * np.ones((numPixelsX, numPixelsY)) \
                + np.random.standard_normal((numPixelsX, numPixelsY))*noiseStd

        # Grab the correct xlink to image
        xlinkimage = self.xlinks[xtype]
        xlinkr = xlinkimage.GeneratePositions(microtubules)

        #print "xlinkr: {}".format(xlinkr)
        #print "len: {}".format(len(xlinkr))
        for idx,xit in enumerate(xlinkr):
            #print "xit: {}".format(xit)
            x0 = xit[1]/pixelSize + numPixelsX/2
            y0 = xit[2]/pixelSize + numPixelsY/2

            if np.isnan(x0).any() or np.isnan(y0).any():
                print "found NaN!"
                print "idx: {}".format(idx)
                print "xit: {}".format(xit)
                sys.exit(1)

            derp = GaussianSpot2D(x, y, A, sigma, x0, y0)
            imagedata = imagedata + derp

        imagedata = np.fliplr(imagedata)
        return imagedata

    ### Generate the 1d positions of the bound xlinks along the spindle
    def GenerateSPBXlinkDistance(self, spbs, microtubules, xtype=0):
        xlinkset = self.xlinks[xtype]
        xlinkr = xlinkset.GeneratePositions(microtubules)
        
        distances = []
        weights = []

        distances_stg1, weights_stg1 = self.GenerateSPBStage1XlinkDistance(xtype, spbs, microtubules)
        distances_stg2, weights_stg2 = self.GenerateSPBStage2XlinkDistance(xtype, spbs, microtubules)

        distances += [distances_stg1, distances_stg2]

        weights += [weights_stg1, weights_stg2]

        return [distances, weights]


    def GenerateSPBStage1XlinkDistance(self, spbs, microtubules, xtype=0):
        xlinkset = self.xlinks[xtype]
        xlinkr = xlinkset.GeneratePositions(microtubules)
        
        distances = []
        weights = []

        for xit in xlinkset.stage1:
            # What microtublue am I attached to?
            parent = xit.parent
            # who has spb?
            spb = microtubules[parent].spb
            #print "r: {}, mt: {}, spb: {}".format(xit.r, parent, spb)
            # Calculate distance
            #print "xit.r: {}, spbs[spb].r: {}".format(xit.r, spbs[spb].r)
            dist = np.linalg.norm(spbs[spb].r - xit.r)

            # distances = np.append(distances, dist)
            distances += [dist]
            # weights = np.append(weights, 1.0)
            weights += [1.0]

        return [distances, weights]


    def GenerateSPBStage2XlinkDistance(self, spbs, microtubules, xtype=0):
        xlinkset = self.xlinks[xtype]
        xlinkr = xlinkset.GeneratePositions(microtubules)
        
        distances = []
        weights = []

        for xit in xlinkset.stage2:
            parent1 = xit.bond1
            parent2 = xit.bond2
            spb1 = microtubules[parent1].spb
            spb2 = microtubules[parent2].spb
            dist1 = np.linalg.norm(spbs[spb1].r - xit.r)
            dist2 = np.linalg.norm(spbs[spb2].r - xit.r)

            distances += [dist1, dist2]
            weights += [.5, .5] # 0.5 for each SPB

        return (distances, weights)


    def GenerateXlinkSpindle1D(self, spbs, microtubules, xtype=0, target=2.75):
        """ Creates a list of all xlink positions relative to a spindle for current frame.
            Spindle is defined by the vector from SPB1 to SPB2 and xlink position is the 
            projection of the xlink position relative to SPB1 onto spindle vector.

        Input: [spb objects], [microtubule objects], xlink species, target spindle length(depricated)
        Output: [ [[stage1 xlink distances on spindle], [stage1 weights of corresponding distances]],
                  [[stage2 xlink distances on spindle], [stage2 weights of corresponding distances]]
                ] 
        """ 
        stg1_distances = [] 
        stg1_weights = []
        stg2_distances = [] 
        stg2_weights = []

        # First check if the spindle length is within the +-50nm range of target
        rspindle = spbs[1].r - spbs[0].r
        spindle_length = np.linalg.norm(rspindle)
        #print "rspindle: {}, length: {}, target: {}".format(rspindle, spindle_length, target)

        # if spindle_length > (target - 2.0) and spindle_length < (target + 2.0):
        xlinkset = self.xlinks[xtype]
        xlinkr = xlinkset.GeneratePositions(microtubules)

        # Loop over stage 1 xlinks and see where they are located, normalize positions
        for xit in xlinkset.stage1:
            rcross = xit.r - spbs[0].r # shift to common origin
            dist = np.divide(np.dot(rcross, rspindle), spindle_length*spindle_length)
            stg1_distances += [dist]
            stg1_weights += [1.0]

        # Same deal for stage 2
        for xit in xlinkset.stage2:
            rcross = xit.r - spbs[0].r
            dist = np.divide(np.dot(rcross, rspindle), spindle_length*spindle_length)
            # distances = np.append(distances, dist)
            stg2_distances += [dist]
            # weights = np.append(weights, 1.0)
            stg2_weights += [1.0]

        #print "xlink spindle 1d distances: {}".format(distances)
        #print "xlink spindle 1d weights: {}".format(weights)

        return [ [stg1_distances, stg1_weights], [stg2_distances, stg2_weights] ] 

    def GenerateNumXlinks(self, xtype):
        xs = self.xlinks[xtype]
        return [ xs.nfree, 
                 xs.nbound1[0]+xs.nbound1[1],
                 xs.nbound2,
                 xs.nfree+xs.nbound1[0]+xs.nbound1[1]+xs.nbound2 ]

    # def Generate

##########################################
if __name__ == "__main__":
    xlinks = ReadPositXlinks(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
    xlinks.LoadPosit()
    xlinks.ReadFramePosit()
    xlinks.PrintFrame()
