#!/usr/bin/env python
# In case of poor (Sh**y) commenting contact christopher.edelmaier@colorado.edu
# Basic
import sys, os, pdb
import gc
import argparse
import fnmatch
## Analysis
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Spindle'))

from spindle_sim import SpindleSim
from sim_graph_funcs import *
from stylelib.ase1_styles import ase1_sims_stl
from cen2_fitness import Cen2Fitness

from scipy.stats import ks_2samp
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw
import scipy.io as sio

'''
Name: RealFitnessTests.py
Description:
Input:
Output:
'''

def parse_args():
    parser = argparse.ArgumentParser(prog='RealFitnessTests.py')

    # General options that are actually required
    parser.add_argument('-m', '--model', required=True, type=str,
            help='Model sim')
    parser.add_argument('-e','--experiment', required=True, type=str,
            help='Experimental data to read')
    parser.add_argument('-k','--kinetochores', action='store_true',
            help='Kinetochore or no kinetochore model')

    # Minimum of extra options to make this work
    parser.add_argument('--nopost', action='store_true',
            help="Do not use post analyis program to decrease time of analysis.")
    parser.add_argument('-F', '--fitness', action='store_true',
            help='Create fitness for spindle simulations.')
    parser.add_argument('-A', '--analyze', action='store_true',
            help='Analyze data from multiple simulations')
    parser.add_argument('--datadir', type=str,
            help='Name of the data directory in which all analyzed data files will be read/written. \
                    Also the directory where all the graphs will be placed when saved. \
                    Default is set to {workdir}/data/.')
    parser.add_argument('-d', '--workdir', type=str,
            help='Name of the working directory where simulation will be run.')

    opts = parser.parse_args()
    return opts

