#!/usr/bin/env python
# In case of poor (Sh**y) commenting contact christopher.edelmaier@colorado.edu
# Basic
import sys, os, pdb
import gc
import argparse
import fnmatch
## Analysis
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'Lib'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'Spindle'))

from spindle_sim import SpindleSim
from sim_graph_funcs import *
from stylelib.ase1_styles import ase1_sims_stl
from stylelib.ase1_styles import cp_spindle_stl

from scipy.stats import ks_2samp
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw
import scipy.io as sio

'''
Name: MultiRun.py
Description: Takes in a good WT plot, then some number of bad plots
and prints the averages of the variables of interest
Input:
Output:
'''

def parse_args():
    parser = argparse.ArgumentParser(prog='MotorDeleteGraph.py')
    # General options that are actually required
    #parser.add_argument('-w', '--wildtype', required=True, type=str,
    #        help='WT sim')

    parser.add_argument('--sim', nargs='+', type=str,
            help='Simulation directory.')
    parser.add_argument('--simname', nargs='+', type=str,
            help='Simulation name for corresponding directory.')

    parser.add_argument('--stddevmean', action='store_true',
            help='Write out std. dev of mean, rather than sqrt variance')

    # Minimum of extra options to make this work
    parser.add_argument('--nopost', action='store_true',
            help="Do not use post analyis program to decrease time of analysis.")
    parser.add_argument('-F', '--fitness', type=str, default='WT_Cen2', nargs='?', const='WT_Cen2',
            help='Create fitness for spindle simulations.')
    parser.add_argument('-A', '--analyze', action='store_true',
            help='Analyze data from multiple simulations')
    parser.add_argument('--datadir', type=str,
            help='Name of the data directory in which all analyzed data files will be read/written. \
                    Also the directory where all the graphs will be placed when saved. \
                    Default is set to {workdir}/data/.')
    parser.add_argument('-d', '--workdir', type=str,
            help='Name of the working directory where simulation will be run.')

    opts = parser.parse_args()
    return opts


