#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# Basic
import sys, os, pdb
import json
import yaml
import cPickle  as pickle
import re
from collections import OrderedDict
## Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *
from criteria_funcs import *
from copy import deepcopy

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
from base_funcs import moving_average
from spindle_unit_dict import SpindleUnitDict
from read_posit_spindle import ReadPositSpindle
from read_posit_thermo import ReadPositThermo
from read_posit_forces import ReadPositForces
from read_posit_xlinks import ReadPositXlinks
from read_posit_chromosomes import ReadPositChromosomes

'''
Name: SpindlePostAnalysis.py
Description: Does any more spindle post analysis (reading posit) stuff that needs doing
Input:
Output:
'''

uc = SpindleUnitDict()

#Class definition
class SpindlePostAnalysis(object):
    def __init__(self, sd_dir_path, opts=None, uc=None,):
        print "    path: {}".format( os.path.relpath(
                                     sd_dir_path,
                                     os.path.join(sd_dir_path,'../../../..')
                                     )
                                   )

        self.opts = opts
        self.analyze_xlinks = False
        self.analyze_chromosomes = False
        self.analyze_fitness = False
        self.sd_dir_path = os.path.realpath(sd_dir_path)

        #List of beginning and end times if spindle was successful.
        #If not self.success_spindle remains None
        self.success_spindle = None

        self.spindlereader = None
        self.thermoreader = None
        self.forcereader = None

        # FIXME these might not always be the correct names for the posit and yaml files
        self.spindlereader = ReadPositSpindle(sd_dir_path, "spindle_bd_mp.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        self.posit_flag = self.spindlereader.LoadPosit()
        self.n_bonds = len(self.spindlereader.microtubules)
        self.xlinkreader = ReadPositXlinks(sd_dir_path, "crosslinks.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        self.analyze_xlinks = self.xlinkreader.LoadPosit()

        # Check for the analysis chromosomes
        if self.spindlereader.CheckDefaultThenEquil('chromosome_config'):
            self.analyze_chromosomes = True

        # Load the readers
        if self.analyze_chromosomes:
            self.chromosomereader = ReadPositChromosomes(sd_dir_path, "chromosomes.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
            self.chromosomereader.LoadPosit()
            if not self.chromosomereader.chromosomes:
                self.analyze_chromosomes = False

        self.thermoreader = ReadPositThermo(sd_dir_path, "spindle_bd_mp.thermo", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        self.thermoreader.LoadPosit()
        self.forcereader = ReadPositForces(sd_dir_path, "spindle_bd_mp.forces", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        self.forcereader.LoadPosit(self.chromosomereader.chromosomes.n_chromo)

        self.uc = uc

        # We always have a spindle reader
        self.nposit = self.spindlereader.CheckDefaultThenEquil('n_posit')
        self.ngraph = self.spindlereader.CheckDefaultThenEquil('n_graph')
        self.delta = self.spindlereader.CheckDefaultThenEquil('delta')
        self.wait_time = self.spindlereader.CheckDefaultThenEquil('wait_time') * self.uc['sec'][1] / 60.0

        # If there are chromosomes, do a check to see if we need to figure out
        # if the anaphase flag is turned on....
        self.do_anaphase = False
        if self.analyze_chromosomes:
            if self.chromosomereader.GetAnaphase():
                self.do_anaphase = True

        self.timedata = {} #Dictionary of data arrays vs time
        self.timedata['time'] = [] #List of frame times
        self.distrdata = {}
        self.xlinkdata = {}
        self.kcdata = {}

        self.anaphase_onset = 0.0
        self.anaphase_onset_step = 0
        self.naf = 0
        self.nxsome = 0
        if self.analyze_chromosomes:
            self.naf = self.chromosomereader.chromosomes.naf
            self.nxsome = self.chromosomereader.chromosomes.n_chromo


        # Short medium and long in sim units
        self.ldict = {}
        self.ldict['short'] = 1.05 / self.uc['um'][1]
        self.ldict['med'] = 1.825 / self.uc['um'][1]
        self.ldict['long'] = 2.15 / self.uc['um'][1]
        self.name_ldict = {}
        self.name_ldict['long'] = '2.15um'
        self.name_ldict['med'] = '1.825um'
        self.name_ldict['short'] = '1.05um'

        self.measuretime = 5.0 # measurement timing in seconds
        #self.nmeasure = np.int_(self.measuretime / self.uc['sec'][1] / self.delta / self.nposit)
        self.nmeasure = 1

        if self.analyze_xlinks: pass
        if self.analyze_chromosomes:
            print "Post Analysis on Chromosomes"
            print "Anaphase: {}".format(self.do_anaphase)
        if self.opts.fitness:
            print "Post Analysis on Fitness"
            self.analyze_fitness = True

        # If we find a .start-ing file, tag this as bad and move on
        self.bad_seed = False
        if os.path.isfile(os.path.join(self.sd_dir_path, '.start-ing')):
            print "Found bad or unfinished seed via .start-ing file"
            self.bad_seed = True

    # Load or analyze
    def CheckLoadAnalyze(self, file_path, data, force_analyze=False):
        if data == None: data = copy(self.timedata)
        if os.path.isfile(file_path) and not force_analyze:
            # Skip analysis and load
            try:
                self.load(file_path, data)
                return False
            except EOFError: return False
            except: raise
        else:
            return True

    # Only does the analyze, and creates a pandas DF?
    def Analyze(self, force_analyze=False):
        if not self.CheckLoadAnalyze(os.path.join(self.sd_dir_path, "SpindlePostAnalysisSeed.pickle"),
                                     self.timedata, force_analyze):
            return False # no need to re-analyze
        elif not self.posit_flag:
            return False
        elif self.bad_seed:
            return False
        else:
            for frame in xrange(self.spindlereader.nframes):
            #for frame in xrange(100):
                #print "frame: {}".format(frame)
                self.spindlereader.ReadFramePosit()
                if self.thermoreader:
                    self.thermoreader.ReadFramePosit()
                if self.forcereader:
                    self.forcereader.ReadFramePosit()
                if self.analyze_xlinks:
                    self.xlinkreader.ReadFramePosit()
                if self.analyze_chromosomes:
                    self.chromosomereader.ReadFramePosit()

                # Look at each xlink and see where it happens to lie, which SPB it's on,
                # etc
                if (frame % self.nmeasure == 0):
                    trueframe = frame * self.nposit / self.ngraph
                    frametime = frame * self.nposit * self.delta * self.uc['min'][1]
                    # print "Trueframe: {}, time: {}".format(trueframe, frametime)
                    self.timedata['time'].append(frametime)
                    self.SPBSeparation(frametime)
                    # self.XlinkSpindle1D(frametime, target=42.0) # 1050nm = short spindle!
                    self.AvgMTlength(frametime)
                    self.MTLengthDistributions(frametime)
                    self.AvgMTDirectorSplay(frametime)
                    self.InterpolarFraction(frametime)
                    if self.thermoreader:
                        self.AnalyzeThermo(frametime)
                    if self.forcereader:
                        self.AnalyzeForces(frametime)
                    if self.analyze_xlinks:
                        self.XlinkDistance(frametime)
                        self.NumXlinks(frametime)
                    if self.analyze_chromosomes:
                        self.KCSpindle1D(frametime)
                        self.KCAttachType(frametime)
                        self.KCNAttachedEndon(frametime)
                        self.KCDistance(frametime)
                        self.KCOccupancy(frametime)
                        self.kMTLifetimes(frametime)
                    if self.analyze_fitness:
                        self.FitnessStreams(frametime)

                if (frame*10000/self.spindlereader.nframes) % 1000 == 0:
                    sys.stdout.write('\r')
                    sys.stdout.write("Post Analysis on spindle: {}% ".format(frame*100/self.spindlereader.nframes))
                    sys.stdout.flush()
                    # print "{}% ".format(frame*100/self.spindlereader.nframes)
            print ""

        # If we are doing anaphase, check to see when it occurred
        if self.do_anaphase:
            self.anaphase_onset_step = 0
            # Opem up the sim.log file and look for anaphase
            with open(os.path.join(self.sd_dir_path, 'sim.log'), 'r') as fsimlog:
                for line in fsimlog:
                    manaphase = re.search(r'Turning off the SAC, step: (\d+)', line)
                    if manaphase:
                        self.anaphase_onset_step = np.int_(manaphase.group(1))
                        self.anaphase_onset = self.anaphase_onset_step * self.delta * self.uc['min'][1]
            print "Anaphase onset: step {}, {} min".format(self.anaphase_onset_step, self.anaphase_onset)

        return True # Seed was analyzed and must be saved in spindle_seed

    # Unload the posits at the end...
    def UnloadPosit(self):
        if self.spindlereader:
            self.spindlereader.UnloadPosit()
            del self.spindlereader
        if self.thermoreader:
            self.thermoreader.UnloadPosit()
            del self.thermoreader
        if self.xlinkreader:
            self.xlinkreader.UnloadPosit()
            del self.xlinkreader
        if self.chromosomereader and self.analyze_chromosomes:
            self.chromosomereader.UnloadPosit()
            del self.chromosomereader

    ### Generate distribution data
    #def MakeDistributionData(self):
    #    if not self.CheckLoadAnalyze(os.path.join(self.sd_dir_path, "SpindleDistributionAnalysis.pickle"), self.distrdata, self.opts.analyze):
    #        return False

    #    # Xlink distribution data
    #    self.MakeXlinkDistributionData()

    #    # Kinetochore distribution data
    #    if self.analyze_chromosomes:
    #        self.MakeKCDistributionData()

    #    return True

    ### Analysis of Xlinks
    def XlinkDistance(self, frametime):
        """ Find singly and doubly bound crosslinks distances
        from their attached spbs.
        """
        # Make sure you have the xlink distances for both stages
        if ( 'spb_stage1_xlink_distance' not in self.xlinkdata or
                frametime not in self.xlinkdata['spb_stage1_xlink_distance']):
            self.SPBStage1XlinkDistanceTime(frametime)
        if ( 'spb_stage2_xlink_distance' not in self.xlinkdata or
                frametime not in self.xlinkdata['spb_stage2_xlink_distance']):
            self.SPBStage2XlinkDistanceTime(frametime)
        if ( 'spindle_xlink_distance' not in self.xlinkdata or
                frametime not in self.xlinkdata['spindle_xlink_distance']):
            self.SpindleXlinkDistance(frametime)

    def SPBStage1XlinkDistanceTime(self, frametime):
        # TODO make this available for all xlinks not just the first species
        [distances, weights] = self.xlinkreader.GenerateSPBStage1XlinkDistance(self.spindlereader.spbs, self.spindlereader.microtubules, xtype=0)
        if 'spb_stage1_xlink_distance' not in self.xlinkdata:
            self.xlinkdata['spb_stage1_xlink_distance'] = {}
        self.xlinkdata['spb_stage1_xlink_distance'][frametime] = [distances, weights]

    def SPBStage2XlinkDistanceTime(self, frametime):
        # TODO make this available for all xlinks not just the first species aka loop over xtype
        [distances, weights] = self.xlinkreader.GenerateSPBStage2XlinkDistance(self.spindlereader.spbs, self.spindlereader.microtubules, xtype=0)
        if 'spb_stage2_xlink_distance' not in self.xlinkdata:
            self.xlinkdata['spb_stage2_xlink_distance'] = {}
        self.xlinkdata['spb_stage2_xlink_distance'][frametime] = [distances, weights]

    def SpindleXlinkDistance(self, frametime):
        '''Find the distribution along the entire spindle relative to
        the first spb of stage 1 and stage 2 crosslinks. The distribution is normalized to the length of the
        spindle.
        '''
        spbs = self.spindlereader.spbs
        mts = self.spindlereader.microtubules
        # TODO Modify so that you get both stage1 and stage2 crosslinks
        # TODO make this available for all xlinks not just the first species
        [stage1_data, stage2_data] = self.xlinkreader.GenerateXlinkSpindle1D(spbs, mts, xtype=0)
        # print "SpindleXlinkDistance is not implemented yet"
        if 'spindle_xlink_distance_stage1' not in self.xlinkdata:
            self.xlinkdata['spindle_xlink_distance_stage1'] = {}
        self.xlinkdata['spindle_xlink_distance_stage1'][frametime] = stage1_data

        if 'spindle_xlink_distance_stage2' not in self.xlinkdata:
            self.xlinkdata['spindle_xlink_distance_stage2'] = {}
        self.xlinkdata['spindle_xlink_distance_stage2'][frametime] = stage2_data

    def MakeDistributionData(self):
        pickle_flag = self.CheckLoadAnalyze( os.path.join(self.sd_dir_path,
                                             "SpindleDistributionAnalysis.pickle"),
                                        self.distrdata,
                                        self.opts.analyze )

        save_flag = (pickle_flag or self.opts.analyze)
        # keys = [ 'mt_lengths', 'mt_lengths_by_index' ]
        ### MT distributions added to self.distrdata
        # for k in keys:
            # d = self.distrdata[k] = { 'spb1':{},
                                      # 'spb2':{},
                                      # 'merge':{}
                                     # }
        if save_flag:
            # MT lengths by spb
            d = self.distrdata['mt_lengths'] = {} 
            # FIXME only works for current layout of MTLengthsHist which might change in the future
            hist_arr, n_points_arr, bin_mids = self.MTLengthHist(self.timedata['mt_lengths'])
            d['spb1'] = {'mid_points': bin_mids, 'hist': hist_arr[0], 'n_points': n_points_arr[0]}
            d['spb2'] = {'mid_points': bin_mids, 'hist': hist_arr[1], 'n_points': n_points_arr[1]}
            d['merge'] = {'mid_points': bin_mids, 'hist': hist_arr[2], 'n_points': n_points_arr[2]}

            # MT lengths by index
            # TODO Make this more modular with above function
            d = self.distrdata['mt_lengths_by_index'] = {}
            df = pd.DataFrame.from_dict(self.timedata['mt_lengths_by_index'], orient='index').sort_index()
            time = self.timedata['time']
            if self.timedata['succ_info']['succ']: start_time = self.timedata['succ_info']['start_time'] + time[15]
            elif len(time) > 100:
                start_time = time[100]
            else: start_time = time[-1]
            for i in xrange(0, self.n_bonds):
                d[i] = {}
                dist = np.array(df[i])*(uc['um'][1])
                d[i]['hist'], bin_edges = np.histogram(dist, bins=60, range=(0.0, 3.75))
                d[i]['bin_mids'] = moving_average(bin_edges)
                d[i]['n_points'] = dist.size

        # If there are no xlinks to analyze don't pickle data
        ### Xlink distributions added to self.distrdata
        if self.analyze_xlinks and save_flag:
            # Keys for making histograms with flag to normalize or not
            keys = [ ('spb_stage1_xlink_distance', False),
                     ('spb_stage2_xlink_distance', False),
                     ('spindle_xlink_distance_stage1', True),
                     ('spindle_xlink_distance_stage2', True)
                   ]
            # Loop over keys to make integrated and final stage xlink distributions
            for k, n in keys:
                self.distrdata[k] = {}
                # Final configuration
                d = self.distrdata[k]['final'] = {}
                d['mid_points'], d['hist'], d['n_points'] = self.XlinkDistanceHist(self.xlinkdata[k], xstate=r'final', key=k, norm=n)
                # Integrated configuration
                d = self.distrdata[k]['integrated'] = {}
                d['mid_points'], d['hist'], d['n_points'] = self.XlinkDistanceHist( self.xlinkdata[k], xstate=r'integrated',
                                                                                    key=k, norm=n )
            del self.xlinkdata

        # If there are no chromosomes to analyze, don't pikcle the data either
        if self.analyze_chromosomes and save_flag:
            keys = [ ('kc_spindle_1d', True)
                   ]
            # Loop over the keys to make the integrated (and other) distributions for KCs
            for k, n in keys:
                self.distrdata[k] = {}
                # Always take integrated condition
                d = self.distrdata[k]['integrated'] = {}
                # Do the 3 different target lenghts for EM
                for k2, v2 in self.ldict.iteritems():
                    self.distrdata[k]['integrated'][k2] = {}
                    d = self.distrdata[k]['integrated'][k2]
                    d['bin_edges'], d['mid_points'], d['hist'], d['n_points'] = self.KCSpindle1DHist(self.kcdata[k][k2], xstate=r'integrated')
                # Do the final one
                self.distrdata[k]['integrated']['full'] = {}
                d = self.distrdata[k]['integrated']['full']
                d['bin_edges'], d['mid_points'], d['hist'], d['n_points'] = self.KCSpindle1DHist(self.kcdata[k]['full'], xstate=r'integrated')

            # Load this distrubtion data, specifically for the stretch, not compatible with the one above
            self.distrdata['inter_kc_stretch'] = {}
            d = self.distrdata['inter_kc_stretch']
            d['bin_edges'], d['mid_points'], d['hist'], d['n_points'] = self.InterKCStretchHist()

            del self.kcdata

        if self.analyze_fitness:
            combined = {}
            # Compute the length stream for everything
            for k,vname in self.name_ldict.iteritems():
                target = self.timedata['fitness'][vname]
                if vname not in combined:
                    combined[vname] = {}
                for ftime,stream in target.iteritems():
                    for sname,sresults in stream.iteritems():
                        if sname not in combined[vname]:
                            combined[vname][sname] = []
                        combined[vname][sname] += sresults

            #print "combined: {}".format(combined['1.825um']['interpolar'])
            self.distrdata['fitness'] = combined

        if self.do_anaphase:
            if 'anaphase_onset' not in self.distrdata:
                self.distrdata['anaphase_onset'] = self.anaphase_onset
                self.distrdata['anaphase_onset_step'] = self.anaphase_onset_step
            else:
                self.anaphase_onset = self.distrdata['anaphase_onset']
                self.anaphase_onset_step = self.distrdata['anaphase_onset_step']

            # Calculate the success of anaphase
            self.CalcSegregationSuccess()
            self.distrdata['segregation_success'] = self.segregation_success
            print "segregation success: {}".format(self.segregation_success)

        #print "anaphase_onset: {}".format(self.anaphase_onset)
        #print "stored as: {}".format(self.distrdata['anaphase_onset'])
        return save_flag

    def XlinkDistanceHist(self, xlink_data, xstate=r'final', **kwargs):
        """ Function to bin xlink distances whether on a spindle or from SPB.
        Inputs: dictionary of { times: [list of xlink position in sim units] },
                integrate or final state used, kwargs
        Outputs: bin_positions, number of xlinks in each bin,
                 and total number of data points
        """
        # Turn xlink_data dictionary into pandas data frame indexed by the frametime
        df = pd.DataFrame.from_dict(xlink_data, orient='index').sort_index()
        # Start time info to determine when to start collecting data.
        # TODO Add option and think about this more
        time = self.timedata['time']
        if self.timedata['succ_info']['succ']: start_time = self.timedata['succ_info']['start_time'] + time[15]
        elif len(time) > 100:
            start_time = time[100]
        else: start_time = time[-1] # Should this be changed to something else?
        # Define scope of histogram distributions
        if 'norm' in kwargs:
            if kwargs['norm']:
                units_factor = 1.0
                x_axis_range = (-.4, 1.4)
            else:
                units_factor = uc['um'][1]
                x_axis_range = (0.0, 2.75)
        # Define how data will be collected
        if xstate == r'final':
            distances = np.array(df.iloc[-1][0])*units_factor
            weights = np.array(df.iloc[-1][1])
        else:
            data = df.loc[start_time:].sum()
            distances = np.array(data[0])*units_factor
            weights = np.array(data[1])
        # Make histogram and collect make locations for histogram bars(bin_mids)
        hist, bin_edges = np.histogram(distances, weights=weights, bins=110, range=x_axis_range)
        bin_mids = moving_average(bin_edges)
        return (bin_mids, hist, distances.size)

    def NumXlinks(self, frametime):
        """Create a list of the number croslinks for a xlink species.
        index n = stage n, index 3 = Total number
        """
        xl_stage_list = self.xlinkreader.GenerateNumXlinks(0)
        # print "SpindlePostAnalysis.NumXlinks xl_stage_list: {}".format(xl_stage_list)
        if 'num_xlinks' not in self.timedata:
            self.timedata['num_xlinks'] = {}
        self.timedata['num_xlinks'][frametime] = xl_stage_list

    ### Analysis of Chromosomes
    def KCSpindle1D(self, frametime):
        # Do all 3 spindle lengths!
        for k,v in self.ldict.iteritems():
            [distances, weights] = self.chromosomereader.GenerateKCSpindle1D(self.spindlereader.spbs, target_length = v, use_target = True)
            if 'kc_spindle_1d' not in self.kcdata:
                self.kcdata['kc_spindle_1d'] = {}
            if k not in self.kcdata['kc_spindle_1d']:
                self.kcdata['kc_spindle_1d'][k] = {}
            self.kcdata['kc_spindle_1d'][k][frametime] = [distances, weights]
        # Do the final one for spindles of length > 1 micron
        [distances, weights] = self.chromosomereader.GenerateKCSpindle1D(self.spindlereader.spbs, target_length = 40.0, use_target = False)
        if 'full' not in self.kcdata['kc_spindle_1d']:
            self.kcdata['kc_spindle_1d']['full'] = {}
        self.kcdata['kc_spindle_1d']['full'][frametime] = [distances, weights]

    #def MakeKCDistributionData(self):
    #    # Kinetochore 1d spindle
    #    name = 'kc_spindle_1d'
    #    self.distrdata[name] = {}
    #    # Integrated configuration
    #    self.distrdata[name]['integrated'] = {}
    #    for k,v in self.ldict.iteritems():
    #        self.distrdata[name]['integrated'][k] = {}
    #        d = self.distrdata[name]['integrated'][k]
    #        d['mid_points'], d['hist'], d['n_points'] = self.KCSpindle1DHist(self.kcdata[name][k], xstate=r'integrated')
    #    # Do the final one
    #    self.distrdata[name]['integrated']['full'] = {}
    #    d = self.distrdata[name]['integrated']['full']
    #    d['mid_points'], d['hist'], d['n_points'] = self.KCSpindle1DHist(self.kcdata[name]['full'], xstate=r'integrated')

    #    del self.kcdata

    def KCSpindle1DHist(self, kc_data, xstate=r'final'):
        distances = np.array([])
        weights = np.array([])

        if xstate == r'final':
            distances = np.array([])
            weights = np.array([])
        else:
            for k,v in kc_data.iteritems():
                distances = np.append(distances, v[0])
                weights = np.append(weights, v[1])

        hist, bin_edges = np.histogram(distances, weights=weights, bins=12, range=(-0.10, 1.10))
        bin_mids = moving_average(bin_edges)

        return (bin_edges, bin_mids, hist, distances.size)

    def InterKCStretchHist(self):
        distances = np.array([])
        weights = np.array([])

        time = self.timedata['time']
        dist_dict = self.timedata['kc_distance']
        dist_arr = [dist_dict[ts] for ts in time]

        nchromo = len(dist_arr[0]) / 2

        # Plot doublets of each chromosome
        for ic in xrange(nchromo):
            dist0 = np.zeros(len(dist_arr))
            dist1 = np.zeros(len(dist_arr))
            for x in xrange(len(dist_arr)):
                dist0[x] = dist_arr[x][2*ic]
                dist1[x] = dist_arr[x][2*ic+1]
            distances = np.append(distances, np.fabs(dist1 - dist0))

        weights = np.ones((len(distances)))

        hist, bin_edges = np.histogram(distances, weights=weights, bins=55, range=(0.0, 2.75))
        bin_mids = moving_average(bin_edges)

        return (bin_edges, bin_mids, hist, distances.size)

    # Calculate the segregation success if anaphase started
    def CalcSegregationSuccess(self):
        threshold = 0.3
        #print "calculating anaphase segregation success"
        time = np.asarray(self.timedata['time'])
        anaphase_b_time = self.anaphase_onset + 2.0
        print "Anaphase A start at: {}".format(self.anaphase_onset)
        print "Anaphase A end at:   {}".format(anaphase_b_time)
        # Now look at the distance of all 6 kinetochores from their respective poles
        dist_dict = self.timedata['kc_distance']
        dist_arr = [dist_dict[ts] for ts in time]
        nchromo = len(dist_arr[0]) / 2

        # Get the maximum spindle length
        spb_sep = self.timedata['spb_separation']
        spbsep_arr = [spb_sep[ts] for ts in time]
        spbsep_arr = np.array(spbsep_arr)*uc['um'][1] # Convert units
        final_spindle_separation = spbsep_arr[-1]
        #print "final spindle length = {}".format(final_spindle_separation)

        successful_segregation = np.zeros(nchromo)

        # Loop over chromosomes
        for ic in xrange(nchromo):
            dist0 = np.zeros(len(dist_arr))
            dist1 = np.zeros(len(dist_arr))
            for x in xrange(len(dist_arr)):
                dist0[x] = dist_arr[x][2*ic]
                dist1[x] = dist_arr[x][2*ic+1]
            adj_dist0 = dist0[time < anaphase_b_time]
            adj_dist1 = dist1[time < anaphase_b_time]

            if (adj_dist0[-1] < threshold) and (np.fabs(adj_dist1[-1] - final_spindle_separation) < threshold):
                successful_segregation[ic] = 1
            elif (adj_dist1[-1] < threshold) and (np.fabs(adj_dist0[-1] - final_spindle_separation) < threshold):
                successful_segregation[ic] = 1

        #print "successful_segregation = {}".format(successful_segregation)
        self.segregation_success = np.mean(successful_segregation)

    # Generate the kinetochore attachment types at this particular time
    def KCAttachType(self, frametime):
        atypes = self.chromosomereader.GenerateAttachmentTypes(self.spindlereader.spbs, self.spindlereader.microtubules)
        if 'kc_atypes' not in self.timedata:
            self.timedata['kc_atypes'] = {}
        self.timedata['kc_atypes'][frametime] = atypes

    # How many end-on attachments does each chromosome have at this time?
    def KCNAttachedEndon(self, frametime):
        endon = self.chromosomereader.GenerateEndOnAttachments(self.spindlereader.microtubules)
        if 'kc_nend' not in self.timedata:
            self.timedata['kc_nend'] = {}
        self.timedata['kc_nend'][frametime] = endon

    # How many are bound by the chromosomes?
    def KCOccupancy(self, frametime):
        nbound = self.chromosomereader.GenerateNBound()
        if 'kc_occupancy' not in self.timedata:
            self.timedata['kc_occupancy'] = {}
        if 'kc_occupancy_x' not in self.timedata:
            self.timedata['kc_occupancy_x'] = {}
        self.timedata['kc_occupancy'][frametime] = nbound[0]
        self.timedata['kc_occupancy_x'][frametime] = nbound[1]

    # Grab the current kMT lifetimes
    def kMTLifetimes(self, frametime):
        current_lifetimes = self.chromosomereader.GetAttachmentLifetimes()
        if 'attachment_lifetimes' not in self.timedata:
            self.timedata['attachment_lifetimes'] = {}
        self.timedata['attachment_lifetimes'][frametime] = current_lifetimes

    def KCDistance(self, frametime):
        distance = self.chromosomereader.GetKCDistance(self.spindlereader.spbs)
        if 'kc_distance' not in self.timedata:
            self.timedata['kc_distance'] = {}
        self.timedata['kc_distance'][frametime] = np.array(distance) * uc['um'][1]

    ### Analysis of MTs
    def MTLengthDistributions(self, frametime):
        # MT lengths by spb
        lengths = self.spindlereader.GenerateMTLengthSPBDistributions()
        if 'mt_lengths' not in self.timedata:
            self.timedata['mt_lengths'] = {}
        self.timedata['mt_lengths'][frametime] = lengths
        # MT lengths by index
        lengths = self.spindlereader.GenerateMTLengthIndexDistributions()
        if 'mt_lengths_by_index' not in self.timedata:
            self.timedata['mt_lengths_by_index'] = {}
        self.timedata['mt_lengths_by_index'][frametime] = lengths

    def MTLengthHist(self, mt_data, **kwargs):
        df = pd.DataFrame.from_dict(mt_data, orient='index').sort_index()
        # Start time info to determine when to start collecting data.
        # TODO Add option and think about this more
        time = self.timedata['time']
        if self.timedata['succ_info']['succ']: start_time = self.timedata['succ_info']['start_time'] + time[15]
        elif len(time) > 100:
            start_time = time[100]
        else: start_time = time[-1]
        # Analysis of the mt lengths from a given time in run
        data = df.loc[start_time:].sum()
        dist0 = np.array(data[0])*(uc['um'][1])
        dist1 = np.array(data[1])*(uc['um'][1])
        dist2 = np.append(dist0, dist1)
        hist0, bin_edges = np.histogram(dist0, bins=110, range=(0.0, 3.75))#, density=True
        hist1, bin_edges = np.histogram(dist1, bins=110, range=(0.0, 3.75))#, density=True
        hist2, bin_edges = np.histogram(dist2, bins=110, range=(0.0, 3.75))#, density=True
        hist_arr = [hist0, hist1, hist2]
        n_points_arr = [dist0.size, dist1.size, dist2.size]
        bin_mids = moving_average(bin_edges)
        return (hist_arr, n_points_arr, bin_mids)

    def AvgMTlength(self, frametime):
        """Create a list of the average MT lengths at each time step.
        index n = spb n, index 2 = all
        """
        avg_mt_length = self.spindlereader.GenerateAvgMTLength()
        if 'avg_mt_length' not in self.timedata:
            self.timedata['avg_mt_length'] = {}
        self.timedata['avg_mt_length'][frametime] = avg_mt_length

    def InterpolarFraction(self, frametime):
        interpolar_fraction, interpolar_length_fraction = self.spindlereader.GenerateInterpolarFraction()
        if 'interpolar_fraction' not in self.timedata:
            self.timedata['interpolar_fraction'] = {}
        self.timedata['interpolar_fraction'][frametime] = interpolar_fraction

        if 'interpolar_length_fraction' not in self.timedata:
            self.timedata['interpolar_length_fraction'] = {}
        self.timedata['interpolar_length_fraction'][frametime] = interpolar_length_fraction

    def AvgMTDirectorSplay(self, frametime):
        spb_mt_director, mt_splay_data = self.spindlereader.GenerateMTSplay()

        if 'spb_mt_directors' not in self.timedata:
            self.timedata['spb_mt_directors'] = {}
        self.timedata['spb_mt_directors'][frametime] = spb_mt_director

        if 'mt_splay' not in self.timedata:
            self.timedata['mt_splay'] = {}
        self.timedata['mt_splay'][frametime] = mt_splay_data

    ### Fitness functions
    def FitnessStreams(self, frametime):
        '''Generate the matching conditions based on the old spindle_analysis_new code that writes
           out the length streams in a particular way. Do this because this is the form that the
           matching MATLAB matrices are in

           The conditions are:
                spindle length around target +/- 2
                at least 12 of antiparallel overlap (300 nm)
        '''
        # Use the information from the already generate InterpolarFraction to determine success
        spb_sep = self.timedata['spb_separation'][frametime]
        is_spindle = self.timedata['interpolar_fraction'][frametime] > 0.0

        for k,v in self.ldict.iteritems():
            name = self.name_ldict[k]
            if 'fitness' not in self.timedata:
                self.timedata['fitness'] = {}
            if self.name_ldict[k] not in self.timedata['fitness']:
                self.timedata['fitness'][name] = {}

        if (is_spindle):
            for k,v in self.ldict.iteritems():
                if (spb_sep > (v - 2.0)) and (spb_sep < (v + 2.0)):
                    streams = self.spindlereader.BinLengths()
                    name = self.name_ldict[k]

                    if 'fitness' not in self.timedata:
                        self.timedata['fitness'] = {}
                    if self.name_ldict[k] not in self.timedata['fitness']:
                        self.timedata['fitness'][name] = {}
                    self.timedata['fitness'][name][frametime] = streams

    ### Analysis of SPBs
    def SPBSeparation(self, frametime):
        spb_sep = self.spindlereader.GenerateSPBSeparation()
        if 'spb_separation' not in self.timedata:
            self.timedata['spb_separation'] = {}
        self.timedata['spb_separation'][frametime] = spb_sep

    ### Other Analysis
    def SeedSuccess(self, **kwargs):
        """ Function to make the success info dictionary in timedata """
        if 'crit_func' not in kwargs: kwargs['crit_func'] = interpolar_fraction_threshold
        if 'seed_num' not in kwargs: kwargs['seed_num'] = 'Nan'

        self.timedata['succ_info'] = OrderedDict()

        success_spindle = kwargs['crit_func'](self.timedata, **kwargs)
        succ = int(bool(success_spindle))
        st = 0
        self.timedata['succ_info']['sd_num'] = kwargs['seed_num']
        self.timedata['succ_info']['succ'] = succ
        if succ:
            # Get the time the spindle starts
            st = self.timedata['time'][success_spindle[0]]
            # CJE subtract off the start time off the wait time
            st = st - self.wait_time
            # Calc spb_seperations, 15 frames = 22.5 secs
            try:
                spb_sep = pd.Series(self.timedata['spb_separation'])
                self.timedata['succ_info']['SpindleSepAvg'] = spb_sep.iloc[success_spindle[0] + 15]
            except:
                if 'spb_separation' not in self.timedata:
                    print "spb_separation not in PostAnalysis.timedata. Cannot make SpindleSepAvg."
                else:
                    print "Cannot make SpindleSepAvg for some reason."
        else:
            self.timedata['succ_info']['SpindleSepAvg'] = 0

        self.timedata['succ_info']['start_time'] = st
        print "tagged success: {}, starttime: {}".format(succ, st)

        return

    # Generate the informatino from the thermo file!
    def AnalyzeThermo(self, frametime):
        #if 'mt_virial' not in self.timedata:
        #    self.timedata['mt_virial'] = {}
        if 'spb_forces' not in self.timedata:
            self.timedata['spb_forces'] = {}
        spindle_vector = self.spindlereader.GetSpindleVector()
        spindle_forces = self.thermoreader.GenerateSpindleForce(spindle_vector)

        self.timedata['spb_forces'][frametime] = spindle_forces * uc['pN'][1]

        # Do some really nasty, interesting things with quaternions and rotations to get
        # the forces that are acting on the spindle itself
        rspb = self.spindlereader.GetSPBPositions()

        # Generate the two vectors we need, the tomor vector, and the spindle vector
        rspb0 = rspb[0,:]
        rspb1 = rspb[1,:]

        tangent_forces = self.thermoreader.GenerateForceInformation(rspb0, rspb1)
        if 'tangent_forces' not in self.timedata:
            self.timedata['tangent_forces'] = {}
        self.timedata['tangent_forces'][frametime] = tangent_forces * uc['pN'][1]

        # Try to generate the virial information for this run
        #self.thermoreader.GenerateVirialMT(rspb0, rspb1)

    # Generate the information on the forces on the two spindle poles
    def AnalyzeForces(self, frametime):
        if 'pole_forces' not in self.timedata:
            self.timedata['pole_forces'] = {}
            self.timedata['pole_forces'][0] = {}
            self.timedata['pole_forces'][1] = {}
        spindle_vector = self.spindlereader.GetSpindleVector()
        spindle_forces = self.forcereader.GenerateSpindleForces(spindle_vector, uc)

        # Both pole forces
        self.timedata['pole_forces'][0][frametime] = spindle_forces[0]
        self.timedata['pole_forces'][1][frametime] = spindle_forces[1]

    ###  Graphing functionality
    def GraphSPBStage1XlinkDistance(self, ax, label, color='b', me=1, xlabel = True, xstate=r'final', xdata_in = None):
        ax.set_title("Stage1 Xlink SPB Separations")
        ax.set_ylabel(r'Probability Density')
        if xlabel:
            ax.set_xlabel(r'Distance ($\mu$m)')

        # Generate a histogram of the data (look at override version first!)
        if xdata_in:
            xdata = xdata_in
        elif xstate == r'final':
            xdata = self.GetFinalState('spb_stage1_xlink_distance')
        else:
            xdata = self.GetIntegratedState('spb_stage1_xlink_distance')
        # print "distances: {}".format(xdata[0])
        # print "weights: {}".format(xdata[1])
        # print "lengths: {}, {}".format(len(xdata[0]), len(xdata[1]))
        hist, bin_edges = np.histogram(xdata[0], weights=xdata[1], bins=110, range=(0.0, 2.75), density=True)
        # print "After it to the histogram without error"
        # Generate Bin Mids
        bin_mids = moving_average(bin_edges)
        ax.plot(bin_mids,
                hist,
                color = color,
                label = label )

        return xdata

    def GraphSPBStage2XlinkDistance(self, ax, label, color='b', me=1, xlabel = True, xstate=r'final', xdata_in = None):
        ax.set_title("Stage2 Xlink SPB Separations")
        ax.set_ylabel(r'Probability Density')
        if xlabel:
            ax.set_xlabel(r'Distance ($\mu$m)')

        # Generate a histogram of the data (look at override version first!)
        if xdata_in:
            xdata = xdata_in
        elif xstate == r'final':
            xdata = self.GetFinalState('spb_stage2_xlink_distance')
        else:
            xdata = self.GetIntegratedState('spb_stage2_xlink_distance')
        # print "distances: {}".format(xdata[0])
        # print "weights: {}".format(xdata[1])
        # print "lengths: {}, {}".format(len(xdata[0]), len(xdata[1]))
        hist, bin_edges = np.histogram(xdata[0], weights=xdata[1], bins=110, range=(0.0, 2.75), density=True)
        # Generate Bin Mids
        bin_mids = moving_average(bin_edges)
        ax.plot(bin_mids,
                hist,
                color = color,
                label = label )

        return xdata

    def GraphXlinkSpindle1D(self, ax, label, color='b', me=1, xlabel = True, xstate=r'final', xdata_in = None):
        ax.set_title("Xlink Spindle Distribution")
        ax.set_ylabel(r'Probability Density')
        if xlabel:
            ax.set_xlabel(r'Normalized Spindle Distance')

        # Generate a histogram of the data
        if xdata_in:
            xdata = xdata_in
        elif xstate == r'final':
            xdata = self.GetFinalState('xlink_spindle_1d')
        else:
            xdata = self.GetIntegratedState('xlink_spindle_1d')
        #print "distances: {}".format(xdata[0])
        #print "weights: {}".format(xdata[1])
        hist, bin_edges = np.histogram(xdata[0], bins=130, range=(-0.15, 1.15), density=True)
        # Generate Bin Mids
        bin_mids = moving_average(bin_edges)
        ax.plot(bin_mids,
                hist,
                color = color,
                label = label )
        return xdata

    def GraphKCSpindle1D(self, ax, label, color='b', me=1, xlabel=True, xstate=r'final', xdata_in = None):
        ax.set_title("Kinetochore Spindle Distribution")
        ax.set_ylabel(r'Probability Density')
        if xlabel:
            ax.set_xlabel(r'Normalized Spindle Distance')

        # Generate a histogram of the data
        if xdata_in:
            xdata = xdata_in
        elif xstate == r'final':
            xdata = self.GetFinalState('kc_spindle_1d')
        else:
            xdata = self.GetIntegratedState('kc_spindle_1d')
        hist, bin_edges = np.histogram(xdata[0], bins=130, range=(-0.15, 1.15), density=True)
        # Generate Bin Mids
        bin_mids = moving_average(bin_edges)
        ax.plot(bin_mids,
                hist,
                color = color,
                label = label )
        return xdata


    ### Fun save/load hacks
    def savestate(self):
        # pandas = pd.DataFrame.from_dict(self.timedata)
        time_filename = os.path.join(self.sd_dir_path, "SpindlePostAnalysisSeed.pickle")
        distr_filename = os.path.join(self.sd_dir_path, "SpindleDistributionAnalysis.pickle")
        self.save(time_filename, self.timedata)
        self.save(distr_filename, self.distrdata)

    def save(self, filename, datadict):
        with open(filename, 'wb') as f:
            pickle.dump(datadict,f)
        # TODO Code for dataframe pickling, implement this when you have time
        # self.test_df = pd.DataFrame.from_dict(self.timedata, orient = 'index')
        # dframes = {}
        # for k,v in self.timedata.iteritems():
        #     if k == 'time': continue
        #     dframes[k] = pd.DataFrame.from_dict(v, orient='index')
        #     dframes[k].sort_index(inplace=True)
        # # df1 = pd.DataFrame.from_dict(self.timedata['spb_stage1_xlink_distance'], orient='index')
        # # df2 = pd.DataFrame.from_dict(self.timedata['spb_stage2_xlink_distance'], orient='index')
        # # df1.sort_index(inplace=True)
        # # df2.sort_index(inplace=True)
        # self.test_df = pd.concat(dframes.values(), axis=1, keys=dframes.keys())
        # self.test_df.index.name = 'time'
        # self.test_df['time'] = self.test_df.index
        # # self.test_df.reset_index(drop=True)

        # # self.test_df = pd.DataFrame.from_dict(self.timedata['spb_stage1_xlink_distance'])
        # # self.test_df.sort_index(inplace=True)
        # # self.test_df.columns=['distances','weights']

        # # with open('dataframe_pickle.pkl', 'w') as f:
        # self.test_df.to_pickle('dataframe_pickle.pkl')
        # print self.test_df
        # # print self.test_df.loc['time']

    def load(self, filename, datadict):
        # TODO implement dataframe pickling to save time on a number of things
        with open(filename, 'rb') as f:
            datadict.update(pickle.load(f))



##########################################
if __name__ == "__main__":
    uc = SpindleUnitDict()
    filename = os.path.realpath("SpindlePostAnalysisSeed.pickle")
    if sys.argv[2] == 'analyze':
        force_analyze = True
    else:
        force_analyze = False
    s = SpindlePostAnalysis(sys.argv[1], uc)
    s.CheckLoadAnalyze(filename, force_analyze)

    fig, ax = plt.subplots()
    s.GraphSPBXlinkDistance(ax=ax, label=None, xstate=r'integrated')

    fig2, ax2 = plt.subplots()
    s.GraphXlinkSpindle1D(ax=ax2, label=None, xstate=r'integrated')

    fig3, ax3 = plt.subplots()
    s.GraphKCSpindle1D(ax=ax3, label=None, xstate=r'integrated')

    plt.show()

    s.savestate()