# Class definition
class RealFitnessTests(object):
    def __init__(self, opts):
        self.opts = opts
        self.cwd = os.getcwd()

        self.ReadOpts()
        self.AnalyzeSims()

    # Read in options
    def ReadOpts(self):
        if not self.opts.workdir:
            self.opts.workdir = os.path.abspath(self.cwd)
        elif not os.path.exists(self.opts.workdir):
            raise IOError( "Working directory {} does not exist.".format(
                self.opts.workdir) )
        else:
            self.opts.workdir = os.path.abspath(self.opts.workdir)

        self.model_dir = os.path.abspath(self.opts.model)
        self.model_sim = SpindleSim(self.model_dir, opts)
        self.truncname = self.model_sim.name.split('_')[0]

        self.LoadData()

    # Read in the data from a specified location
    def LoadData(self):
        print "Reading in data from {}".format(self.opts.experiment)
        if self.opts.experiment == 'klp56':
            experiment_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'Klp56', 'Pandas')
            listoffiles = os.listdir(experiment_dir)
            pattern = '*.csv'
            experiments = []
            for entry in listoffiles:
                if fnmatch.fnmatch(entry, pattern):
                    experiments += [os.path.join(experiment_dir, entry)]

            ### FIXME hardcoded by hand for when anaphase begins
            self.tracklist = [1, 3, 4, 6, 7, 8, 9, 10, 12, 13, 61, 149, 150, 151, 152, 153]
            self.timelist = [21.5, 0, 18, 28.5, 0, 0, 52.5, 28.25, 0, 0, 18.75, 5.25, 13.5, 3.5, 0, 13.25]
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            self.best_tracks = [1, 12, 13, 61, 150, 152, 153]

            self.experiment_data = {}
            for experiment in experiments:
                track_names = os.path.basename(experiment).split('_')
                track_name = track_names[0] + '_' + track_names[1]
                track_number = np.int(track_names[1])
                #df = pd.read_pickle(experiment)
                df = pd.read_csv(experiment)
                self.experiment_data[track_number] = df
                self.measure_time = df['time'][1] - df['time'][0] ### FIXME
        elif self.opts.experiment == 'wt':
            experiment_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'WT', 'Pandas')
            listoffiles = os.listdir(experiment_dir)
            pattern = '*.csv'
            experiments = []
            for entry in listoffiles:
                if fnmatch.fnmatch(entry, pattern):
                    experiments += [os.path.join(experiment_dir, entry)]

            # These shoudl all be fine, no need to do the anaphase thing...
            self.tracklist = ['002-1',
                              '002-2',
                              '003-1',
                              '003-2',
                              '003-3',
                              '004',
                              '005-1',
                              '005-2',
                              '005-3',
                              '005-4',
                              '006',
                              '007',
                              '009',
                              '010-1',
                              '010-2',
                              '010-3']
            self.timelist = np.zeros(16)
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            self.best_tracks = ['002-1',
                                '002-2',
                                '003-1',
                                #'003-2',
                                '003-3',
                                '004',
                                '005-1',
                                '005-2',
                                '005-3',
                                '005-4',
                                '006',
                                '007',
                                '009',
                                '010-1',
                                '010-2',
                                '010-3']

            self.experiment_data = {}
            for experiment in experiments:
                track_names = os.path.basename(experiment).split('_')
                track_subtrack = track_names[1].split('-')
                if len(track_subtrack) > 1:
                    track_name = track_names[0] + '_' + track_subtrack[0] + '-' + track_subtrack[1]
                    track_number = track_subtrack[0] + '-' + track_subtrack[1]
                else:
                    track_name = track_names[0] + '_' + track_subtrack[0]
                    track_number = track_subtrack[0]
                df = pd.read_csv(experiment)
                self.experiment_data[track_number] = df
                self.measure_time = df['time_length'][1] - df['time_length'][0]

        elif self.opts.experiment == "cen2":
            experiment_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'WT', 'Cen2')
            listoffiles = os.listdir(experiment_dir)
            pattern = '*.csv'
            experiments = []
            for entry in listoffiles:
                if fnmatch.fnmatch(entry, pattern):
                    experiments += [os.path.join(experiment_dir, entry)]

            self.tracklist = ['001_1',
                              '001_2',
                              '002_1',
                              '002_2',
                              '002_3',
                              '003_1',
                              '003_3',
                              '003_4',
                              '004_1',
                              '004_2']
            self.timelist = np.zeros(len(self.tracklist))
            self.anaphase = dict(zip(self.tracklist, self.timelist))
            # Set the anaphse by hand
            self.anaphase['001_2'] = 35.2
            self.anaphase['002_1'] = 21.87
            self.anaphase['002_2'] = 23.07
            #self.anaphase['002_3'] = 15.6
            self.anaphase['002_3'] = 17.7
            self.anaphase['003_3'] = 8.53
            self.anaphase['003_4'] = 5.2
            self.anaphase['004_1'] = 21.2
            self.anaphase['004_2'] = 27.87
            self.best_tracks = ['001_1',
                                '001_2',
                                '002_1',
                                '002_2',
                                '002_3',
                                #'003_1',
                                '003_3',
                                '003_4',
                                '004_1',
                                '004_2']

            self.experiment_data = {}
            for experiment in experiments:
                track_names = os.path.basename(experiment).split('_')
                track_name     = track_names[2]
                track_subtrack = track_names[3].split('.')[0]

                track_number = track_name + '_' + track_subtrack
                df = pd.read_csv(experiment, names = ["length", "kc0_dist", "kc1_dist", "length_error", "kc0_dist_error", "kc1_dist_error"])
                # add the time information to the length...
                self.measure_time = 8.0/60.0
                df['time_length'] = df.index.values * self.measure_time
                # Offset of 2 seconds for the kc timing
                df['time_kc'] = df['time_length'] + 2.0/60.0
                self.experiment_data[track_number] = df
        else:
            print "Please choose a correct experiment to compare to!"
            sys.exit(1)

    # Analyze the sims
    def AnalyzeSims(self):
        self.model_sim.Analyze()
        self.model_sim.CalcSimSuccess()

        if self.opts.fitness:
            wt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'wt.mat')
            lstream_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'lstream.mat')
            self.model_sim.Fitness(wt_path, lstream_path)

    def GraphExperimentAll(self):
        # Plot the spindle length vs time first
        fig, ax = plt.subplots()
        # Get the colors
        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.best_tracks)))

        # Iterate over the time series, and plot them
        index = 0
        axr = ax
        lstream = []

        self.experiment_data_adj = {}
        for track in self.best_tracks:
            df = self.experiment_data[track]
            # Correct for anaphase behavior and the max_time behavior
            if self.anaphase[track] != 0:
                df = df[df.time_length <= self.anaphase[track]]
            # Correct for max time
            df = df[df.time_length <= self.max_time]
            self.experiment_data_adj[track] = df

            time = df['time_length']
            spbsep_arr = df['length']
            if spbsep_arr[0] > 2.5:
                print "bad track: {}".format(track)
            # Cap the time at the max, and anaphase
            # Graphing options
            #axr.set_title(r'WT SPB Separation')
            axr.set_ylabel(r'SPB-SPB Distance ($\mu$m)')
            axr.set_xlabel(r'Time (min)') 
            axr.plot(time, spbsep_arr, color=colors[index])
            lstream += [spbsep_arr]
            index += 1

        fig.tight_layout()
        fig.savefig('wt_experimental_time_traces.pdf', dpi=fig.dpi)

        # Graph the histogram of the KC distances from one of the spindle poles
        # Get the experimental values first
        dist_stream_experiment = []
        for track in self.best_tracks:
            df = self.experiment_data_adj[track]
            times = df['time_kc']
            kc0 = df['kc0_dist']
            kc1 = df['kc1_dist']
            kc2 = df['kc2_dist']
            kc3 = df['kc3_dist']
            kc4 = df['kc4_dist']
            kc5 = df['kc5_dist']

            dist_stream_experiment = np.append(dist_stream_experiment, [kc0, kc1, kc2, kc3, kc4, kc5])

        bins = np.linspace(0, 2.75, 10)

        fig, ax = plt.subplots()
        ax.hist(dist_stream_experiment, bins, alpha = 0.5, label = 'wt', density = True)
        ax.set_xlabel(r'SPB-KC Distance $\mu$m')
        ax.set_ylabel(r'Probability')
        fig.tight_layout()
        fig.savefig('wt_experimental_kcspbd_dist.pdf', dpi=fig.dpi)
        plt.close()
        gc.collect()
        mpl.rcdefaults()

        # Graph the 1d spindle location and ks test of the kinetochores
        kc_spindle_experiment = []
        for track in self.best_tracks:
            df = self.experiment_data_adj[track]
            lengths = df['length']
            kc0 = df['kc0_dist']
            kc1 = df['kc1_dist']
            kc2 = df['kc2_dist']
            kc3 = df['kc3_dist']
            kc4 = df['kc4_dist']
            kc5 = df['kc5_dist']

            # For experiments, assume that the KC is on the spindle, and that it is always the distance away from the SPB
            # That it is tagged to, so that we can generate a 1d length distribution
            for i in xrange(len(lengths)):
                kc0_1d = kc0[i] / lengths[i]
                kc1_1d = kc1[i] / lengths[i]
                kc2_1d = kc2[i] / lengths[i]
                kc3_1d = kc3[i] / lengths[i]
                kc4_1d = kc4[i] / lengths[i]
                kc5_1d = kc5[i] / lengths[i]

                kc_spindle_experiment = np.append(kc_spindle_experiment, [kc0_1d, kc1_1d, kc2_1d, kc3_1d, kc4_1d, kc5_1d])

        bins = np.linspace(0, 1.0, 10)

        fig, ax = plt.subplots()
        ax.hist(kc_spindle_experiment, bins, alpha = 0.5, label = 'wt', density = True)
        ax.set_xlabel(r'KC Normalized Distance on Spindle')
        ax.set_ylabel(r'Probability')
        fig.tight_layout()
        fig.savefig('wt_experimental_kc1d_dist.pdf', dpi=fig.dpi)
        plt.close()
        gc.collect()
        mpl.rcdefaults()



    def GraphTimeSeries(self):
        #fig, axarr = plt.subplots(2, 1, figsize=(16,16))
        fig, axarr = plt.subplots(2, 1)

        self.GraphTimeSeriesLength(fig, axarr, self.model_sim, 0)
        self.GraphTimeSeriesReal(fig, axarr, 1)

        # Reset the limit of the y axis to the experiment
        axarr[0].set_ylim(axarr[1].get_ylim())

        self.plot_base_name = "model_" + self.truncname + "_experiment_" + self.opts.experiment
        plt.savefig(self.plot_base_name + "timeseries.pdf", dpi=fig.dpi)
        #plt.show()
        plt.close()
        gc.collect()
        mpl.rcdefaults()

    def GraphEMTomographySeeds(self):
        varnames = [['interpolar', 'interpolar'],
                    ['pairing_length_max', 'pair'],
                    ['total', 'lengths'],
                    ['angles', 'angles']]
        lengths = ['short',
                   'medium',
                   'long']
        conversions = {}
        conversions['interpolar'] = uc['um'][1]
        conversions['pairing_length_max'] = uc['um'][1]
        conversions['total'] = uc['um'][1]
        conversions['angles'] = 1.0
        bins = {}
        bins['interpolar'] = np.linspace(0.0, 2.0, 9)
        bins['pairing_length_max'] = np.linspace(0.0, 2.0, 9)
        bins['total'] = np.linspace(0.0, 2.0, 9)
        bins['angles'] = np.linspace(0, 2*pi, 9)
        self.ldict = {}
        self.ldict['long'] = '2.15um'
        self.ldict['med'] = '1.825um'
        self.ldict['short'] = '1.05um'

        # The fitnesses
        emfitness = {}
        emlogfitness = {}
        emfitness['interpolar'] = 0.0
        emfitness['pairing_length_max'] = 0.0
        emfitness['total'] = 0.0
        emfitness['angles'] = 0.0
        emlogfitness['interpolar'] = 0.0
        emlogfitness['pairing_length_max'] = 0.0
        emlogfitness['total'] = 0.0
        emlogfitness['angles'] = 0.0


        for sd in self.model_sim.seeds:
            distrdata = self.emdata_model[sd.name]
            # Group by the type of measurement
            for varname in varnames:
                fig, axarr = plt.subplots(3, 1, figsize=(25,16))
                matlabname = varname[1]
                # Histogram of the experiment
                axarr[0].hist(np.array(self.emdata_exp[matlabname]['short'])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)
                axarr[1].hist(np.array(self.emdata_exp[matlabname]['med'])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)
                axarr[2].hist(np.array(self.emdata_exp[matlabname]['long'])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)

                # Histogram of the model seed point
                axarr[0].hist(np.array(distrdata['1.05um'][varname[0]])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)
                axarr[1].hist(np.array(distrdata['1.825um'][varname[0]])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)
                axarr[2].hist(np.array(distrdata['2.15um'][varname[0]])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)

                # Run the KS test on each and quote the results
                short_res = ks_2samp(np.array(distrdata['1.05um'][varname[0]]), self.emdata_exp[matlabname]['short'])
                med_res = ks_2samp(np.array(distrdata['1.825um'][varname[0]]), self.emdata_exp[matlabname]['med'])
                long_res = ks_2samp(np.array(distrdata['2.15um'][varname[0]]), self.emdata_exp[matlabname]['long'])

                axarr[0].set_title(str(short_res))
                axarr[1].set_title(str(med_res))
                axarr[2].set_title(str(long_res))

                fig.tight_layout()

                plt.savefig('em_' + sd.name + '_' + varname[0] + '.png')
                plt.close()
                gc.collect()
                mpl.rcdefaults()
                
                # Average this stuff
                emfitness[varname[0]] += (short_res.pvalue + med_res.pvalue + long_res.pvalue)/3.
                emlogfitness[varname[0]] += (np.log10(short_res.pvalue)/100.0 + np.log10(med_res.pvalue)/100.0 + np.log10(long_res.pvalue)/100.0)/3.

        total_base = emfitness['interpolar'] + emfitness['pairing_length_max'] + emfitness['total'] + emfitness['angles']
        total_log  = emlogfitness['interpolar'] + emlogfitness['pairing_length_max'] + emlogfitness['total'] + emlogfitness['angles']

        # Normalize to the number of seeds...
        total_base /= 12.
        total_log /= 12.

        print "Avg Total Base EM Fitness {}".format(total_base/4.)
        print "Avg Total Log EM Fitness {}".format(total_log/4.)

    def GraphEMTomographyCombined(self):
        varnames = [['interpolar', 'interpolar'],
                    ['pairing_length_max', 'pair'],
                    ['total', 'lengths'],
                    ['angles', 'angles']]
        lengths = ['short',
                   'medium',
                   'long']
        conversions = {}
        conversions['interpolar'] = uc['um'][1]
        conversions['pairing_length_max'] = uc['um'][1]
        conversions['total'] = uc['um'][1]
        conversions['angles'] = 1.0
        bins = {}
        bins['interpolar'] = np.linspace(0.0, 2.0, 9)
        bins['pairing_length_max'] = np.linspace(0.0, 2.0, 9)
        bins['total'] = np.linspace(0.0, 2.0, 9)
        bins['angles'] = np.linspace(0, 2*pi, 9)
        self.ldict = {}
        self.ldict['long'] = '2.15um'
        self.ldict['med'] = '1.825um'
        self.ldict['short'] = '1.05um'
        total_base = 0.0
        total_log = 0.0
        # Combine the measurements from the model into one big pot
        for varname in varnames:
            combined = {}
            for sd in self.model_sim.seeds:
                distrdata = self.emdata_model[sd.name]

                # Do the standard thing of combining information into the total array
                if 'short' not in combined:
                    if varname[0] in distrdata['1.05um']:
                        combined['short'] = np.array(distrdata['1.05um'][varname[0]])*conversions[varname[0]]
                else:
                    if varname[0] in distrdata['1.05um']:
                        combined['short'] = np.append(combined['short'], np.array(distrdata['1.05um'][varname[0]])*conversions[varname[0]])

                if 'medium' not in combined:
                    if varname[0] in distrdata['1.825um']:
                        combined['medium'] = np.array(distrdata['1.825um'][varname[0]])*conversions[varname[0]]
                else:
                    if varname[0] in distrdata['1.825um']:
                        combined['medium'] = np.append(combined['medium'], np.array(distrdata['1.825um'][varname[0]])*conversions[varname[0]])

                if 'long' not in combined:
                    if varname[0] in distrdata['2.15um']:
                        combined['long'] = np.array(distrdata['2.15um'][varname[0]])*conversions[varname[0]]
                else:
                    if varname[0] in distrdata['2.15um']:
                        combined['long'] = np.append(combined['long'], np.array(distrdata['2.15um'][varname[0]])*conversions[varname[0]])

            # Histogram this
            fig, axarr = plt.subplots(3, 1)
            matlabname = varname[1]
            # Histogram of the experiment
            axarr[0].hist(np.array(self.emdata_exp[matlabname]['short'])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)
            axarr[1].hist(np.array(self.emdata_exp[matlabname]['med'])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)
            axarr[2].hist(np.array(self.emdata_exp[matlabname]['long'])*conversions[varname[0]], bins=bins[varname[0]], alpha = 0.5, density = True)

            # Histogram of the combined model points
            axarr[0].hist(combined['short'], bins=bins[varname[0]], alpha = 0.5, density = True)
            axarr[1].hist(combined['medium'], bins=bins[varname[0]], alpha = 0.5, density = True)
            axarr[2].hist(combined['long'], bins=bins[varname[0]], alpha = 0.5, density = True)

            # Run the combined KS test
            inres = {}
            inres['short'] = (ks_2samp(combined['short'], np.array(self.emdata_exp[matlabname]['short'])*conversions[varname[0]])).pvalue
            inres['medium'] = (ks_2samp(combined['medium'], np.array(self.emdata_exp[matlabname]['med'])*conversions[varname[0]])).pvalue
            inres['long'] = (ks_2samp(combined['long'], np.array(self.emdata_exp[matlabname]['long'])*conversions[varname[0]])).pvalue
            #print "Short  {} KS: {}".format(varname[0], inres['short'])
            #print "Medium {} KS: {}".format(varname[0], inres['medium'])
            #print "Long   {} KS: {}".format(varname[0], inres['long'])
            combined_val = (inres['short'] + inres['long'] + inres['medium'])/3.
            # Take the log of each, and then if the pvalue is 0.0, set to -10.0
            logres_short = -10.0
            logres_medium = -10.0
            logres_long = -10.0
            if inres['short'] > 0.0:
                logres_short = np.log10(inres['short'])/100.
            if inres['medium'] > 0.0:
                logres_medium = np.log10(inres['medium'])/100.
            if inres['long'] > 0.0:
                logres_long = np.log10(inres['long'])/100.
            combined_log = (logres_short + logres_medium + logres_long)/3.
            if np.isnan(combined_log) or np.isinf(combined_log):
                combined_log = -10.0 # Set a minimum value...
            total_base += combined_val
            total_log += combined_log
            #print "{} Combined {} ({})".format(varname[0], combined_val, combined_log)

            # Display the results of the EM fitness
            axarr[0].set_title(str(inres['short']))
            axarr[1].set_title(str(inres['medium']))
            axarr[2].set_title(str(inres['long']))
            
            fig.tight_layout()

            fig.savefig('em_' + 'combined' + '_' + varname[0] + '.pdf', dpi=fig.dpi)
            plt.close()
            gc.collect()
            mpl.rcdefaults()

        print "Total Base EM Fitness {}".format(total_base/4.)
        print "Total Log EM Fitness {}".format(total_log/4.)


    def GraphTimeSeriesLength(self, fig, axarr, sim, row):
        # Graph the model that we are doing
        colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        time = sim.seeds[-1].time[:min_size]
        nmeasure = np.int(self.measure_time / ((time[1] - time[0])))
        time = time[0::nmeasure]
        min_size = time.size

        avg = np.zeros(min_size)
        self.min_size = min_size
        num_seeds = 0

        axr = axarr[row]

        lstream = []
        dstream = []

        # Grab the EM fitness streams from each seed
        self.emdata_model = {}

        # Do the graphing by hand of the spb separation, need this to change
        # the measurement frequency time
        self.max_time = 0.0
        for sd, col in zip(sim.seeds, colors):
            time = sd.PostAnalysis.timedata['time']
            spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
            if 'kc_distance' in sd.PostAnalysis.timedata: kcdist_dict = sd.PostAnalysis.timedata['kc_distance']
            spbsep_arr = [spbsep_dict[ts] for ts in time ]
            if 'kc_distance' in sd.PostAnalysis.timedata: kcdist_arr = [kcdist_dict[ts] for ts in time ]
            spbsep_arr = np.array(spbsep_arr)*uc['um'][1]
            # don't have to convert the kcdist array, already should be in microns
            spbsep_arr = spbsep_arr[0::nmeasure]
            if 'kc_distance' in sd.PostAnalysis.timedata: kcdist_arr = kcdist_arr[0::nmeasure]
            time = time[0::nmeasure]
            # Graphing options
            axr.set_title(r'Model')
            axr.set_ylabel(r'SPB Separation ($\mu$m)')
            xlabel = False
            if xlabel: axr.set_xlabel(r'Time (min)') 
            axr.plot(time, spbsep_arr, color=col, label=sd.label)
            lstream += [spbsep_arr]
            if 'kc_distance' in sd.PostAnalysis.timedata: dstream += [kcdist_arr]
            num_seeds += 1
            avg = np.add(avg, spbsep_arr[:min_size])
            self.max_time = max(self.max_time, time[-1])

            # Check for the EM tomogramphy fitness...
            self.emdata_model[sd.name] = sd.PostAnalysis.distrdata['fitness']
            self.emdata_exp = sd.fitness.experiment

        #avg /= float(num_seeds)
        #axr[1].plot(time, avg)

        ## Modify the limits, etc
        #axarr[1, 0].legend(loc='center left', bbox_to_anchor=(2.2,-.19))

        axarr[0].set_xlabel(r'Time (min)')
        #axarr[-1,0].set_xlabel(r'Time (min)')
        fig.tight_layout()
        self.xlims = axarr[0].get_xlim()
        #plt.subplots_adjust(hspace = .38, right=.85, top=.87)

        if row == 0:
            self.model_avg_length = avg
            self.model_lstream = lstream
            self.model_times = time
            self.model_nseeds = num_seeds
            self.model_dist_stream = dstream
        elif row == 1:
            self.experiment_avg_length = avg
            self.experiment_lstream = lstream
            self.experiment_times = time
            self.experiment_nseeds = num_seeds

    def GraphTimeSeriesReal(self, fig, axarr, row):
        # Get the colors
        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.best_tracks)))

        # Iterate over the time series, and plot them
        index = 0
        axr = axarr[row]
        lstream = []

        self.experiment_data_adj = {}
        for track in self.best_tracks:
            df = self.experiment_data[track]
            # Correct for anaphase behavior and the max_time behavior
            if self.anaphase[track] != 0:
                df = df[df.time_length <= self.anaphase[track]]
            # Correct for max time
            df = df[df.time_length <= self.max_time]
            self.experiment_data_adj[track] = df

            time = df['time_length']
            spbsep_arr = df['length']
            if spbsep_arr[0] > 2.5:
                print "bad track: {}".format(track)
            # Cap the time at the max, and anaphase
            # Graphing options
            axr.set_title(r'WT Experiment')
            axr.set_ylabel(r'SPB Separation ($\mu$m)')
            xlabel = False
            if xlabel: axr.set_xlabel(r'Time (min)') 
            axr.plot(time, spbsep_arr, color=colors[index])
            lstream += [spbsep_arr]
            index += 1

        # Modify the limits, etc
        #axarr[1, 0].legend(loc='center left', bbox_to_anchor=(2.2,-.19))

        axarr[1].set_xlabel(r'Time (min)')
        axarr[1].set_xlim(self.xlims)
        axarr[1].axhline(2.75, color='k', linestyle='--')
        #axarr[-1,0].set_xlabel(r'Time (min)')
        fig.tight_layout()
        #plt.subplots_adjust(hspace = .38, right=.85, top=.87)

    def OriginalLengthFitness(self):
        # Check the original length fitness, via what should be the original setup
        # where the length is taken to be at least 1.1 micron, etc etc

        # Create the experimental distribution
        lstream_experiment = []
        for track in self.best_tracks:
            df = self.experiment_data_adj[track]
            times = df['time_length'] * 60.0
            lengths = df['length']
            t0 = times[np.argmax(lengths > 1.1)] + 10.0
            mask = np.zeros(1)
            if t0.size == 0:
                mask = np.zeros(times.size)
            else:
                mask = (times >= t0)

            if sum(mask) > 5:
                lexperiment = lengths[mask]
                lstream_experiment = np.append(lstream_experiment, lexperiment)

        # Now do the model
        for mseed in xrange(self.model_nseeds):
            lmodel = self.model_lstream[mseed]
            times = np.asarray(self.model_times) * 60.0
            t0 = times[np.argmax(lmodel > 1.1)] + 10.0
            mask = np.zeros(1)
            if t0.size == 0:
                mask = np.zeros(times.size)
            else:
                mask = (times >= t0)

            lstream_model = []
            if sum(mask) > 5:
                lmodel = lmodel[mask]
                lstream_model = np.append(lstream_model, lmodel)

        if len(lstream_experiment) > 0 and len(lstream_model) > 0:
            length_res = ks_2samp(lstream_model, lstream_experiment)
            print "length pvalue: {}".format(length_res)
        else:
            print "length pvalue: 0.0 (lengths not correct)"

        bins = np.linspace(0.0, 2.75, 10)

        plt.hist(lstream_model, bins, alpha = 0.5, label = 'model', density = True)
        plt.hist(lstream_experiment, bins, alpha = 0.5, label = 'experiment', density = True)
        plt.legend()
        plt.savefig(self.plot_base_name + "_lengthstreams.png")
        plt.close()
        gc.collect()
        mpl.rcdefaults()

    def GraphKC(self):
        # Graph the histogram of the KC distances from one of the spindle poles
        # Get the experimental values first
        dist_stream_experiment = []
        for track in self.best_tracks:
            df = self.experiment_data_adj[track]
            times = df['time_kc']
            kc0 = df['kc0_dist']
            kc1 = df['kc1_dist']
            kc2 = df['kc2_dist']
            kc3 = df['kc3_dist']
            kc4 = df['kc4_dist']
            kc5 = df['kc5_dist']

            dist_stream_experiment = np.append(dist_stream_experiment, [kc0, kc1, kc2, kc3, kc4, kc5])

        # Now grab the model values
        dist_stream_model = []
        for mseed in xrange(self.model_nseeds):
            dist_model = self.model_dist_stream[mseed]
            dist_stream_model = np.append(dist_stream_model, dist_model)
        
        # Check the KS test for this setup!
        if len(dist_stream_experiment) > 0 and len(dist_stream_model) > 0:
            length_res = ks_2samp(dist_stream_model, dist_stream_experiment)
            print "KC Dist pvalue: {}".format(length_res)
        else:
            print "length pvalue: 0.0 (lengths not correct)"

        bins = np.linspace(0, 2.75, 10)

        fig, ax = plt.subplots()
        ax.hist(dist_stream_model, bins, alpha = 0.5, label = 'Model', density = True)
        ax.hist(dist_stream_experiment, bins, alpha = 0.5, label = 'Experiment', density = True)
        ax.legend()
        #plt.title(self.truncname + ' KC Distance Histograms')
        ax.set_xlabel(r'SPB-KC Distance $\mu$m')
        ax.set_ylabel(r'Probability')
        fig.tight_layout()
        fig.savefig(self.plot_base_name + '_kcdist.pdf', dpi=fig.dpi)
        plt.close()
        gc.collect()
        mpl.rcdefaults()

    def KCSpindle1D(self):
        # Graph the 1d spindle location and ks test of the kinetochores
        kc_spindle_experiment = []
        for track in self.best_tracks:
            df = self.experiment_data_adj[track]
            lengths = df['length']
            kc0 = df['kc0_dist']
            kc1 = df['kc1_dist']
            kc2 = df['kc2_dist']
            kc3 = df['kc3_dist']
            kc4 = df['kc4_dist']
            kc5 = df['kc5_dist']

            # For experiments, assume that the KC is on the spindle, and that it is always the distance away from the SPB
            # That it is tagged to, so that we can generate a 1d length distribution
            for i in xrange(len(lengths)):
                kc0_1d = kc0[i] / lengths[i]
                kc1_1d = kc1[i] / lengths[i]
                kc2_1d = kc2[i] / lengths[i]
                kc3_1d = kc3[i] / lengths[i]
                kc4_1d = kc4[i] / lengths[i]
                kc5_1d = kc5[i] / lengths[i]

                kc_spindle_experiment = np.append(kc_spindle_experiment, [kc0_1d, kc1_1d, kc2_1d, kc3_1d, kc4_1d, kc5_1d])

        # To start, do the exact same thing with the spindle KC stuff, as the original fitness version has this binned over
        # spindle lengths of 1 micron, and os isn't working, and we've destroyed the original length stream data...
        kc_spindle_model = []
        for mseed in xrange(self.model_nseeds):
            lengths = self.model_lstream[mseed]
            dist_model = self.model_dist_stream[mseed]

            for i in xrange(len(lengths)):
                kc0_1d = dist_model[i][0] / lengths[i]
                kc1_1d = dist_model[i][1] / lengths[i]
                kc2_1d = dist_model[i][2] / lengths[i]
                kc3_1d = dist_model[i][3] / lengths[i]
                kc4_1d = dist_model[i][4] / lengths[i]
                kc5_1d = dist_model[i][5] / lengths[i]

                kc_spindle_model = np.append(kc_spindle_model, [kc0_1d, kc1_1d, kc2_1d, kc3_1d, kc4_1d, kc5_1d])

        # Check the KS test for this setup!
        if len(kc_spindle_experiment) > 0 and len(kc_spindle_model) > 0:
            length_res = ks_2samp(kc_spindle_model, kc_spindle_experiment)
            print "KC 1D pvalue: {}".format(length_res)
        else:
            print "length pvalue: 0.0 (lengths not correct)"

        bins = np.linspace(0, 1.0, 10)

        fig, ax = plt.subplots()
        ax.hist(kc_spindle_model, bins, alpha = 0.5, label = 'Model', density = True)
        ax.hist(kc_spindle_experiment, bins, alpha = 0.5, label = 'Experiment', density = True)
        ax.legend()
        #plt.title(self.truncname + ' KC 1D Histograms')
        ax.set_xlabel(r'KC Distance on Spindle')
        ax.set_ylabel(r'Probability')
        fig.tight_layout()
        plt.savefig(self.plot_base_name + '_kc1d.pdf', dpi=fig.dpi)
        plt.close()
        gc.collect()
        mpl.rcdefaults()


    def CrossCorrelation(self):
        # Check the cross correlation of all of the seeds against each other....
        max_correlation = np.zeros(self.model_nseeds)
        avg_correlation = np.zeros(self.model_nseeds)
        # Compare each model to all the experiments
        for mseed in xrange(self.model_nseeds):
            model_ts = pd.Series(self.model_lstream[mseed])

            # Now grab all of our experiments
            avg_corr = 0.0
            for track in self.best_tracks:
                df = self.experiment_data_adj[track]
                experiment_ts = pd.Series(df['length'])
                corr = model_ts.corr(experiment_ts)
                #print "Cross correlation: model: {}, experiment: {} = {}".format(mseed, eseed, corr)
                if corr > max_correlation[mseed]:
                    max_correlation[mseed] = corr
                avg_corr += corr
            avg_correlation[mseed] = avg_corr / len(self.best_tracks)

        for mseed in xrange(self.model_nseeds):
            print "Cross correlation max: model: {} = {}".format(mseed, max_correlation[mseed])
            print "Cross correlation avg: model: {} = {}".format(mseed, avg_correlation[mseed])

        print "Cross correlation avg of max: {}".format(np.mean(max_correlation))
        print "Cross correlation avg of avg: {}".format(np.mean(avg_correlation))

        # What about the average length distributions?
        #avg_e = pd.Series(self.experiment_avg_length)
        #avg_m = pd.Series(self.model_avg_length)
        #corr = avg_m.corr(avg_e)
        #print "Cross correlation: model: avg, experiment: avg = {}".format(corr)

    def DynamicTimeWarping(self):
        # Check the dynamic time warping between everything
        max_timewarp = np.zeros(self.model_nseeds)
        avg_timewarp = np.zeros(self.model_nseeds)
        # Compare the model to all experiments
        for mseed in xrange(self.model_nseeds):
            model_ts = np.column_stack([self.model_times, self.model_lstream[mseed]])
            max_dist = 2.75 * len(model_ts[:,0]) # multiply by the number of samples in the sequence

            # Now grab all of our experiments
            avg_tw = 0.0
            for track in self.best_tracks:
                df = self.experiment_data_adj[track]
                #experiment_ts = [df['time'], df['length']]
                experiment_ts = df.as_matrix(columns = ['time_length', 'length'])
                distance, _ = fastdtw(model_ts, experiment_ts, dist=euclidean)
                similarity = 1.0 - distance/max_dist
                #print "DTW: model: {}, experiment: {} = {}".format(mseed, eseed, similarity)
                if similarity > max_timewarp[mseed]:
                    max_timewarp[mseed] = similarity
                avg_tw += similarity
            avg_timewarp[mseed] = avg_tw / len(self.best_tracks)

        for mseed in xrange(self.model_nseeds):
            print "DTW max: model: {} = {}".format(mseed, max_timewarp[mseed])
            print "DTW avg: model: {} = {}".format(mseed, avg_timewarp[mseed])

        print "DTW avg of max: {}".format(np.mean(max_timewarp))
        print "DTW avg of avg: {}".format(np.mean(avg_timewarp))

        ## What about the average distributions?
        #model_avg_ts = np.column_stack([self.model_times, self.model_avg_length])
        #max_dist = 2.75 * len(model_avg_ts[:,0])
        #experiment_avg_ts = np.column_stack([self.experiment_times, self.experiment_avg_length])
        #distance, _ = fastdtw(model_avg_ts, experiment_avg_ts, dist=euclidean)
        #similarity = 1.0 - distance/max_dist
        #print "DTW: model: avg, experiment: avg = {}".format(similarity)


##########################################
if __name__ == "__main__":
    opts = parse_args()
    x = RealFitnessTests(opts)
    if opts.experiment != "cen2":
        x.GraphTimeSeries()
        #x.OriginalLengthFitness()
        x.CrossCorrelation()
        #x.DynamicTimeWarping()
        if opts.kinetochores:
            if opts.experiment != "cen2":
                x.GraphKC()
                x.KCSpindle1D()
        if opts.experiment != "cen2":
            x.GraphExperimentAll()

        # Do the EM fitness
        #x.GraphEMTomographySeeds()
        x.GraphEMTomographyCombined()

    else:
        cen2 = Cen2Fitness(x)
        cen2.Run()
