#!/usr/bin/env python
# In case of poor (Sh**y) commenting contact adam.lamson@colorado.edu
# Basic
import sys, os, pdb
from collections import OrderedDict
import fnmatch
## Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
from seed_base import SeedBase
from spindle_unit_dict import SpindleUnitDict

from read_posit_spindle import ReadPositSpindle
from read_posit_chromosomes import ReadPositChromosomes

from criteria_funcs import *

import scipy.io as sio
from scipy.stats import ks_2samp
import pickle
import yaml

'''
Name: spindle_fitness.py
Description:
Input:
Output:
'''

uc = SpindleUnitDict()

#Class definition
class SpindleFitness(SeedBase):
    def __init__(self, path, strain_name):
        SeedBase.__init__(self, path)
        self.P = {}
        self.ldict = {}
        self.ldict['long'] = '2.15um'
        self.ldict['med'] = '1.825um'
        self.ldict['short'] = '1.05um'
        self.varnames = ['interpolar',
                         'pairing_length_max',
                         'total',
                         'angles']
        self.chromosome_seconds = 0.0
        self.lstream_model = []
        self.strain_name = strain_name

    def MakeDataDict(self, dat_file_list='', header=None):
        #TODO Parse options if necessary
        if not dat_file_list:
            dat_file_list = ['spindle_separation.dat', 
                             'frac_interpolar.dat', 
                             'avg_mt_length.dat'] 
        SeedBase.MakeDataDict(self, dat_file_list, header)
        return

    def LoadExperimentDistributions(self, filename_wt, filename_lstream):
        raw_matlab = sio.loadmat(filename_wt, squeeze_me=True)
        self.experiment = {}
        self.experiment['interpolar'] = {}
        self.experiment['pair'] = {}
        self.experiment['lengths'] = {}
        self.experiment['angles'] = {}
        for k,v in self.experiment.iteritems():
            for k2,v2 in self.ldict.iteritems():
                mstring = "{}_{}".format(k,k2)
                self.experiment[k][k2] = raw_matlab[mstring]
        self.experiment['lstream'] = sio.loadmat(filename_lstream, squeeze_me=True)['lstream']
        #print "{}".format(self.experiment)

    def LoadStrainData(self):
        print "Reading in expeirmental data from type {}".format(self.strain_name)
        strain_base = self.strain_name.split('_')[0]
        strain_type = self.strain_name.split('_')[1]
        self.strain_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', strain_base, strain_type)
        listoffiles = os.listdir(self.strain_dir)
        pattern = '*.csv'
        experiments = []
        for entry in listoffiles:
            if fnmatch.fnmatch(entry, pattern):
                experiments += [os.path.join(self.strain_dir, entry)]

        if self.strain_name == 'Klp56_wt':
            ### FIXME hardcoded by hand for when anaphase begins
            self.tracklist = [1, 3, 4, 6, 7, 8, 9, 10, 12, 13, 61, 149, 150, 151, 152, 153]
            self.timelist = [21.5, 0, 18, 28.5, 0, 0, 52.5, 28.25, 0, 0, 18.75, 5.25, 13.5, 3.5, 0, 13.25]
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            self.best_tracks = [1, 12, 13, 61, 150, 152, 153]
        elif self.strain_name == 'WT_wt':
            # These shoudl all be fine, no need to do the anaphase thing...
            self.tracklist = ['002-1',
                              '002-2',
                              '003-1',
                              '003-2',
                              '003-3',
                              '004',
                              '005-1',
                              '005-2',
                              '005-3',
                              '005-4',
                              '006',
                              '007',
                              '009',
                              '010-1',
                              '010-2',
                              '010-3']
            self.timelist = np.zeros(16)
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            self.best_tracks = ['002-1',
                                '002-2',
                                '003-1',
                                #'003-2',
                                '003-3',
                                '004',
                                '005-1',
                                '005-2',
                                '005-3',
                                '005-4',
                                '006',
                                '007',
                                '009',
                                '010-1',
                                '010-2',
                                '010-3']
            self.experiment_data = {}
            for experiment in experiments:
                track_names = os.path.basename(experiment).split('_')
                track_subtrack = track_names[1].split('-')
                if len(track_subtrack) > 1:
                    track_name = track_names[0] + '_' + track_subtrack[0] + '-' + track_subtrack[1]
                    track_number = track_subtrack[0] + '-' + track_subtrack[1]
                else:
                    track_name = track_names[0] + '_' + track_subtrack[0]
                    track_number = track_subtrack[0]
                df = pd.read_csv(experiment)
                self.experiment_data[track_number] = df
                self.measure_time = df['time_length'][1] - df['time_length'][0]
        elif self.strain_name == 'WT_Cen2':
            print "Fitness working on Cen2GFP Data"
            self.tracklist = ['001_1',
                              '001_2',
                              '002_1',
                              '002_2',
                              '002_3',
                              '003_1',
                              '003_3',
                              '003_4',
                              '004_1',
                              '004_2']
            self.timelist = np.zeros(len(self.tracklist))
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            # Set the anaphse by hand
            self.anaphase['001_2'] = 35.2
            self.anaphase['002_1'] = 21.87
            self.anaphase['002_2'] = 23.07
            self.anaphase['002_3'] = 15.6
            self.anaphase['003_3'] = 8.53
            self.anaphase['003_4'] = 5.2
            self.anaphase['004_1'] = 21.2
            self.anaphase['004_2'] = 27.87
            self.best_tracks = ['001_1',
                                '001_2',
                                '002_1',
                                '002_2',
                                '002_3',
                                #'003_1',
                                '003_3',
                                '003_4',
                                '004_1',
                                '004_2']
            self.experiment_data = {}
            for experiment in experiments:
                track_names = os.path.basename(experiment).split('_')
                track_name     = track_names[2]
                track_subtrack = track_names[3].split('.')[0]

                track_number = track_name + '_' + track_subtrack
                df = pd.read_csv(experiment, names = ["length", "kc0_dist", "kc1_dist", "length_error", "kc0_dist_error", "kc1_dist_error"])
                # add the time information to the length...
                self.measure_time = 8.0/60.0
                df['time_length'] = df.index.values * self.measure_time
                # Offset of 2 seconds for the kc timing
                df['time_kc'] = df['time_length'] + 2.0/60.0
                self.experiment_data[track_number] = df
        elif self.strain_name == 'Cerevisiae_basic':
            print "Fitness working on Cerevisiae basic data, length only!"
            self.tracklist = ['001_1',
                              '002_1',
                              '003_1',
                              '004_1']
            self.timelist = np.zeros(len(self.tracklist))
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            self.best_tracks = ['001_1',
                                '002_1',
                                '003_1',
                                '004_1']
            self.experiment_data = {}
            for experiment in experiments:
                track_names = os.path.basename(experiment).split('_')
                track_name     = track_names[2]
                track_subtrack = track_names[3].split('.')[0]

                track_number = track_name + '_' + track_subtrack
                df = pd.read_csv(experiment, names = ["length"])
                # add the time information to the length...
                self.measure_time = 12.0/60.0
                df['time_length'] = df.index.values * self.measure_time
                self.experiment_data[track_number] = df
        elif self.strain_name == 'Cerevisiae_length':
            print "Fitness working on Cerevisiae length data from s288c strain, length only!"
            self.tracklist = ['001_1',
                              '002_1',
                              '003_1',
                              '004_1',
                              '005_1']
            self.timelist = np.zeros(len(self.tracklist))
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            self.best_tracks = ['001_1',
                                '002_1',
                                '003_1',
                                '004_1',
                                '005_1']
            self.experiment_data = {}
            for experiment in experiments:
                track_names     = os.path.basename(experiment).split('_')
                track_name      = track_names[2]
                track_subtrack  = track_names[3].split('.')[0]

                track_number = track_name + '_' + track_subtrack
                df = pd.read_csv(experiment, names = ["length", "length_error"])
                self.measure_time = 5.0/60.0
                df['time_length'] = df.index.values * self.measure_time
                self.experiment_data[track_number] = df

        else:
            print "Something has gone horribly wrong"
            sys.exit(1)

    def StatisticalTests(self, data):
        self.KSTest('interpolar', 'interpolar', data)
        self.KSTest('pairing_length_max', 'pair', data)
        self.KSTest('total', 'lengths', data)
        self.KSTest('angles', 'angles', data)

    def KSTest(self, varname, matlabname, data):
        self.P[varname] = {}

        #print "KS test: {}".format(varname)

        for k,v in self.ldict.iteritems():
            if varname not in data[v]:
                self.P[varname][k] = -10.0
                continue
            else:
                indata = np.asarray(data[v][varname])
            #print "indata: {}".format(indata)
            if indata.size == 0:
                self.P[varname][k] = -10.0
            else:
                inres = ks_2samp(indata, self.experiment[matlabname][k])
                #print "inres[{}][{}] = {}".format(varname, k, inres)
                #print "log10/100[{}][{}] = {}".format(varname, k, np.log10(inres.pvalue)/100.0)
                if inres.pvalue <= 0.0:
                    self.P[varname][k] = -10.0
                else:
                    self.P[varname][k] = np.log10(inres.pvalue)/100.0

    def CalcTimeSeriesCorrelation(self, timedata, spbsep_dict):
        # First, find the max time of the spindle calculation
        # Also, change the time calculation to the nmeasure calculation from the experiment sampling
        self.nmeasure = np.int(self.measure_time / ((timedata[1] - timedata[0])))

        # Calculate the new time measure
        time = timedata[0::self.nmeasure]

        # Also the spb separation vector
        spbsep_arr = [spbsep_dict[ts] for ts in timedata]
        spbsep_arr = np.array(spbsep_arr) * uc['um'][1]
        spbsep_arr = spbsep_arr[0::self.nmeasure]
        self.max_time = time[-1]

        # Adjust the experiments for the onset of anaphase and max time of the simulations
        max_correlation = 0.0
        avg_correlation = 0.0
        avg_corr = 0.0

        # Create the model timeseries of data
        model_ts = pd.Series(spbsep_arr)

        self.experiment_data_adj = {}
        for track in self.best_tracks:
            df = self.experiment_data[track]
            # Correct for anaphase behavior and the max_time behavior
            if self.anaphase[track] != 0:
                df = df[df.time_length <= self.anaphase[track]]
            # Correct for max time
            df = df[df.time_length <= self.max_time]
            # Save off for adjustments for later computation
            self.experiment_data_adj[track] = df

            # Run the correlation measure against this particular expeirmental seed
            experiment_ts = df['length']
            corr = model_ts.corr(experiment_ts)
            if corr > max_correlation:
                max_correlation = corr
            avg_corr += corr
        avg_correlation = avg_corr / len(self.best_tracks)
            
        # Unload the pandas dataframes to save space, and save information on correlation
        #self.experiment_data = None
        if np.isnan(avg_correlation):
            avg_correlation = 0.0
        self.length_correlation_avg = avg_correlation


    def CalcLengthFitness(self, timedata, spbsep_dict):
        self.flstream_exp = self.experiment['lstream'][self.experiment['lstream'] <= 2.75]
        # Get the time and spb separation variables
        length_arr = [spbsep_dict[ts] for ts in timedata]
        lspindle = np.array(length_arr) * uc['um'][1]
        t = np.array(timedata)
        #print "time: {}".format(timedata)
        #print "length_arr: {}".format(lspindle)

        t0 = t[np.argmax(lspindle > 1.1)] + 10.0
        #print "t0: {}".format(t0)
        mask = np.zeros(1)
        if t0.size == 0:
            mask = np.zeros(t.size)
        else:
            mask = (t >= t0)

        # every 15 seconds in array notation
        measuretime = 15.0
        nmeasure = np.int(measuretime / ((t[1] - t[0])*60.0))
        #print "nmeasure: {}".format(self.nmeasure)

        self.lstream_model = []
        if sum(mask) > 5:
            lspindle = lspindle[mask]
            lspindle = lspindle[0::nmeasure]
            self.lstream_model = lspindle

    def ChromosomeFitness(self, timedata, spbsep_dict, kc_attachments, kc_distance, kc_nend):
        t = np.array(timedata)
        length_arr = [spbsep_dict[ts] for ts in timedata]
        attach_arr = [kc_attachments[ts] for ts in timedata]
        kcdist_arr = [kc_distance[ts] for ts in timedata]
        kcnend_arr = [kc_nend[ts] for ts in timedata]

        lspindle = np.array(length_arr) * uc['um'][1]

        # Run the original spindle success biroeintation test
        self.PureAttachFitness(timedata, lspindle, attach_arr)

        # Calculate the biorientation fitness information
        self.BiorientationFitness(timedata, attach_arr, kcnend_arr)

        # Do the chromosome information for boht the SPB-KC distance, and the KC 1D spindle distance
        self.SPBKCDistances(t, lspindle, kcdist_arr)

    # Just see how long we have attached for to get chromosome seconds
    def PureAttachFitness(self, timedata, lspindle, attach_arr):
        target_spindle = 1.0
        target_amphitelic = 3

        # Choose a spindle length greater than 1 micron
        atypeint = 4
        nchromo = len(attach_arr[0])
        aa_ = np.zeros(len(attach_arr))
        for x in xrange(len(attach_arr)):
            for ic in xrange(nchromo):
                attach = attach_arr[x][ic]
                if attach == atypeint:
                    aa_[x] += 1.0

        condition_ = np.logical_and(lspindle >= target_spindle, aa_ >= target_amphitelic)
        # Abuse the start/end time stuff to get the fitness
        start_end_times = get_start_end_time(timedata, condition_, duration = 0.01, thresh = 0.9, break_length = 0.2, end_flag = False, chromosome_seconds = True)
        print "start_end_times: {}".format(start_end_times)
        self.chromosome_seconds = 0.0
        self.chromosome_seconds_fraction = 0.0
        if start_end_times:
            if any(isinstance(el, list) for el in start_end_times):
                for start_end in start_end_times:
                    start = timedata[start_end[0]]
                    end = timedata[start_end[1]]
                    print "start: {}, end: {}".format(start, end)
                    self.chromosome_seconds += (end - start) * 60.0
            else:
                start = timedata[start_end_times[0]]
                end = timedata[start_end_times[1]]
                self.chromosome_seconds += (end - start) * 60.0
        print "chromosome_seconds: {}".format(self.chromosome_seconds)
        self.chromosome_seconds_fraction = self.chromosome_seconds / (timedata[-1]*60.0)
        print "chromosome_seconds_fraction: {}".format(self.chromosome_seconds_fraction)

    # Calculate the new biorientation end fitness
    def BiorientationFitness(self, timedata, attach_arr, kcnend_arr):
        # We have full information on the amount of biorientation time, make use of it all, and don't use the 
        # Different smapling rate information
        times = timedata
        attachments = attach_arr
        nendon = kcnend_arr

        nchromo = len(attach_arr[0])
        naf = 2*sd.PostAnalysis.naf # Get the numbe of attachments per chromosome dynamically

        self.fbiorient = 0.0
        fbiorient = 0.0

        # Loop over all times
        for i in xrange(len(times)):
            for ic in xrange(nchromo):
                # Check if amphitelic
                if attachments[i][ic] == 4:
                    fbiorient += nendon[i][ic]

        self.fbiorient = fbiorient / (len(times) * nchromo * naf)
        print "fbiorient: {}".format(self.fbiorient)

    # List of distances from the SPB to KC in the model point, keep for the whole comparison later
    def SPBKCDistances(self, times, lspindle, kcdist):
        # First thing to do is convert times, lspindle, and kcdist into common measurement times
        times       = times[0::self.nmeasure]
        lengths     = lspindle[0::self.nmeasure]
        dist_model  = kcdist[0::self.nmeasure]

        # Then just save off for future use...
        self.kc_dist_model = dist_model

        # also keep the kc spindle 1d stuff here
        kc_spindle_model = []
        for i in xrange(len(lengths)):
            kc0_1d = dist_model[i][0] / lengths[i]
            kc1_1d = dist_model[i][1] / lengths[i]
            kc2_1d = dist_model[i][2] / lengths[i]
            kc3_1d = dist_model[i][3] / lengths[i]
            kc4_1d = dist_model[i][4] / lengths[i]
            kc5_1d = dist_model[i][5] / lengths[i]

            kc_spindle_model = np.append(kc_spindle_model, [kc0_1d, kc1_1d, kc2_1d, kc3_1d, kc4_1d, kc5_1d])
        self.kc_spindle_model = kc_spindle_model

    # New Cen2 strain fitness information!
    def Cen2Fitness(self, sd):
        # First, adjust for max time and things of that nature!
        # Get the maximum time of the simulations for later use
        self.experiment_data_adj = {}
        time = np.array(sd.PostAnalysis.timedata['time'])
        self.max_time = time[-1]
        # Prepare the tracks
        for track in self.best_tracks:
            df = self.experiment_data[track]
            # Correct for anaphase behavior and the max_time behavior
            if self.anaphase[track] != 0:
                df = df[df.time_length <= self.anaphase[track]]
            # Correct for max time
            df = df[df.time_length <= self.max_time]
            self.experiment_data_adj[track] = df


        # Single seed, to make sure to save information, etc, in self
        # Get the time, spindle length, kinetochore position, and interkc stretch as separate objects
        time = np.array(sd.PostAnalysis.timedata['time'])
        spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
        dist_dict = sd.PostAnalysis.timedata['kc_distance']
        spindle_length = np.array([spbsep_dict[ts] for ts in time])*uc['um'][1]
        kc_distance = np.array([dist_dict[ts] for ts in time]) # KC doesn't need unit conversion, already done!!!!!

        # Generate the actual time traces, matching sampling
        self.nchromo = len(kc_distance[0])/2
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))
        time = time[0::self.nmeasure]
        spindle_length = spindle_length[0::self.nmeasure]
        kc_distance = kc_distance[0::self.nmeasure]

        # FITNESS calc here!
        # Do all sorts of fun looping etc to compare the proper fitness for this seed
        # See also Validation/cen2_fitness.py
        # What is my maximum correlation?
        self.max_correlation_combined = 0.0
        self.max_length_correlation = 0.0
        self.max_kc_distance_correlation = 0.0
        self.max_kc_sep_correlation = 0.0
        self.max_correlation_track = -1

        # What about the averages?
        self.avg_correlation_combined = 0.0
        self.avg_length_correlation = 0.0
        self.avg_kc_distance_correlation = 0.0
        self.avg_kc_sep_correlation = 0.0

        model_ts_length = pd.Series(spindle_length)
        ntracks = len(self.best_tracks)
        for it, track in enumerate(self.best_tracks):
            df = self.experiment_data_adj[track]
            exp_times = df['time_kc']
            exp_dist0 = df['kc0_dist']
            exp_dist1 = df['kc1_dist']
            exp_length = df['length']
            exp_time_length = df['time_length']

            # Length for this particular comparison
            corr_length = model_ts_length.corr(exp_time_length)

            # Now do the complicated switching of chromosome and kinetochore movements
            corr0 = np.zeros(self.nchromo)
            corr1 = np.zeros(self.nchromo)
            corr_kc_sep = np.zeros(self.nchromo)
            for ic in xrange(self.nchromo):
                dist0 = pd.Series(kc_distance[:,2*ic])
                dist1 = pd.Series(kc_distance[:,2*ic+1])
                # Do all 4 correlations that cross
                corr_0_0 = dist0.corr(exp_dist0)
                corr_1_1 = dist1.corr(exp_dist1)
                corr_0_1 = dist0.corr(exp_dist1)
                corr_1_0 = dist1.corr(exp_dist0)

                # Pick the best set
                if (corr_0_0 + corr_1_1) > (corr_0_1 + corr_1_0):
                    corr0[ic] = corr_0_0
                    corr1[ic] = corr_1_1
                else:
                    corr0[ic] = corr_0_1
                    corr1[ic] = corr_1_0

                #print "  Chromosome {} Dist Pearson correlation {} : {} = {}, {}".format(ic, mseed, track, corr0[ic], corr1[ic])

                # KC separation
                kc_sep = pd.Series(np.fabs(kc_distance[:,2*ic] - kc_distance[:,2*ic+1]))
                kc_sep_exp = pd.Series(np.fabs(exp_dist0 - exp_dist1))
                corr_kc_sep[ic] = kc_sep.corr(kc_sep_exp)

                #print "  Chromosome {} Sep Pearson correlation {} : {} = {}".format(ic, mseed, track, corr_kc_sep[ic])

            # Average the kinetochore correlation measures
            corr_kc_combined = (np.mean(corr0) + np.mean(corr1))/2.0
            corr_kc_sep_combined = np.mean(corr_kc_sep)

            #print "KC Pearson correlation {} : {} = {}".format(mseed, track, corr_kc_combined)
            #print "KC Separation Pearson correlation {} : {} = {}".format(mseed, track, corr_kc_sep_combined)

            if (corr_length + corr_kc_combined + corr_kc_sep_combined) > self.max_correlation_combined:
                self.max_correlation_combined = corr_length + corr_kc_combined + corr_kc_sep_combined
                self.max_length_correlation = corr_length
                self.max_kc_distance_correlation = corr_kc_combined
                self.max_kc_sep_correlation = corr_kc_sep_combined
                self.max_correlation_track = track

            # Averages
            self.avg_correlation_combined += (corr_length + corr_kc_combined + corr_kc_sep_combined)
            self.avg_length_correlation += corr_length
            self.avg_kc_distance_correlation += corr_kc_combined
            self.avg_kc_sep_correlation += corr_kc_sep_combined
            
        print "   max correlation with experimental track {} = {} (length: {}, kc: {}, kc_sep: {})".format(self.max_correlation_track, self.max_correlation_combined, self.max_length_correlation, self.max_kc_distance_correlation, self.max_kc_sep_correlation)
        self.avg_correlation_combined /= np.float_(ntracks)
        self.avg_length_correlation /= np.float_(ntracks)
        self.avg_kc_distance_correlation /= np.float_(ntracks)
        self.avg_kc_sep_correlation /= np.float_(ntracks)
        print "   avg correlation {} (length: {}, kc: {}, kc_sep: {})".format(self.avg_correlation_combined, self.avg_length_correlation, self.avg_kc_distance_correlation, self.avg_kc_sep_correlation)

        # Now calculate the fitness of the integrated biorientation time, and the integrated fbiorientation
        self.Cen2ChromosomeHeuristics(sd)

    # Cerevisiae length only fitness information
    def CerevisiaeBasicFitness(self,sd):
        # Adjust for the max length of things
        self.experiment_data_adj = {}
        time = np.array(sd.PostAnalysis.timedata['time'])
        self.max_time = time[-1]
        for track in self.best_tracks:
            df = self.experiment_data[track]
            if self.anaphase[track] != 0:
                df = df[df.time_length <= self.anaphase[track]]
            # Correct for max time
            df = df[df.time_length <= self.max_time]
            self.experiment_data_adj[track] = df

        # Single seed, to make sure to save information, etc, in self
        # Get the time, spindle length, kinetochore position, and interkc stretch as separate objects
        time = np.array(sd.PostAnalysis.timedata['time'])
        spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
        dist_dict = sd.PostAnalysis.timedata['kc_distance']
        spindle_length = np.array([spbsep_dict[ts] for ts in time])*uc['um'][1]
        kc_distance = np.array([dist_dict[ts] for ts in time]) # KC doesn't need unit conversion, already done!!!!!

        # Generate the actual time traces, matching sampling
        self.nchromo = len(kc_distance[0])/2
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))
        time = time[0::self.nmeasure]
        spindle_length = spindle_length[0::self.nmeasure]
        kc_distance = kc_distance[0::self.nmeasure]

        print "nchromo: {}".format(self.nchromo)

        # FITNESS calc here!
        # Do all sorts of fun looping etc to compare the proper fitness for this seed
        # See also Validation/cen2_fitness.py
        # What is my maximum correlation?
        self.max_correlation_combined = 0.0
        self.max_length_correlation = 0.0
        #self.max_kc_distance_correlation = 0.0
        #self.max_kc_sep_correlation = 0.0
        self.max_correlation_track = -1

        # What about the averages?
        self.avg_correlation_combined = 0.0
        self.avg_length_correlation = 0.0
        #self.avg_kc_distance_correlation = 0.0
        #self.avg_kc_sep_correlation = 0.0

        model_ts_length = pd.Series(spindle_length)
        ntracks = len(self.best_tracks)
        for it, track in enumerate(self.best_tracks):
            df = self.experiment_data_adj[track]
            exp_length = df['length']
            exp_time_length = df['time_length']

            # Length for this particular comparison
            corr_length = model_ts_length.corr(exp_time_length)

            ## XXX: There are no chromosome pieces of information yet, so skip this part
            if (corr_length) > self.max_correlation_combined:
                self.max_length_correlation = corr_length
                self.max_correlation_combined = corr_length
                self.max_correlation_track = track

            # Averages
            self.avg_correlation_combined += (corr_length)
            self.avg_length_correlation += corr_length

        print "   max correlation with experimental track {} = {} (length: {})".format(self.max_correlation_track, self.max_correlation_combined, self.max_length_correlation)
        self.avg_correlation_combined /= np.float_(ntracks)
        self.avg_length_correlation /= np.float_(ntracks)
        #self.avg_kc_distance_correlation /= np.float_(ntracks)
        #self.avg_kc_sep_correlation /= np.float_(ntracks)
        print "   avg correlation {} (length: {})".format(self.avg_correlation_combined, self.avg_length_correlation)

        # Now do chromosome biorientaiton heuristics
        self.Cen2ChromosomeHeuristics(sd)

    # Chromosome heuristic information about the biorientation time and the fraction of attachments that are bioriented
    # Do this for all times!
    def Cen2ChromosomeHeuristics(self, sd):
        time = np.array(sd.PostAnalysis.timedata['time'])
        spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
        attach_dict = sd.PostAnalysis.timedata['kc_atypes']
        kcnend_dict = sd.PostAnalysis.timedata['kc_nend']

        spindle_length = np.array([spbsep_dict[ts] for ts in time])*uc['um'][1]
        kc_attachments = np.array([attach_dict[ts] for ts in time])
        kc_nendon      = np.array([kcnend_dict[ts] for ts in time])

        target_attach = 4
        # FIXME: Different values for budding and fission yeast
        target_spindle = 1.0 # Set the spindle for above 1 micron, fission yesat
        #target_spindle = 0.5 # budding yeast, scaled from 1 micron from the difference in nucleus size, was 0.73 to match fraction of nucleus from fission yeast
        if self.strain_name == 'WT_Cen2':
            target_spindle = 1.0
        elif self.strain_name == 'Cerevisiae_basic':
            target_spindle = 0.53
        elif self.strain_name == 'Cerevisiae_length':
            target_spindle = 0.53
        else:
            print "Something has gone horribly wrong in spindle_fitness"
            sys.exit(1)
        print "target spindle length {} for number of chromosomes {}".format(target_spindle, self.nchromo)
        target_amphitelic = self.nchromo

        # Integrated biorientation time
        self.integrated_biorientation_time = 0.0
        self.fraction_integrated_biorientation_time = 0.0
        aa_ = np.zeros(len(kc_attachments))
        for x in xrange(len(kc_attachments)):
            for ic in xrange(self.nchromo):
                attach = kc_attachments[x][ic]
                if attach == target_attach:
                    aa_[x] += 1.0
        condition_ = np.logical_and(spindle_length >= target_spindle, aa_ >= target_amphitelic)
        # Abuse the start/end time stuff to get the fitness
        start_end_times = get_start_end_time(time, condition_, duration = 0.01, thresh = 0.9, break_length = 0.2, end_flag = False, chromosome_seconds = True)
        print "start_end_times: {}".format(start_end_times)
        if start_end_times:
            if any(isinstance(el, list) for el in start_end_times):
                for start_end in start_end_times:
                    start = time[start_end[0]]
                    end = time[start_end[1]]
                    print "start: {}, end: {}".format(start, end)
                    self.integrated_biorientation_time += (end - start) * 60.0
            else:
                start = time[start_end_times[0]]
                end = time[start_end_times[1]]
                self.integrated_biorientation_time += (end - start) * 60.0
        self.fraction_integrated_biorientation_time = self.integrated_biorientation_time / (time[-1]*60.0)

        print "  Integrated biorientation time: {}".format(self.integrated_biorientation_time)
        print "  Fraction integrated biorientation time: {}".format(self.fraction_integrated_biorientation_time)

        # Fraction N end on attachments biorientation time
        fbiorient = 0.0
        naf = 2*sd.PostAnalysis.naf # Get the numbe of attachments per chromosome dynamically
        for i in xrange(len(time)):
            for ic in xrange(self.nchromo):
                # Check if amphitelic
                if kc_attachments[i][ic] == 4:
                    fbiorient += kc_nendon[i][ic]

        self.fbiorient = fbiorient / (len(time) * self.nchromo * naf)
        print "  Fraction normalized biorientation time: {}".format(self.fbiorient)

    ### Basic Print functionality
    def Print(self):
        # CJE trying to recreate how ChiPet works, because Adam is awesome
        for varname in self.varnames:
            print "{}: ".format(varname)
            for k in self.ldict.keys():
                print "  {}: {}".format(k, self.P[varname][k])

        print "lstream_model: {}".format(self.lstream_model)
        if self.chromosome_seconds:
            print "chromosome_seconds: {}".format(self.chromosome_seconds)

    # Print to YAML
    def PrintYAML(self, path, success):
        with open(os.path.join(path, 'fitness.yaml'), 'w') as stream:
            for varname in self.varnames:
                stream.write("{}:\n".format(varname))
                for k in self.ldict.keys():
                    stream.write("  {}: {}\n".format(k, self.P[varname][k]))

            stream.write("lstream_model:\n")
            for l in self.lstream_model:
                stream.write("  - {}\n".format(l))

            stream.write("success: {}\n".format(success))

            stream.write("length_correlation_avg: {}\n".format(self.length_correlation_avg))

            if self.chromosome_seconds:
                stream.write("chromosome_seconds: {}\n".format(self.chromosome_seconds))
                stream.write("chromosome_seconds_fraction: {}\n".format(self.chromosome_seconds_fraction))
                stream.write("fbiorient: {}\n".format(self.fbiorient))

    # Print the special cen2 yaml
    def PrintCen2YAML(self, path):
        with open(os.path.join(path, 'fitness.yaml'), 'w') as stream:
            # Write all the correlation measures first
            stream.write('max_correlation_combined: {}\n'.format(self.max_correlation_combined))
            stream.write('max_length_correlation: {}\n'.format(self.max_length_correlation))
            stream.write('max_kc_distance_correlation: {}\n'.format(self.max_kc_distance_correlation))
            stream.write('max_kc_sep_correlation: {}\n'.format(self.max_kc_sep_correlation))
            stream.write('avg_correlation_combined: {}\n'.format(self.avg_correlation_combined))
            stream.write('avg_length_correlation: {}\n'.format(self.avg_length_correlation))
            stream.write('avg_kc_distance_correlation: {}\n'.format(self.avg_kc_distance_correlation))
            stream.write('avg_kc_sep_correlation: {}\n'.format(self.avg_kc_sep_correlation))

            # Biorientation heuristics
            stream.write('integrated_biorientation_time: {}\n'.format(self.integrated_biorientation_time))
            stream.write('fraction_integrated_biorientation_time: {}\n'.format(self.fraction_integrated_biorientation_time))
            stream.write('fraction_normalized_biorientation_Time: {}\n'.format(self.fbiorient))

    # Print the cerevisiae basic parameters
    def PrintCerevisiaeBasicYAML(self, path):
        with open(os.path.join(path, 'fitness.yaml'), 'w') as stream:
            # Write all the correlation measures first
            stream.write('max_correlation_combined: {}\n'.format(self.max_correlation_combined))
            stream.write('max_length_correlation: {}\n'.format(self.max_length_correlation))
            #stream.write('max_kc_distance_correlation: {}\n'.format(self.max_kc_distance_correlation))
            #stream.write('max_kc_sep_correlation: {}\n'.format(self.max_kc_sep_correlation))
            stream.write('avg_correlation_combined: {}\n'.format(self.avg_correlation_combined))
            stream.write('avg_length_correlation: {}\n'.format(self.avg_length_correlation))
            #stream.write('avg_kc_distance_correlation: {}\n'.format(self.avg_kc_distance_correlation))
            #stream.write('avg_kc_sep_correlation: {}\n'.format(self.avg_kc_sep_correlation))

            # Biorientation heuristics
            stream.write('integrated_biorientation_time: {}\n'.format(self.integrated_biorientation_time))
            stream.write('fraction_integrated_biorientation_time: {}\n'.format(self.fraction_integrated_biorientation_time))
            stream.write('fraction_normalized_biorientation_Time: {}\n'.format(self.fbiorient))


    # Print the special cen2 yaml empty
    def PrintEmptyCen2YAML(self, path):
        with open(os.path.join(path, 'fitness.yaml'), 'w') as stream:
            # Write all the correlation measures first
            stream.write('max_correlation_combined: {}\n'.format(0.0))
            stream.write('max_length_correlation: {}\n'.format(0.0))
            stream.write('max_kc_distance_correlation: {}\n'.format(0.0))
            stream.write('max_kc_sep_correlation: {}\n'.format(0.0))
            stream.write('avg_correlation_combined: {}\n'.format(0.0))
            stream.write('avg_length_correlation: {}\n'.format(0.0))
            stream.write('avg_kc_distance_correlation: {}\n'.format(0.0))
            stream.write('avg_kc_sep_correlation: {}\n'.format(0.0))

            # Biorientation heuristics
            stream.write('integrated_biorientation_time: {}\n'.format(0.0))
            stream.write('fraction_integrated_biorientation_time: {}\n'.format(0.0))
            stream.write('fraction_normalized_biorientation_Time: {}\n'.format(0.0))

    def PrintEmptyYAML(self, path):
        with open(os.path.join(path, 'fitness.yaml'), 'w') as stream:
            for varname in self.varnames:
                stream.write("{}:\n".format(varname))
                for k in self.ldict.keys():
                    stream.write("  {}: 0.0\n".format(k))

            stream.write("lstream_model:\n")
            for l in self.lstream_model:
                stream.write("  - {}\n".format(l))

            stream.write("success: 0.0\n")

            stream.write("length_correlation_avg: {}\n".format(0.0))

            if self.chromosome_seconds:
                stream.write("chromosome_seconds: {}\n".format(self.chromosome_seconds))
                stream.write("chromosome_seconds_fraction: {}\n".format(self.chromosome_seconds_fraction))
                stream.write("fbiorient: {}\n".format(self.fbiorient))

    ### Fun save/load hacks
    def savestate(self):
        self.experiment_data = None
        filename = os.path.join(self.path, "SpindleFitness.pickle")
        self.save(filename)

    def save(self, filename):
        self.experiment_data = None
        f = open(filename, 'wb')
        pickle.dump(self.__dict__,f)
        f.close()

    def load(self, filename):
        f = open(filename, 'rb')
        tmp_dict = pickle.load(f)
        f.close()

        self.__dict__.update(tmp_dict)

if __name__ == "__main__":
    cwd = os.getcwd()
    fitness = SpindleFitness(cwd)
    fitness.LoadExperimentDistributions(sys.argv[1], sys.argv[2])
    fitness.StatisticalTests()
    fitness.Print()
