#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# YOLO edelmaie@colorado.edu (too)
# Basic
import sys, os, pdb
import gc
## Analysis
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import yaml

from operator import attrgetter

from scipy import special
import scipy.misc

from read_posit_base import ReadPositBase
from microtubule import Microtubule
from spindle_pole_body import SpindlePoleBody
from gaussian_imaging import GaussianLine2D
from gaussian_imaging import GaussianLine3D

# sys.path.append(os.path.join(os.path.dirname(__file__), '~/projects/newagebob/analysis))
# sys.path.append('/Users/adamlamson/projects/newagebob/analysis/Spindle')
from spindle_unit_dict import SpindleUnitDict

from antiparallel_overlap import antiparallel_overlap
from base_funcs import moving_average

# import pandas as pd
# import matplotlib.pyplot as plt
# import matplotlib as mpl
# from math import *

'''
Name: read_posit_spindle.py
Description:
Input:
Output:
'''

uc = SpindleUnitDict()

#Class definition
class ReadPositSpindle(ReadPositBase):
    def __init__(self, seed_path, posit_name,
                 default_file='default.yaml', equil_file='equil.yaml'):
        ReadPositBase.__init__(self, seed_path, posit_name, default_file, equil_file)
        self.microtubules = []
        self.spbs = []

    def LoadPosit(self):
        if not os.path.exists(self.posit_path):
            print "     *** No posit file. Continuing without analyzing *** "
            return False
        else:
            self.f_posit = open(self.posit_path)
            self.ReadPosHeader()
            self.CountFrames()
            self.microtubules = [Microtubule(i) for i in xrange(self.nmts)]
            self.spbs = [SpindlePoleBody(i) for i in xrange(self.nspbs)]
            return True

    def UnloadPosit(self):
        try: self.f_posit.close()
        except: pass

    def ReadPosHeader(self):
        headerdt = np.dtype([
            ('ndim', np.int32),
            ('nspheres', np.int32),
            ('nmts', np.int32),
            ('nspbs', np.int32),
        ])
        header = np.fromfile(self.f_posit, dtype=headerdt, count=1)[0]
        self.f_data_start = self.f_posit.tell()
        self.ndim = abs(header['ndim'])
        self.nspheres = header['nspheres']
        self.nmts = header['nmts']
        self.nspbs = header['nspbs']
        self.mt_per_spb = self.nmts / self.nspbs
        self.cur_frame = -1
        if self.nspheres != 0:
            print "Shouldn't have spheres, exiting!"
            sys.exit(1)

    def CountFrames(self):
        old_file_position = self.f_posit.tell()
        self.f_posit.seek(0, os.SEEK_END)
        final_file_position = self.f_posit.tell()
        self.f_posit.seek(old_file_position, os.SEEK_SET)
        self.framesize = final_file_position - old_file_position
        self.configsize = (self.ndim*self.ndim + 1 +
                (self.nmts*2 + self.nspheres)*self.ndim +
                2*(self.nspbs*self.ndim))*4
        self.nframes = self.framesize / self.configsize

    def ReadFramePosit(self):
        framedt = np.dtype([
            ('h', 'f4', (self.ndim, self.ndim)),
            ('time', 'f4'),
            ('rsite', 'f4', (2*self.nmts+self.nspheres,self.ndim)),
            ('rspb', 'f4', (self.nspbs,self.ndim)),
            ('uspb', 'f4', (self.nspbs,self.ndim)),
        ])
        frame = np.fromfile(self.f_posit, dtype=framedt, count=1)[0]
        self.h = frame['h']
        self.rsite = frame['rsite']
        rspb = frame['rspb']
        uspb = frame['uspb']
        for ispb in xrange(self.nspbs):
            self.spbs[ispb].r = rspb[ispb]
            self.spbs[ispb].u = uspb[ispb]
        for xrod in xrange(self.nmts):
            vmt = self.rsite[2*xrod+1] - self.rsite[2*xrod]
            lmt = np.linalg.norm(vmt)
            umt = np.divide(vmt,lmt)
            rmt = self.rsite[2*xrod] + 0.5*vmt
            self.microtubules[xrod].r = rmt
            self.microtubules[xrod].u = umt
            self.microtubules[xrod].v = vmt
            self.microtubules[xrod].l = lmt
            self.microtubules[xrod].spb = xrod / self.mt_per_spb
        self.cur_frame += 1

    def GetSpindlePoleBodies(self):
        return self.spbs

    def GetSPBPositions(self):
        return np.array([self.spbs[0].r, self.spbs[1].r])

    def GetSpindleVector(self):
        return self.spbs[1].r - self.spbs[0].r

    def GenerateAvgMTLength(self):
        nmts_0 = 0
        nmts_1 = 0
        spb0_mt_avg = 0
        spb1_mt_avg = 0
        for mt in self.microtubules:
            if mt.spb == 0:
                spb0_mt_avg += mt.l
                nmts_0 += 1
            elif mt.spb == 1:
                spb1_mt_avg += mt.l
                nmts_1 += 1

        spb0_mt_avg = np.divide(spb0_mt_avg, nmts_0)
        spb1_mt_avg = np.divide(spb1_mt_avg, nmts_1)

        return [spb0_mt_avg, spb1_mt_avg, np.divide((spb0_mt_avg + spb1_mt_avg), 2)]

    def GenerateMTLengthSPBDistributions(self):
        spb_mt_lengths = [[],[]]
        for mt in self.microtubules:
            spb_mt_lengths[mt.spb] += [mt.l]
        return spb_mt_lengths

    def GenerateMTLengthIndexDistributions(self):
        spb_mt_lengths = [0]*(len(self.microtubules))
        for mt in self.microtubules:
            spb_mt_lengths[mt.idx] = mt.l
        return spb_mt_lengths
    
    def GenerateMTSplay(self):
        spb_mts = [[],[]]
        mt_avg_director = np.zeros((2,3))
        avg_splay = np.zeros(3)

        for mt in self.microtubules: spb_mts[mt.spb] += [mt]
        spb_mt_num = [len(spb_mts[0]), len(spb_mts[1])]


        for spb in [0,1]:
            for mt in spb_mts[spb]:
                mt_avg_director[spb] = np.add(mt_avg_director[spb], mt.u)

            mt_avg_director[spb] = np.divide(mt_avg_director[spb], np.linalg.norm(mt_avg_director[spb]))

            for mt in spb_mts[spb]:
                avg_splay[spb] =  np.add( avg_splay[spb], np.dot(mt_avg_director[spb], mt.u))

            avg_splay[spb] = np.divide(avg_splay[spb], spb_mt_num[spb])

        avg_splay[2] = np.divide( (avg_splay[0]*float(spb_mt_num[0])+
                                   avg_splay[1]*float(spb_mt_num[1])),
                                    spb_mt_num[0]+spb_mt_num[1] )

        return (mt_avg_director.tolist(), avg_splay.tolist())

    def GenerateSPBSeparation(self):
        sep_vec = np.array(self.spbs[0].r) - np.array(self.spbs[1].r)
        return np.linalg.norm(sep_vec)

    def GenerateInterpolarFraction(self, thresh=12.):
        mts = self.microtubules
        tot_overlap_length = 0.0
        max_overlap_length = 0.0

        spb0_mts = []
        spb1_mts = []
        for mt in self.microtubules:
            if mt.spb == 0:
                spb0_mts += [mt]
            elif mt.spb == 1:
                spb1_mts += [mt]

        for mt0 in spb0_mts:
            for mt1 in spb1_mts:
                pairing_dist = antiparallel_overlap(mt0.r, mt0.u, mt0.l,
                                                    mt1.r, mt1.u, mt1.l)
                max_overlap_length += 2.0 * (mt0.l if mt0.l <= mt1.l else mt1.l)
                tot_overlap_length += 2.0 * pairing_dist

                if pairing_dist >= thresh:
                    mt0.paired = 1.
                    mt1.paired = 1.

        tot_paired = 0.
        for mt in self.microtubules:
            tot_paired += mt.paired
            mt.paired = 0.

        return (tot_paired/float(self.nmts), tot_overlap_length/max_overlap_length)

    # This function is a replication of the above one, but returns ALL of the variables calculated by
    # the original spindle_analysis_new
    def BinLengths(self):
        mts = self.microtubules
        uspindle = self.spbs[1].r - self.spbs[0].r
        uspindle = uspindle / np.linalg.norm(uspindle)

        streams = {}
        streams['pairing_length_max'] = []
        streams['interpolar'] = []
        streams['polar'] = []
        streams['total'] = []
        streams['angles'] = []

        pairing_dist_max = np.zeros(len(mts))
        pairing_dist_tot = np.zeros(len(mts))

        spb0_mts = []
        spb1_mts = []
        for mt in self.microtubules:
            if mt.spb == 0:
                spb0_mts += [mt]
            elif mt.spb == 1:
                spb1_mts += [mt]

        for mt0 in spb0_mts:
            for mt1 in spb1_mts:
                pairing_dist = antiparallel_overlap(mt0.r, mt0.u, mt0.l,
                                                    mt1.r, mt1.u, mt1.l)

                pairing_dist_max[mt0.idx] = max(pairing_dist_max[mt0.idx], pairing_dist)
                pairing_dist_max[mt1.idx] = max(pairing_dist_max[mt1.idx], pairing_dist)
                pairing_dist_tot[mt0.idx] += pairing_dist
                pairing_dist_tot[mt1.idx] += pairing_dist

        for imt in mts:
            if pairing_dist_max[imt.idx] > 0.0:
                streams['pairing_length_max'] += [imt.l]

            if pairing_dist_max[imt.idx] > 12.0:
                streams['interpolar'] += [imt.l]
            else:
                streams['polar'] += [imt.l] # FIXME I'm not sure this is right...

            streams['total'] += [imt.l]

            # Angle stuff
            theta = np.arccos(np.fabs(np.dot(uspindle, imt.u)))
            streams['angles'] += [theta]


        return streams

    # def GenerateInterpolarLengthFraction(self, thresh=12.):
        # tot_overlap_length = 0.0
        # max_overlap_length = 0.0
        # spb0_mts = []
        # spb1_mts = []
        # for mt in self.microtubules:
            # if mt.spb == 0: spb0_mts += [mt]
            # elif mt.spb == 1: spb1_mts += [mt]

        # for mt0 in spb0_mts:
            # for mt1 in spb1_mts:
                # pairing_dist = antiparallel_overlap(mt0.r, mt0.u, mt0.l,
                                                    # mt1.r, mt1.u, mt1.l)
                # max_overlap_length += 2.0 * (mt0.l if mt0.l <= mt1.l else mt1.l)
                # tot_overlap_length += 2.0 * pairing_dist


        # return float(len(self.nmts))

    ### Print functionality
    def __repr__(self):
        retrstring = "ReadPositSpindle\n"
        retrstring += "   ndim: {}\n".format(self.ndim)
        retrstring += "   nspb: {}\n".format(self.nspbs)
        retrstring += "   nmts: {}\n".format(self.nmts)
        retrstring += "   nframes: {}".format(self.nframes)

        return retrstring

    def PrintFrame(self):
        print "****************"
        print "Frame: {}".format(self.cur_frame)
        # Loop over spb and print
        for ispb in xrange(self.nspbs):
            print "   spb: {}".format(ispb)
            print "      r: {}".format(self.spbs[ispb].r)
            print "      u: {}".format(self.spbs[ispb].u)
            for imt in xrange(self.mt_per_spb):
                iimt = ispb*self.mt_per_spb + imt
                print "      mt: {}".format(iimt)
                print "         r: {}".format(self.microtubules[iimt].r)
                print "         u: {}".format(self.microtubules[iimt].u)

    ### Image Generation funcitonality
    def ImageMicrotubules2D(self, imageParams):
        #print "generating microtubule data"
        pixelSize = imageParams['pixelsize']
        noiseStd = imageParams['noisestd']
        bkglevel = imageParams['bkglevel']
        A = imageParams['A']
        sigma = imageParams['sigmaxy']
        offset_distance = imageParams['modeloffset']

        numPixelsX = np.int_(np.ceil((self.h[0][0]+offset_distance)/pixelSize))
        numPixelsY = np.int_(np.ceil((self.h[0][0]+offset_distance)/pixelSize))
        xpixels = np.arange(0,numPixelsX)
        ypixels = np.arange(0,numPixelsY)
        [x,y] = np.meshgrid(xpixels, ypixels)

        imagedata = bkglevel * np.ones((numPixelsX, numPixelsY)) \
                + np.random.standard_normal((numPixelsX, numPixelsY))*noiseStd

        #print "imagedata: {}".format(imagedata)

        # Loop over microtubules and create the image
        for microtubule in self.microtubules:
            r0 = microtubule.r - 0.5 * microtubule.v
            x0 = r0[1]/pixelSize + numPixelsX/2;
            y0 = r0[2]/pixelSize + numPixelsY/2

            theta = -np.arctan2(microtubule.u[2], microtubule.u[1])

            mlen = np.sqrt(microtubule.v[1]*microtubule.v[1] + microtubule.v[2]*microtubule.v[2])/pixelSize

            derp = GaussianLine2D(x, y, A, sigma, x0, y0, mlen, theta)
            imagedata = imagedata + derp

        imagedata = np.fliplr(imagedata)
        return imagedata

    def ImageMicrotubules2DPlanar(self, imageParams, rotation_matrix):
        #print "generating microtubule data"
        pixelSize = imageParams['pixelsize']
        noiseStd = imageParams['noisestd']
        bkglevel = imageParams['bkglevel']
        A = imageParams['A']
        sigma = imageParams['sigmaxy']
        offset_distance = imageParams['modeloffset']

        numPixelsX = np.int_(np.ceil((self.h[0][0]+offset_distance)/pixelSize))
        numPixelsY = np.int_(np.ceil((self.h[0][0]+offset_distance)/pixelSize))
        xpixels = np.arange(0,numPixelsX)
        ypixels = np.arange(0,numPixelsY)
        [x,y] = np.meshgrid(xpixels, ypixels)

        imagedata = bkglevel * np.ones((numPixelsX, numPixelsY)) \
                + np.random.standard_normal((numPixelsX, numPixelsY))*noiseStd

        #print "imagedata: {}".format(imagedata)

        # Loop over microtubules and create the image
        for microtubule in self.microtubules:
            # Rotate the microtubule into the frame of interest
            r0 = np.dot(-rotation_matrix, microtubule.r - 0.5 * microtubule.v)
            r1 = np.dot(-rotation_matrix, microtubule.r + 0.5 * microtubule.v)

            newv = r1 - r0
            newu = newv / np.linalg.norm(newv)

            x0 = r0[0]/pixelSize + numPixelsX/2;
            y0 = r0[1]/pixelSize + numPixelsY/2

            theta = -np.arctan2(newu[1], newu[0])

            mlen = np.sqrt(newv[0]*newv[0] + newv[1]*newv[1])/pixelSize

            derp = GaussianLine2D(x, y, A, sigma, x0, y0, mlen, theta)
            imagedata = imagedata + derp

        imagedata = np.fliplr(imagedata)
        return imagedata

    # Get the information for the SPB tomo axis and the rspbhat vector
    def GetPlanarViewInformation(self):
        rhatspb0 = self.spbs[0].r / np.linalg.norm(self.spbs[0].r)
        rhatspb1 = self.spbs[1].r / np.linalg.norm(self.spbs[1].r)
        tomo = np.cross(rhatspb0, rhatspb1)
        tomo /= np.linalg.norm(tomo)
        rhat = self.spbs[1].r - self.spbs[0].r
        rhat /= np.linalg.norm(rhat)

        return [rhat, tomo]

    def ImageMicrotubules3D(self, imageParams):
        pixelSize = imageParams['pixelsize']
        noiseStd = imageParams['noisestd']
        bkglevel = imageParams['bkglevel']
        A = imageParams['A']
        sigmaxy = imageParams['sigmaxy']
        sigmaz  = imageParams['sigmaz']
        pixelSizeZ = (self.h[0][0] + 60.0)/imageParams['nzstacks']

        #print "pixelSize: {}, pixelSizeZ: {}".format(pixelSize, pixelSizeZ)

        # Still generate the number of pixels in XY, but Z is already set by z stacks
        numPixelsX = np.int_(np.ceil((self.h[0][0]+60.0)/pixelSize))
        numPixelsY = np.int_(np.ceil((self.h[0][0]+60.0)/pixelSize))
        numPixelsZ = np.int_(np.ceil((self.h[0][0]+60.0)/pixelSizeZ))
        xpixels = np.arange(0,numPixelsX)
        ypixels = np.arange(0,numPixelsY)
        zpixels = np.arange(0,numPixelsZ)
        [x,y,z] = np.meshgrid(xpixels, ypixels, zpixels)

        imagedata = bkglevel * np.ones((numPixelsX, numPixelsY, numPixelsZ)) \
                + np.random.standard_normal((numPixelsX, numPixelsY, numPixelsZ))*noiseStd

        # Loop over microtubules and create the image
        # All of the XYZ stuff is screwed up to get the YZ projection like the original images
        for microtubule in self.microtubules:
            r0 = microtubule.r - 0.5 * microtubule.v
            x0 = r0[1]/pixelSize + numPixelsX/2
            y0 = r0[2]/pixelSize + numPixelsY/2
            z0 = r0[0]/pixelSizeZ + numPixelsZ/2
            r1 = microtubule.r + 0.5 * microtubule.v
            x1 = r1[1]/pixelSize + numPixelsX/2
            y1 = r1[2]/pixelSize + numPixelsY/2
            z1 = r1[0]/pixelSizeZ + numPixelsZ/2

            derp = GaussianLine3D(A, sigmaxy, sigmaz, x, y, z, x0, y0, z0, x1, y1, z1)
            imagedata = imagedata + derp

        # For each Z slice, have to fliplr the x and y coordinates
        imagedataflip = imagedata
        for zs in zpixels:
            imagedataflip[:,:,zs] = np.fliplr(imagedata[:,:,zs])
        return imagedataflip