# Class definition
class MultiRun(object):
    def __init__(self, opts):
        self.opts = opts
        self.cwd = os.getcwd()
        self.fig_pretty_size = (2,2)

        print "opts: {}".format(opts)
        #FIXME: Nseeds
        self.nseeds = 12

        self.ReadOpts()
        self.AnalyzeSims()

    def ReadOpts(self):
        if not self.opts.workdir:
            self.opts.workdir = os.path.abspath(self.cwd)
        elif not os.path.exists(self.opts.workdir):
            raise IOError( "Working directory {} does not exist.".format(
                self.opts.workdir) )
        else:
            self.opts.workdir = os.path.abspath(self.opts.workdir)

        # Assert that the number of simulations equals the number of simulation names
        if len(self.opts.sim) != len(self.opts.simname):
            print "please give me simulation directories and names in an equal number!"
            sys.exit(1)

        #self.wt_dir = os.path.abspath(self.opts.wildtype)
        #self.wt_sim = SpindleSim(self.wt_dir, opts)

        self.simdirs = []
        self.sims = []
        for simdir in self.opts.sim:
            mydir = os.path.abspath(simdir)
            self.sims += [SpindleSim(mydir, opts)]

    def AnalyzeSims(self):
        #self.wt_sim.Analyze()
        #self.wt_sim.CalcSimSuccess()

        for sim in self.sims:
            sim.Analyze()
            sim.CalcSimSuccess()

        if self.opts.fitness:
            wt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'Data', 'wt.mat')
            lstream_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'Data', 'lstream.mat')
            #self.wt_sim.Fitness(wt_path, lstream_path)

            for sim in self.sims:
                sim.Fitness(wt_path, lstream_path)

    # Graph the avg spindle length for various delete strains (or just multiple strains)
    def GraphSpindleLength(self):
        plt.style.use(cp_spindle_stl)
        fig, ax = plt.subplots() 

        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.sims)))
        # Graph the specific versions of the input data averages
        self.average_lengths_late = {}
        self.fsb_sim = {}
        xtick_names = []
        #self.average_lengths_late['WT'] = self.GraphSpindleLengthAvg(fig, ax, self.wt_sim, r'WT', 'k')
        #xtick_names = ['WT']
        for isim in xrange(len(self.sims)):
            sim = self.sims[isim]
            simname = self.opts.simname[isim]
            self.average_lengths_late[simname] = self.GraphSpindleLengthAvg(fig, ax, sim, simname, colors[isim])
            self.fsb_sim[simname] = sim.fraction_integrated_biorientation_time_mean
            xtick_names += [simname]

        ax.set_xlabel(r'Time (min)')
        ax.set_ylabel(r'SPB Separation ($\mu$m)')
        ax.legend()
        ax.set_ylim(0.0, 3.0)
        fig.tight_layout() 
        plt.savefig('avg_spindle_length_vs_t.pdf', dpi=fig.dpi)
        plt.close()

        # Draw the late time spindle lengths from the average lengths dictionary
        fig, ax = plt.subplots()
        xvals = []
        yvals = []
        yerrs = []
        fsb = []
        ix = 1
        for simname in xtick_names:
            xvals += [ix]
            yvals += [np.mean(self.average_lengths_late[simname])]
            fsb += [self.fsb_sim[simname]]
            if self.opts.stddevmean:
                yerrs += [np.std(self.average_lengths_late[simname], ddof=1)/np.sqrt(self.nseeds)]
            else:
                yerrs += [np.std(self.average_lengths_late[simname], ddof=1)]
            ix += 1

        ax.scatter(xvals, yvals, zorder=100,
                   s=100, marker='s', color='k', label=None)
        ax.errorbar(xvals, yvals, yerr=yerrs,
                   ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
                   fmt='none', marker='none')
        ax.set_ylim(0.0, 3.0)
        #plt.xticks(xvals, xtick_names, rotation=45)
        plt.xticks(xvals, xtick_names)

        # Print the length data to screen (along with th fraction simultaneous biorientation measure)
        print "---------\nsimane = {}".format(xvals)
        print "late_lengths = {}".format(yvals)
        print "late_lengths_errs = {}".format(yerrs)
        print "fsb = {}".format(fsb)
        print "--------"

        ax.set_ylabel(r'Avg. late spindle length ($\mu$m)')
        fig.tight_layout()
        plt.savefig('late_spindle_length.pdf', dpi=fig.dpi)
        plt.close()


    def GraphSpindleLengthAvg(self, fig, ax, sim, label, color):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        # Redo the self.nmeasure to make it so we can actually see the error bars
        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]
        avg_std = np.zeros((len(sim.seeds), time_resampled.size))
        num_seeds = 0

        # Get the information on the per trace length fluctuations
        length_fluctuations = np.zeros(len(sim.seeds))
        late_variances = np.zeros(len(sim.seeds))
        average_lengths = []
        kckc_sep = []

        isd = 0
        for sd in sim.seeds:
            time = sd.PostAnalysis.timedata['time']
            spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
            dist_dict = sd.PostAnalysis.timedata['kc_distance']
            spbsep_arr = [spbsep_dict[ts] for ts in time ]
            spbsep_arr = np.array(spbsep_arr)*uc['um'][1]
            spbsep_arr = np.mean(spbsep_arr.reshape(-1, self.nmeasure), axis=1)
            dist_arr = np.array([dist_dict[ts] for ts in time])
            nchromo = len(dist_arr[0])/2
            avg_std[isd,:] = spbsep_arr[:min_size]
            # Grab the mean length and the spread from this
            avg_length = np.mean(spbsep_arr)
            dev_length = np.std(spbsep_arr, ddof=1)
            length_fluctuations[isd] = dev_length / avg_length
            # Grab the average lengths after 10 minutes
            long_time_lengths = spbsep_arr[time_resampled > 10.0]
            average_lengths += [long_time_lengths[:]]
            late_variances[isd] = np.std(long_time_lengths, ddof=1)
            num_seeds += 1
            isd += 1

            # Get the KC KC distance too, first, get only the late time array of the distances
            for ic in xrange(nchromo):
                dist0 = np.zeros(len(dist_arr))
                dist1 = np.zeros(len(dist_arr))
                for x in xrange(len(dist_arr)):
                    dist0[x] = dist_arr[x][2*ic]
                    dist1[x] = dist_arr[x][2*ic+1]
                kc_sep = np.fabs(dist1 - dist0)
                kc_sep_resampled = np.mean(kc_sep.reshape(-1, self.nmeasure), axis=1)
                kc_sep_resampled_long = kc_sep_resampled[time_resampled > 10.0]
                kckc_sep += [kc_sep_resampled_long[:]]

        avg = np.mean(avg_std, axis=0)
        if self.opts.stddevmean:
            avg_stddev = np.std(avg_std, axis=0, ddof=1) / np.sqrt(num_seeds)
        else:
            avg_stddev = np.std(avg_std, axis=0, ddof=1)
        ax.errorbar(x = time_resampled, y = avg, yerr = avg_stddev, color = color, label = label)
        print "{}: Avg final spindle length = {} +- {}".format(label, avg[-1], avg_stddev[-1])

        # Do the spindle length for all the times after 10 minutes in the simulations - this is enough to usually
        # get to metaphase length
        print "  {}: Long-time lengths = {} +- {}".format(label, np.mean(np.asarray(average_lengths)), np.std(np.asarray(average_lengths), ddof=1))

        # Write out the average per trace error over the average lengths to find length fluctuations
        print "  {}: Length fluctuation parameters = {}".format(label, np.mean(length_fluctuations))
        print "    : sqrt(var) = {}, {} +- {}".format(late_variances, np.mean(late_variances), np.std(late_variances, ddof=1)/np.sqrt(num_seeds))
        print "  {}: KC separation = {} +- {}".format(label, np.mean(np.asarray(kckc_sep)), np.std(np.asarray(kckc_sep), ddof=1)/np.sqrt(3*num_seeds))

        # Print the raw information on the fluctuations and the ikc stretch
        late_variances_all = ", ".join(repr(e) for e in late_variances)
        kckc_sep_all = ", ".join(repr(e) for e in np.asarray(kckc_sep).flatten())
        #print 'length_fluctuation, {}'.format(late_variances_all)
        #print '--------'
        #print 'ikc_stretch, {}'.format(kckc_sep_all)
        with open("moomoo.csv", "a") as mfile:
            mfile.write('length_fluctuation, {}\n'.format(late_variances_all))
            mfile.write('ikc_stretch, {}\n'.format(kckc_sep_all))

        return average_lengths

    # Graph the avg spindle length for various delete strains (or just multiple strains)
    def GraphSpindleIPF(self):
        plt.style.use(cp_spindle_stl)
        fig, ax = plt.subplots() 

        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.sims)))
        # Graph the specific versions of the input data averages
        self.average_ipf_late = {}
        xtick_names = []
        #self.average_ipf_late['WT'] = self.GraphSpindleIPFAvg(fig, ax, self.wt_sim, r'WT', 'k')
        #xtick_names = ['WT']
        for isim in xrange(len(self.sims)):
            sim = self.sims[isim]
            simname = self.opts.simname[isim]
            self.average_ipf_late[simname] = self.GraphSpindleIPFAvg(fig, ax, sim, simname, colors[isim])
            xtick_names += [simname]

        ax.set_xlabel(r'Time (min)')
        ax.set_ylabel(r'Interpolar Fraction')
        ax.legend()
        fig.tight_layout() 
        plt.savefig('spindle_ipf_vs_t.pdf', dpi=fig.dpi)
        plt.close()

        # Draw the late time spindle lengths from the average lengths dictionary
        fig, ax = plt.subplots()
        xvals = []
        yvals = []
        yerrs = []
        ix = 1
        for simname in xtick_names:
            xvals += [ix]
            yvals += [np.mean(self.average_ipf_late[simname])]
            if self.opts.stddevmean:
                yerrs += [np.std(self.average_ipf_late[simname], ddof=1)/np.sqrt(self.nseeds)]
            else:
                yerrs += [np.std(self.average_ipf_late[simname], ddof=1)]
            ix += 1

        ax.scatter(xvals, yvals, zorder=100,
                   s=100, marker='s', color='k', label=None)
        ax.errorbar(xvals, yvals, yerr=yerrs,
                   ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
                   fmt='none', marker='none')
        ax.set_ylim(0.0, 1.0)
        plt.xticks(xvals, xtick_names, rotation=45)

        ax.set_ylabel(r'Interpolar Fraction')
        fig.tight_layout()
        plt.savefig('late_ipf.pdf', dpi=fig.dpi)
        plt.close()


    def GraphSpindleIPFAvg(self, fig, ax, sim, label, color):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]
        avg_std = np.zeros((len(sim.seeds), time_resampled.size))
        num_seeds = 0
        average_ipfs = []

        isd = 0

        for sd in sim.seeds:
            time = sd.PostAnalysis.timedata['time']
            ipf_dict = sd.PostAnalysis.timedata['interpolar_fraction']
            ipf_arr = [ipf_dict[ts] for ts in time ]
            ipf_arr = np.array(ipf_arr)
            ipf_arr = np.mean(ipf_arr.reshape(-1, self.nmeasure), axis=1)
            avg_std[isd,:] = ipf_arr[:min_size]
            long_time_ipf = ipf_arr[time_resampled > 10.0]
            average_ipfs += [ipf_arr[:]]
            num_seeds += 1
            isd += 1

        avg = np.mean(avg_std, axis=0)
        avg_stddev = np.std(avg_std, axis=0, ddof=1) / np.sqrt(num_seeds)
        ax.errorbar(x = time_resampled, y = avg, yerr = avg_stddev, color = color, label = label)
        print "{}: Avg final spindle IPF = {} +- {}".format(label, avg[-1], avg_stddev[-1])

        # Do the spindle length for all the times after 10 minutes in the simulations - this is enough to usually
        # get to metaphase length
        print "  {}: Long-time IPF = {} +- {}".format(label, np.mean(np.asarray(average_ipfs)), np.std(np.asarray(average_ipfs), ddof=1))
        return average_ipfs

    # Graph the merotelic attachment state for the simulation over time
    def GraphSpindleAttachStates(self):
        plt.style.use(cp_spindle_stl)
        attachment_types = ['unattached',
                            'monotelic',
                            'merotelic',
                            'syntelic',
                            'amphitelic',
                            'singlepole',
                            'doublepole']
        attachment_ylabels = {'unattached' : 'unattached',
                'monotelic' : 'monotelic',
                'merotelic' : 'merotelic',
                'syntelic' : 'syntelic',
                'amphitelic' : 'amphitelic',
                'singlepole' : 'unattached or attached\nto a single pole',
                'doublepole' : 'attached to both poles'}
        for atype in attachment_types:
            fig, ax = plt.subplots()

            colors = mpl.cm.rainbow(np.linspace(0,1,len(self.sims)))
            # Graph the specific versions of the input data averages
            self.average_state_late = {}
            self.average_occ_late = {}
            xtick_names = []
            #self.average_state_late['WT'] = self.GraphSpindleStateAvg(fig, ax, self.wt_sim, atype, r'WT', 'k') 
            #xtick_names = ['WT']
            for isim in xrange(len(self.sims)):
                sim = self.sims[isim]
                simname = self.opts.simname[isim]
                temp_average_state_late = self.GraphSpindleStateAvg(fig, ax, sim, atype, simname, colors[isim])
                self.average_state_late[simname] = temp_average_state_late[0]
                self.average_occ_late[simname] = temp_average_state_late[1]
                #print "sim: {}, occ_late: {}".format(simname, self.average_occ_late[simname])
                xtick_names += [simname]

            ax.set_xlabel(r'Time (min)')
            ax.set_ylabel('Number of chromosomes that are {}'.format(atype))
            ax.legend()
            fig.tight_layout() 
            plt.savefig('spindle_{}_vs_t.pdf'.format(atype), dpi=fig.dpi)
            plt.close()

            # Draw the late time spindle lengths from the average lengths dictionary
            fig, ax = plt.subplots()
            xvals = []
            yvals = []
            yerrs = []
            occvals = []
            occerrs = []
            ix = 1
            for simname in xtick_names:
                xvals += [ix]
                yvals += [np.mean(self.average_state_late[simname])]
                occvals += [np.mean(self.average_occ_late[simname])]
                if self.opts.stddevmean:
                    yerrs += [np.std(self.average_state_late[simname], ddof=1)/np.sqrt(self.nseeds)]
                    occerrs += [np.std(self.average_occ_late[simname], ddof=1)/np.sqrt(self.nseeds)]
                else:
                    yerrs += [np.std(self.average_state_late[simname], ddof=1)]
                    occerrs += [np.std(self.average_occ_late[simname], ddof=1)]
                ix += 1

            ax.scatter(xvals, yvals, zorder=100,
                       s=100, marker='s', color='k', label=None)
            ax.errorbar(xvals, yvals, yerr=yerrs,
                       ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
                       fmt='none', marker='none')
            ax.set_ylim(0.0, 3.1)
            #plt.xticks(xvals, xtick_names, rotation=45)
            plt.xticks(xvals, xtick_names)

            ax.set_ylabel('Number of chromosomes\nthat are {}'.format(attachment_ylabels[atype]))
            fig.tight_layout()
            plt.savefig('late_{}.pdf'.format(atype), dpi=fig.dpi)
            plt.close()

            # Draw the late time occupancy to single/double pole (or really whatever)
            fig, ax = plt.subplots()
            ax.scatter(xvals, occvals, zorder=100,
                       s=100, marker='s', color='k', label=None)
            ax.errorbar(xvals, occvals, yerr=occerrs,
                       ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
                       fmt='none', marker='none')
            ax.set_ylim(0.0, 18.1)
            #plt.xticks(xvals, xtick_names, rotation=45)
            plt.xticks(xvals, xtick_names)

            ax.set_ylabel('Occupancy of chromosomes\nthat are {}'.format(attachment_ylabels[atype]))
            fig.tight_layout()
            plt.savefig('late_occ_{}.pdf'.format(atype), dpi=fig.dpi)
            plt.close()

            # Save the single/doublepole information and plot on a single plot!
            if atype == 'singlepole':
                singlepole_yvals = yvals
                singlepole_yerrs = yerrs
                singlepole_xvals = xvals
            elif atype == 'doublepole':
                doublepole_yvals = yvals
                doublepole_yerrs = yerrs
                doublepole_xvals = xvals


        # Plot the single and double on the same graph
        fig, ax = plt.subplots()
        # Single pole black on the left
        ax.scatter(singlepole_xvals, singlepole_yvals, zorder=100,
                   s=100, marker='s', color='k', label=None)
        ax.errorbar(singlepole_xvals, singlepole_yvals, yerr=singlepole_yerrs,
                   ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
                   fmt='none', marker='none')

        # Grab the axis and twin it
        ax2 = ax.twinx()
        ax2.scatter(doublepole_xvals, doublepole_yvals, zorder=100,
                   s=100, marker='s', color='r', label=None)
        ax2.errorbar(doublepole_xvals, doublepole_yvals, yerr=doublepole_yerrs,
                   ecolor='r', elinewidth=2, capsize=7, capthick=1, zorder=0,
                   fmt='none', marker='none')

        ax.set_ylim(0.0, 3.1)
        ax2.set_ylim(0.0, 3.1)

        plt.xticks(xvals, xtick_names)

        ax.set_ylabel('Number of chromosomes\nthat are {}'.format(attachment_ylabels['singlepole']))
        ax2.set_ylabel('Number of chromosomes\nthat are {}'.format(attachment_ylabels['doublepole']), color = 'r')
        ax2.tick_params('y', colors = 'r')
        fig.tight_layout()
        plt.savefig('late_combined_ontop.pdf', dpi=fig.dpi)
        plt.close()


    def GraphSpindleStateAvg(self, fig, ax, sim, atype, label, color):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]
        avg_std = np.zeros((len(sim.seeds), time_resampled.size))
        num_seeds = 0
        average_state = []
        average_occ = []

        isd = 0

        for sd in sim.seeds:
            #print sd.succ_info_dict
            time = sd.PostAnalysis.timedata['time']
            attach_dict = sd.PostAnalysis.timedata['kc_atypes']
            attach_arr = [attach_dict[ts] for ts in time]
            # Do some occuapncy stuff too, for both singlepole and doublepole, to check the turnover
            # for length flutuations
            occupancy_dict = sd.PostAnalysis.timedata['kc_occupancy_x']
            occupancy_arr = np.array([occupancy_dict[ts] for ts in time])
            nchromo = len(attach_arr[0])
            #atype = 'merotelic'
            if atype == 'unattached':
                atypeint = [0]
            elif atype == 'monotelic':
                atypeint = [1]
            elif atype == 'merotelic':
                atypeint = [2]
            elif atype == 'syntelic':
                atypeint = [3]
            elif atype == 'amphitelic':
                atypeint = [4]
            elif atype == 'singlepole':
                atypeint = [0, 1, 3]
            elif atype == 'doublepole':
                atypeint = [2, 4]

            plot_arr = np.zeros(len(attach_arr))
            occupancy_arr_x = np.zeros(len(attach_arr))
            
            # Regenerate the merotelic numbers
            for x in xrange(len(attach_arr)):
                for ic in xrange(nchromo):
                    attach = attach_arr[x][ic]
                    if attach in atypeint:
                        plot_arr[x] += 1.0
                        occupancy_arr_x[x] += occupancy_arr[x][ic]
            plot_arr = np.mean(plot_arr.reshape(-1, self.nmeasure), axis=1)
            occupancy_arr_x = np.mean(occupancy_arr_x.reshape(-1, self.nmeasure), axis=1)
            avg_std[isd,:] = plot_arr[:min_size]
            long_time_state = plot_arr[time_resampled > 10.0]
            long_time_occ   = occupancy_arr_x[time_resampled > 10.0]
            average_state += [long_time_state[:]]
            average_occ += [long_time_occ[:]]
            num_seeds += 1
            isd += 1

        avg = np.mean(avg_std, axis=0)
        avg_stddev = np.std(avg_std, axis=0, ddof=1) / np.sqrt(num_seeds)
        ax.errorbar(x = time_resampled, y = avg, yerr = avg_stddev, color = color, label = label)
        print "{}: Avg final spindle State {} = {} +- {}".format(label, atype, avg[-1], avg_stddev[-1])

        # Do the spindle length for all the times after 10 minutes in the simulations - this is enough to usually
        # get to metaphase length
        print "  {}: Long-time state {} = {} +- {}".format(label, atype, np.mean(np.asarray(average_state)), np.std(np.asarray(average_state), ddof=1))
        return [average_state, average_occ]

    # Graph the spindle occupany of KC-MT binding sites
    def GraphOccupancyBindingSites(self):
        plt.style.use(cp_spindle_stl)
        fig, ax = plt.subplots()

        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.sims)))
        # Graph the specific versions of the input data averages
        self.average_occupancy_late = {}
        xtick_names = []
        #self.average_occupancy_late['WT'] = self.GraphOccupancyAvg(fig, ax, self.wt_sim, r'WT', 'k')
        #xtick_names = ['WT']
        for isim in xrange(len(self.sims)):
            sim = self.sims[isim]
            simname = self.opts.simname[isim]
            self.average_occupancy_late[simname] = self.GraphOccupancyAvg(fig, ax, sim, simname, colors[isim])
            xtick_names += [simname]

        ax.set_xlabel(r'Time (min)')
        ax.set_ylabel('Occupancy')
        ax.legend()
        fig.tight_layout() 
        plt.savefig('spindle_occupancy_vs_t.pdf', dpi=fig.dpi)
        plt.close()

        # Draw the late time spindle lengths from the average lengths dictionary
        fig, ax = plt.subplots()
        xvals = []
        yvals = []
        yerrs = []
        ix = 1
        for simname in xtick_names:
            xvals += [ix]
            yvals += [np.mean(self.average_occupancy_late[simname])]
            if self.opts.stddevmean:
                yerrs += [np.std(self.average_occupancy_late[simname], ddof=1)/np.sqrt(self.nseeds)]
            else:
                yerrs += [np.std(self.average_occupancy_late[simname], ddof=1)]
            ix += 1

        ax.scatter(xvals, yvals, zorder=100,
                   s=100, marker='s', color='k', label=None)
        ax.errorbar(xvals, yvals, yerr=yerrs,
                   ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
                   fmt='none', marker='none')
        ax.set_ylim(0.0, 18.1)
        plt.xticks(xvals, xtick_names, rotation=45)

        ax.set_ylabel(r'Occupancy')
        fig.tight_layout()
        plt.savefig('late_occupancy.pdf', dpi=fig.dpi)
        plt.close()


    def GraphOccupancyAvg(self, fig, ax, sim, label, color):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]
        avg_std = np.zeros((len(sim.seeds), time_resampled.size))
        num_seeds = 0
        average_occupancy = []

        isd = 0

        for sd in sim.seeds:
            time = sd.PostAnalysis.timedata['time']
            occupancy_dict = sd.PostAnalysis.timedata['kc_occupancy']
            occupancy_arr = np.array([occupancy_dict[ts] for ts in time])
            occupancy_arr = np.mean(occupancy_arr.reshape(-1, self.nmeasure), axis=1)
            avg_std[isd,:] = occupancy_arr[:min_size]
            late_time_occupancy = occupancy_arr[time_resampled > 10.0]
            average_occupancy += [late_time_occupancy[:]]
            num_seeds += 1
            isd += 1

        avg = np.mean(avg_std, axis=0)
        avg_stddev = np.std(avg_std, axis=0, ddof=1) / np.sqrt(num_seeds)
        ax.errorbar(x = time_resampled, y = avg, yerr = avg_stddev, color = color, label = label)
        print "{}: Avg final spindle Occupancy = {} +- {}".format(label, avg[-1], avg_stddev[-1])

        # Do the spindle length for all the times after 10 minutes in the simulations - this is enough to usually
        # get to metaphase length
        print "  {}: Long-time occupancy = {} +- {}".format(label, np.mean(np.asarray(average_occupancy)), np.std(np.asarray(average_occupancy), ddof=1))
        return average_occupancy

    def GraphBiorientationMeasure(self):
        plt.style.use(cp_spindle_stl)
        fig, ax = plt.subplots()

        # Graph the specific versions of the input data averages
        self.fsb = {}
        self.fsberr = {}
        xtick_names = []
        for isim in xrange(len(self.sims)):
            sim = self.sims[isim]
            simname = self.opts.simname[isim]
            self.fsb[simname] = sim.fraction_integrated_biorientation_time_mean
            self.fsberr[simname] = sim.fraction_integrated_biorientation_time_std
            xtick_names += [simname]

        xvals = []
        yvals = []
        yerrs = []
        ix = 1
        for simname in xtick_names:
            xvals += [ix]
            yvals += [self.fsb[simname]]
            if self.opts.stddevmean:
                yerrs += [self.fsberr[simname]/np.sqrt(self.nseeds)]
            else:
                yerrs += [self.fsberr[simname]]
            ix += 1

        ax.scatter(xvals, yvals, zorder=100,
                   s=100, marker='s', color='k', label=None)
        ax.errorbar(xvals, yvals, yerr=yerrs,
                   ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
                   fmt='none', marker='none')
        ax.set_ylim(0.0, 1.1)
        plt.xticks(xvals, xtick_names, rotation=45)

        print "----------Biorientation measures from MultiRun----------"
        print xvals
        print yvals
        print yerrs

        ax.set_ylabel(r'Fraction simultaneous biorientation')
        fig.tight_layout()
        plt.savefig('spindle_fsb.pdf', dpi=fig.dpi)
        plt.close()

    # Graph the forces vs spindle length for the different simulations that were fed in
    def GraphForces(self):
        # First, helper funciton like what is in SingleRunChromosomes to harvest the information in the simulations

        # Graph the specific versions of the input data averages
        xtick_names = []
        self.df = {}
        for isim in xrange(len(self.sims)):
            sim = self.sims[isim]
            simname = self.opts.simname[isim]
            self.df[simname] = self.HarvestForceInformation(sim)
            xtick_names += [simname]

        # Now that we have all the forces, graph them as K5, K14, XL, and AF
        # Remember, this is the axis forces

        # K5
        self.GraphSpindleLengthForcesContribution('K5')
        self.GraphSpindleLengthForcesContribution('K14')
        self.GraphSpindleLengthForcesContribution('XL')
        self.GraphSpindleLengthForcesContribution('AF')

    def GraphSpindleLengthForcesContribution(self, name):
        plt.style.use(cp_spindle_stl)
        fig, ax = plt.subplots()

        bins = np.linspace(0.0, 2.75, 11)
        bin_mids = np.linspace(0.0, 2.475, 10) + 0.275/2. # FIXME: Off by 1
        marker = itertools.cycle(('s', 'o', 'v', '^', 'H', 'D', '8', 'p', '<','>','*','h','d'))
        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.sims)))
        for isim in xrange(len(self.sims)):
            simname = self.opts.simname[isim]
            df = self.df[simname]
            df['bins'] = pd.cut(df['spindle_length'], bins)
            df_means = df.groupby('bins').mean()
            df_std = df.groupby('bins').std() / np.sqrt(12)

            if name == 'K5':
                df_name = 'xlink_type0_axis_force'
                mcolor = 'r'
            if name == 'K14':
                df_name = 'xlink_type1_axis_force'
                mcolor = 'b'
            if name == 'XL':
                df_name = 'xlink_type2_axis_force'
                mcolor = 'k'
            if name == 'AF':
                df_name = 'af_axis_force'
                mcolor = 'c'

            ax.scatter(x = bin_mids, y = df_means[df_name], zorder=100, s=100, color = colors[isim], label = simname, marker = marker.next(), facecolors = 'none')
            ax.errorbar(x = bin_mids, y = df_means[df_name], yerr = df_std[df_name], ecolor = colors[isim], elinewidth = 2,
                    capsize = 7, capthick = 1, zorder = 0, fmt = 'none', marker = 'none', label = None)

        ax.set_ylim(-100.0, 100.0)
        ax.set_xlabel('Spindle length ($\mu$m)')
        ax.set_ylabel('{} Force (pN)'.format(name))
        ax.legend()

        fig.tight_layout()
        plt.savefig('spindle_forces_vs_length_{}.pdf'.format(name), dpi=fig.dpi)
        plt.close()


    def HarvestForceInformation(self, sim):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))
        self.nmeasure = 4 * self.nmeasure

        colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]

        num_seeds = 0

        for sd in sim.seeds:
            time = sd.PostAnalysis.timedata['time']
            time_resampled = time[0::self.nmeasure]

            # We want the spindle lengths for binning the appropriate force distribution
            spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
            spindle_length_arr = np.array([spbsep_dict[ts] for ts in time])

            # The forces are stored as a dictionary, so get each contribution
            forces_dict_0 = sd.PostAnalysis.timedata['pole_forces'][0]

            # The first entry is the time, then it is a dictionary of the different forces
            forces_arr_0_subtype_0 = np.array([forces_dict_0[ts]['xlink_mag_subtype'][0] for ts in time])
            forces_arr_0_subtype_1 = np.array([forces_dict_0[ts]['xlink_mag_subtype'][1] for ts in time])
            forces_arr_0_subtype_2 = np.array([forces_dict_0[ts]['xlink_mag_subtype'][2] for ts in time])
            forces_axis_0_subtype_0 = np.array([forces_dict_0[ts]['xlink_axis_subtype'][0] for ts in time])
            forces_axis_0_subtype_1 = np.array([forces_dict_0[ts]['xlink_axis_subtype'][1] for ts in time])
            forces_axis_0_subtype_2 = np.array([forces_dict_0[ts]['xlink_axis_subtype'][2] for ts in time])
            # and the total Xlink force
            forces_axis_0_xlink = np.array([forces_dict_0[ts]['xlink_axis'] for ts in time])

            # Grab the AF forces too
            forces_arr_0_af = np.array([forces_dict_0[ts]['af_mag'] for ts in time])
            forces_axis_0_af = np.array([forces_dict_0[ts]['af_axis'] for ts in time])

            # Redo the mean average stuff
            forces_arr_0_subtype_0 = np.mean(forces_arr_0_subtype_0.reshape(-1, self.nmeasure), axis=1)
            forces_arr_0_subtype_1 = np.mean(forces_arr_0_subtype_1.reshape(-1, self.nmeasure), axis=1)
            forces_arr_0_subtype_2 = np.mean(forces_arr_0_subtype_2.reshape(-1, self.nmeasure), axis=1)
            forces_arr_0_af = np.mean(forces_arr_0_af.reshape(-1, self.nmeasure), axis=1)

            forces_axis_0_subtype_0 = np.mean(forces_axis_0_subtype_0.reshape(-1, self.nmeasure), axis=1)
            forces_axis_0_subtype_1 = np.mean(forces_axis_0_subtype_1.reshape(-1, self.nmeasure), axis=1)
            forces_axis_0_subtype_2 = np.mean(forces_axis_0_subtype_2.reshape(-1, self.nmeasure), axis=1)
            forces_axis_0_xlink = np.mean(forces_axis_0_xlink.reshape(-1, self.nmeasure), axis=1)
            forces_axis_0_af = np.mean(forces_axis_0_af.reshape(-1, self.nmeasure), axis=1)

            spindle_length_arr = np.mean(spindle_length_arr.reshape(-1, self.nmeasure), axis=1) * uc['um'][1]

            # Load the information of the spindle length and forces into a self consistent dictionary
            if spindle_length_arr.shape[0] != forces_axis_0_af.shape[0]:
                print "moo moo bad"
                sys.exit(1)

            # Create a dataframe of the information to bin later
            # Convert with the minus sign here!
            data = {'spindle_length': spindle_length_arr, 'xlink_type0_axis_force': -1.0 * forces_axis_0_subtype_0,
                    'xlink_type1_axis_force': -1.0 * forces_axis_0_subtype_1, 'xlink_type2_axis_force': -1.0 * forces_axis_0_subtype_2,
                    'af_axis_force': -1.0 * forces_axis_0_af}
            df = pd.DataFrame(data)
            #print df

            # Load the first one
            if num_seeds == 0:
                self.length_force_dataframe_tmp = df
            else:
                self.length_force_dataframe_tmp = self.length_force_dataframe_tmp.append(df, ignore_index = True)

            num_seeds += 1

        # Return the dataframe from tmp
        return self.length_force_dataframe_tmp


################
if __name__ == "__main__":
    opts = parse_args()
    x = MultiRun(opts)
    x.GraphSpindleLength()
    x.GraphSpindleIPF()
    x.GraphSpindleAttachStates()
    x.GraphOccupancyBindingSites()

    x.GraphBiorientationMeasure()
    x.GraphForces()