##########################################
if __name__ == "__main__":
    t = ReadPositSpindle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
    t.LoadPosit()
    data = [[],[]]

    for i in xrange(t.nframes):
        t.ReadFramePosit()
        if i > 15:  
            frame_data = t.GenerateMTLengthDistributions()
            for n in [0,1]: data[n] += frame_data[n]

    fig, axarr = plt.subplots(3, 1)
    # hist_arr, n_points_arr, bin_mids = sd.PostAnalysis.MTLengthHist(mt_l_dict)
    dist0 = np.array(data[0])*(uc['um'][1])
    dist1 = np.array(data[1])*(uc['um'][1])
    # pdb.set_trace()
    dist2 = np.append(dist0, dist1)
    # weights = np.array(data[1])
    hist0, bin_edges = np.histogram(dist0, bins=110, range=(0.0, 3.75), density=False)
    hist1, bin_edges = np.histogram(dist1, bins=110, range=(0.0, 3.75), density=False)
    hist2, bin_edges = np.histogram(dist2, bins=110, range=(0.0, 3.75), density=False)
    # hist = hist_arr[spb_ind]
    # bin_mids = hist_data[3]
    bin_mids = moving_average(bin_edges)
    color = 'b'
    label = ''
    axarr[0].plot(bin_mids,
            hist0,
            color = color,
            label = label )
    axarr[1].plot(bin_mids,
            hist1,
            color = color,
            label = label )
    axarr[2].plot(bin_mids,
            hist2,
            color = color,
            label = label )
    
    plt.show()
    
