#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# Basic
import sys, os, pdb
## Analysis
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *
from spindle_unit_dict import SpindleUnitDict
from base_funcs import moving_average
try: import line_profiler
except: pass

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
from graph_lib import *

'''
Name: seed_graph_funcs.py
Description: library of graphing functions for seed specific graphs
'''

uc = SpindleUnitDict()

# Helper functions
def get_spindle_xlink_distr_histogram( dd, xstate=r'integrated', xstage='merge'):
    """ Helper function to get correct histogram data for xlink spindle distributions.
        Input:  dd = distribution dataframe, panda dataframe containing 
                     bin locations and number of xlinks binned there
                xstate = (integrated) sum of all xlink positions from a time  or
                         (final) location of xlinks at the final state
                xstage = (stage1, stage2, merge) stage of xlinks to be graphed
        Output: bin position array, count array for bins, total number of counts,
                xstage string 
    """
    xl = {}
    if xstage != 'merge':
    # if 'xstage' in kwargs and not (kwargs['xstage'] == 'merge'):
        key = 'spindle_xlink_distance_{}'.format(xstage)
        xl = dd[key][xstate]
    else: # Combine both stages together
        key = 'merge'
        xl1 = dd['spindle_xlink_distance_stage1'][xstate]
        xl2 = dd['spindle_xlink_distance_stage2'][xstate]
        for k in ('hist', 'n_points'): xl[k] = np.add(xl1[k], xl2[k])
        xl['mid_points'] = xl1['mid_points']
    return xl['mid_points'], xl['hist'], xl['n_points'], xstage

def graph_avg_mt_length(sd, ax, color = 'b', spb_ind=2, xlabel=True):
    """ Function to graph the average MT length on the first, second, or all spbs vs time
        Input:  ax = matplotlib axis object
                sd = seed that contains data for graph. Must have a PostAnalysis object.
                color = color of the graphed line
                spb_ind = index of the spb with mts to be analyzed. 0=first, 1=second, 2=all mts
                xlabel = whether or not to put on an xlabel
        Output: avgMTlen_arr = numpy array of average MT lengths as a function of time
    """
    time = sd.PostAnalysis.timedata['time'] # list of simulation times
    avgMTlen_dict = sd.PostAnalysis.timedata['avg_mt_length']
    avgMTlen_arr = []
    # Convert data dictionary into a python list 
    for ts in time: avgMTlen_arr += [avgMTlen_dict[ts][spb_ind]]
    # Dimensionalize array and transform into numpy array
    avgMTlen_arr = np.array(avgMTlen_arr)*uc['um'][1]
    # Graphing options
    if spb_ind == 2: ax.set_title(r'Average MT Lengths')
    else: ax.set_title(r'Average MT Lengths on SPB{}'.spb_ind)
    ax.set_ylabel(r'MT Length ($\mu$m)')
    if xlabel: ax.set_xlabel(r'Time (min)')
    # Make graph
    ax.plot(time, avgMTlen_arr, color=color, label=sd.label)
    return avgMTlen_arr

def graph_spb_sep(sd, ax, color = 'b', xlabel=True):
    """ Function to graph the distance between SPBs vs time
        Input:  sd = seed that contains data for graph. Must have a PostAnalysis object.
                ax = matplotlib axis object
                color = color of the graphed line
                xlabel = whether or not to put on an xlabel
        Output: spbsep_arr = numpy array of SPB separation vs time
    """
    time = sd.PostAnalysis.timedata['time']
    spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
    # Convert data dictionary into a python list 
    spbsep_arr = [spbsep_dict[ts] for ts in time ]
    # Dimensionalize and transform into numpy array
    spbsep_arr = np.array(spbsep_arr)*uc['um'][1]
    # Graphing options
    ax.set_title(r'SPB Separation')
    ax.set_ylabel(r'Distance ($\mu$m)')
    if xlabel: ax.set_xlabel(r'Time (min)') 
    # Make graph
    #for ts in time:
        #print "{} {}".format(ts, spbsep_dict[ts])
        #print "{} {}".format(ts, 10.0 + random.uniform(-5.0, 5.0)) # Random flat
        #print "{} {}".format(ts, 10.0 + (100.0/32.7185)*ts + random.uniform(-5.0, 5.0)) # random with slope
    ax.plot(time, spbsep_arr, color=color, label=sd.label)
    return spbsep_arr

def graph_interpolar_fraction(sd, ax, color='b', xlabel=True, label='sd'):
    """ Function to graph the fraction of MTs that are considered to be 
        interpolar vs time.  Interpolarity is defined as MTs that are 
        oriented so that their dot product is negative and there is a 
        distance between the two <40 nm for longer than 300nm.
        Input:  sd = seed that contains data for graph. Must have a PostAnalysis object.
                ax = matplotlib axis object
                color = color of the graphed line
                xlabel = whether or not to put on an xlabel
                label = legend label. If 'sd' then use the seed objects label.
        Output: if_arr = numpy array of interpolar fraction vs time
    """
    time = sd.PostAnalysis.timedata['time']
    if_dict = sd.PostAnalysis.timedata['interpolar_fraction']
    # Convert data dictionary into a python list 
    if_arr = [if_dict[ts] for ts in time ]
    if_arr = np.array(if_arr)
    # Graphing options
    if label == 'sd': label = sd.label
    ax.set_title(r'Interpolar Fraction of MTs')
    if xlabel: ax.set_xlabel(r'Time (min)')
    ax.set_ylabel('# Interpolar/\nTotal MTs')
    # Make graph
    ax.plot(time, if_arr, color=color, label=sd.label)
    return if_arr

def graph_interpolar_length_fraction(sd, ax, color='b', xlabel=True, label='sd'):
    """ Function to graph the fraction of overlapping interpolar MT lengths 
        over total MT lengths.  Interpolarity is defined as MTs that are 
        oriented so that their dot product is negative and there is a distance 
        between the two <40 nm for longer than 300nm.
        Input:  sd = seed that contains data for graph. Must have a PostAnalysis object.
                ax = matplotlib axis object
                color = color of the graphed line
                xlabel = whether or not to put on an xlabel
                label = legend label. If 'sd' then use the seed objects label.
        Output: ifl_arr = numpy array of interpolar length fraction vs time
    """
    time = sd.PostAnalysis.timedata['time']
    ifl_dict = sd.PostAnalysis.timedata['interpolar_length_fraction']
    ifl_arr = [ifl_dict[ts] for ts in time ]

    if label == 'sd': label = sd.label

    ax.set_title("Interpolar Length Fraction of MTs")

    if xlabel: ax.set_xlabel(r'Time (min)')
    ax.set_ylabel("Total Interpolar Length/\nMax Overlap Length")

    ax.plot(time, ifl_arr, color=color, label=sd.label)

    return ifl_arr

def graph_num_xlinks(sd, ax, species_ind=2, color='b', xlabel=True, label='sd'):
    """ Function to graph the number of a crosslinker species stage over time
        Input:  sd = seed that contains data for graph. Must have a PostAnalysis object.
                ax = matplotlib axis object
                species_ind = stage of the crosslink species to be graphed, 
                              3 graphs all stages
                color = color of the graphed line
                xlabel = whether or not to put on an xlabel
                label = legend label. If 'sd' then use the seed objects label.
        Output: xlink_arr = numpy array of the number crosslinkers vs time 
    """
    # TODO make this also take in species of xlink. Need to 
    # Get data for graphing
    time = sd.PostAnalysis.timedata['time']
    num_xlinks = sd.PostAnalysis.timedata['num_xlinks']
    # Options for graph display
    if label == 'sd': label = sd.label
    if species_ind==3: ax.set_title("Total Number of Xlinks")
    else: ax.set_title("Total Number of Stage{} Xlinks".format(species_ind))
    if xlabel: ax.set_xlabel(r'Time (min)')
    ax.set_ylabel(r'Number of Crosslinks')
    # Fill in xlink_array with number xlinks at each time step
    xlink_arr = [num_xlinks[tf][species_ind] for ts in time]
    xlink_arr = np.array(xlink_arr)
    # Make graph
    ax.plot(time, xlink_arr, color=color, label=sd.label)
    return xlink_arr

def graph_mt_length_distributions( sd, ax, spb='merge', color='b', xlabel=True,
                                   label='sd', **kwargs ):
    """ Function to graph a histogram of MT length distributions for those attached to  
        a specific SPB or for all SPBs
        Input:  sd = seed that contains data for graph. Must have a PostAnalysis object.
                ax = matplotlib axis object
                spb = string of the name the spb the mts came from
                      i.e. spb1, spb2, merge. Merge combines the data from the two.
                color = color of the graphed line
                xlabel = whether or not to put on an xlabel
                label = legend label. If 'sd' then use the seed objects label.
                kwargs = dictionary of key word arguments
        Output: None
    """
    # Graphing options 
    if spb == 'merge': ax.set_title("Total MT Length Distribution")
    else: ax.set_title("{} MT Length Distribution".format(spb))
    if xlabel: ax.set_xlabel(r'Length ($\mu$m)')
    ax.set_ylabel(r'Probability Density')
    if label == 'sd': label=sd.label
    # Get data and manipulate data
    mt_l_distr = sd.PostAnalysis.distrdata['mt_lengths'][spb]
    bin_mids = mt_l_distr['mid_points']
    hist = mt_l_distr['hist']
    n_points = mt_l_distr['n_points']
    # Histogram into a probability distribution histogram
    hist = np.divide(hist, (bin_mids[1]-bin_mids[0])*n_points)
    ax.plot(bin_mids,
            hist,
            color = color,
            label = label )
    return

def graph_all_mt_length_distr(sd, axarr, color='b', succ_type=2, **kwargs):
    """ Function to graph a histogram of MT length distributions for those attached to 
        spb1, spb2, and then the combined data. Graph can be differentiated by spindle
        type. You can choose to graph  for monopolar, bipolar spindles, or both.
        Input:  sd = seed that contains data for graph. Must have a PostAnalysis object.
                axarr = 3x1 numpy array of matplotlib axis object
                color = color of the graphed line
                succ_type = type of spindle to use in graphing (i.e. 0 = monopolar, 
                            1 = bipolar, 2 = both). If seeds spindle is the not the 
                            same type nothing will be graphed.
                xlabel = whether or not to put on an xlabel
                label = legend label. If 'sd' then use the seed objects label.
                kwargs = dictionary of key word arguments
        Output: None
    """
    # Make sure seed is the correct spindle/success type 
    if ( (sd.succ_info_dict['succ'] == succ_type) or 
          (succ_type == 2) ):
        graph_mt_length_distributions(sd, axarr[0], spb='spb1', color=color,
                xlabel=False, **kwargs)
        graph_mt_length_distributions(sd, axarr[1], spb='spb2', color=color,
                xlabel=False,  **kwargs)
        graph_mt_length_distributions(sd, axarr[2], spb='merge', color=color, 
                xlabel=True, **kwargs)
        # Clean up this particular figure
        fig = axarr[0].get_figure()
        fig.tight_layout()
    return

def graph_mt_length_distr_by_index(sd, ax, **kwargs):
    """ Function to graph mt length distributions by the index given in 
        newagebob. Used mostly for testing purposes.
        Input:  sd = seed that contains data for graph. Must have a PostAnalysis object.
                ax = matplotlib axis object
                kwargs = dictionary of key word arguments
        Output: None
    """
    ax.set_title("Individual MT Length Distributions")
    ax.set_ylabel(r'Probability Density')
    ax.set_xlabel(r'Length ($\mu$m)')

    mt_l_dict = sd.PostAnalysis.distrdata['mt_lengths_by_index']
    n_bonds = sd.PostAnalysis.n_bonds
    colors = mpl.cm.rainbow(np.linspace(0,1,n_bonds))

    # Loop over each of the microtubules
    for i, c in zip(xrange(0, n_bonds),colors):
        ax.plot(mt_l_dict[i]['bin_mids'], mt_l_dict[i]['hist'], 
                color=c, label = 'mt_{}'.format(i))
    legend_outside(ax)
    # TODO Remove after a bit
    plt.savefig("graph_mt_distr_error.png")
    return

##########################################

def graph_avg_mt_splay(sd, ax, spb_ind=2, color='b', xlabel=True, label='sd'):
    if spb_ind==2: ax.set_title("Average Total MT Splay")
    else: ax.set_title("Average MT splay of SPB{} ".format(spb_ind))

    if xlabel: ax.set_xlabel(r'Time (min)')
    ax.set_ylabel(r'Angle (rad)')

    if label == 'sd': label=sd.label

    time = sd.PostAnalysis.timedata['time']

    splay_dict = sd.PostAnalysis.timedata['mt_splay']
    splay_arr = [splay_dict[ts][spb_ind] for ts in time ]

    splay_arr = np.arccos(np.array(splay_arr))

    ax.plot(time, splay_arr, color=color, label=label)
    return splay_arr

# Graph a single one of the attachment types
def graph_kc_attachment_type(sd, ax, atype='amphitelic', color='b', xlabel=True, label='sd'):
    ax.set_title("{} Attachment Type".format(atype))

    if xlabel: ax.set_xlabel(r'Time (min)')
    ax.set_ylabel("N Chromosomes")

    time = sd.PostAnalysis.timedata['time']
    attach_dict = sd.PostAnalysis.timedata['kc_atypes']
    attach_arr = [attach_dict[ts] for ts in time]
    nchromo = len(attach_arr[0])

    atypeint = 4 # Default amphitelic attachment

    if atype == 'unattached':
        atypeint = 0
    elif atype == 'monotelic':
        atypeint = 1
    elif atype == 'merotelic':
        atypeint = 2
    elif atype == 'syntelic':
        atypeint = 3
    elif atype == 'amphitelic':
        atypeint = 4

    plot_arr = np.zeros(len(attach_arr))

    for x in xrange(len(attach_arr)):
        for ic in xrange(nchromo):
            attach = attach_arr[x][ic]
            if attach == atypeint:
                plot_arr[x] += 1.0

    ax.plot(time, plot_arr, color=color, label=label)
    ax.set_ylim([0.0, nchromo])
    return plot_arr

# Requires at least 6 different plots to do this correctly
def graph_kc_attachment_types(sd, axarr, color='b', xlabel=True, label='sd'):
    
    if xlabel: axarr[4].set_xlabel(r'Time(min)')
    axarr[0].set_title(r'Unattached')
    axarr[1].set_title(r'Monotelic')
    axarr[2].set_title(r'Merotelic')
    axarr[3].set_title(r'Syntelic')
    axarr[4].set_title(r'Amphitelic')
    axarr[5].set_title(r'Inter-kinetochore force')
    axarr[0].set_ylabel(r'N Chromosomes')
    axarr[1].set_ylabel(r'N Chromosomes')
    axarr[2].set_ylabel(r'N Chromosomes')
    axarr[3].set_ylabel(r'N Chromosomes')
    axarr[4].set_ylabel(r'N Chromosomes')
    axarr[5].set_ylabel(r'Force (pN)')

    if label == 'sd': label=sd.label

    time = sd.PostAnalysis.timedata['time']
    attach_dict = sd.PostAnalysis.timedata['kc_atypes']
    attach_arr = [attach_dict[ts] for ts in time]

    nchromo = len(attach_arr[0])
    #print "n_chromo: {}".format(nchromo)

    unattached_arr = np.zeros(len(attach_arr))
    monotelic_arr  = np.zeros(len(attach_arr))
    merotelic_arr  = np.zeros(len(attach_arr))
    syntelic_arr   = np.zeros(len(attach_arr))
    amphitelic_arr = np.zeros(len(attach_arr))

    for x in xrange(len(attach_arr)):
        for ic in xrange(nchromo):
            attach = attach_arr[x][ic]
            if attach == 0:
                unattached_arr[x] += 1.0
            elif attach == 1:
                monotelic_arr[x] += 1.0
            elif attach == 2:
                merotelic_arr[x] += 1.0
            elif attach == 3:
                syntelic_arr[x] += 1.0
            elif attach == 4:
                amphitelic_arr[x] += 1.0
            else:
                print "Wrong kind of attachment: {}".format(attach)

    # Get the resampled time information
    time_resampled = sd.time_resampled_forces
    forces_arr = sd.force_data

    axarr[0].plot(time, unattached_arr, color=color)
    axarr[1].plot(time, monotelic_arr, color=color)
    axarr[2].plot(time, merotelic_arr, color=color)
    axarr[3].plot(time, syntelic_arr, color=color)
    axarr[4].plot(time, amphitelic_arr, color=color)
    axarr[5].plot(time_resampled, forces_arr, color=color)

    axarr[0].set_ylim([0.0, nchromo])
    axarr[1].set_ylim([0.0, nchromo])
    axarr[2].set_ylim([0.0, nchromo])
    axarr[3].set_ylim([0.0, nchromo])
    axarr[4].set_ylim([0.0, nchromo])

    return [unattached_arr, monotelic_arr, merotelic_arr, syntelic_arr, amphitelic_arr, nchromo]

    #attach_arr0 = [attach_arr[x][0] for x in xrange(len(attach_arr))]
    #attach_arr1 = [attach_arr[x][1] for x in xrange(len(attach_arr))]
    #attach_arr2 = [attach_arr[x][2] for x in xrange(len(attach_arr))]

    #axarr[0].plot(time, attach_arr0, color='r', label=label)
    #axarr[1].plot(time, attach_arr1, color='g', label=label)
    #axarr[2].plot(time, attach_arr2, color='b', label=label)

# Graph the kinetochore distance to the SPB (0)
def graph_kc_spb_distance(sd, axarr, color = 'b', xlabel = True, label = 'sd'):

    if xlabel: axarr[2].set_xlabel(r'Time(min)')
    if label == 'sd': label=sd.label
    axarr[0].set_ylabel(r'SPB1-KC Distance($\mu$m)')
    axarr[1].set_ylabel(r'KC-sep ($\mu$m)')
    axarr[2].set_ylabel(r'KC-sep velocity ($\mu$m/min)')

    measure_time = 8.0/60.0
    time = np.asarray(sd.PostAnalysis.timedata['time'])
    nmeasure = np.int(measure_time / ((time[1] - time[0])))
    #time = time[0::nmeasure]
    time_resampled = time[0::nmeasure]

    dist_dict = sd.PostAnalysis.timedata['kc_distance']
    dist_arr = [dist_dict[ts] for ts in time]

    nchromo = len(dist_arr[0]) / 2
    colors = mpl.cm.rainbow(np.linspace(0,1,nchromo))

    all_arrs = np.zeros((len(dist_arr[0]), len(dist_arr)))

    # Cut the time (and all arrays) at Anaphase + 2minutes
    if sd.PostAnalysis.do_anaphase:
        anaphase_b_time = sd.PostAnalysis.anaphase_onset + 2.0
        print "Anaphase B at: {}".format(anaphase_b_time)
        adj_time = time[time < anaphase_b_time]

    # Plot doublets of each chromosome
    for ic in xrange(nchromo):
        dist0 = np.zeros(len(dist_arr))
        dist1 = np.zeros(len(dist_arr))
        for x in xrange(len(dist_arr)):
            dist0[x] = dist_arr[x][2*ic]
            dist1[x] = dist_arr[x][2*ic+1]
            all_arrs[2*ic][x] = dist0[x]
            all_arrs[2*ic+1][x] = dist1[x]
        # Adjust dist0 and dist1 for the end of anaphase...
        if sd.PostAnalysis.do_anaphase:
            adj_dist0 = dist0[time < anaphase_b_time]
            adj_dist1 = dist1[time < anaphase_b_time]
            kc_sep = np.fabs(adj_dist0 - adj_dist1)
            axarr[0].plot(adj_time, adj_dist0, color = colors[ic], linestyle = '-')
            axarr[0].plot(adj_time, adj_dist1, color = colors[ic], linestyle = '--')
            axarr[1].plot(adj_time, kc_sep, color = colors[ic], linestyle = '-')
            # Look at the velocity....
            velocity = np.gradient(kc_sep, adj_time)
            axarr[2].plot(adj_time, velocity, color = colors[ic], linestyle = '-')
            velocity_anaphase = velocity[adj_time > sd.PostAnalysis.anaphase_onset]
            print "Chromosome {} max kc_sep velocity: {} $\mu$m/min".format(ic, np.amax(velocity))
            print "Chromosoome {} segregation speed after anaphase (max): {} $\mu$m/min".format(ic, np.amax(velocity_anaphase))
        else:
            #axarr[0].plot(time, dist0, color = colors[ic], linestyle = '-')
            #axarr[0].plot(time, dist1, color = colors[ic], linestyle = '--')
            #axarr[1].plot(time, np.fabs(dist0 - dist1), color = colors[ic], linestyle = '-')
            #kc_sep = np.fabs(dist0 - dist1)
            #velocity = np.gradient(kc_sep, time)
            #axarr[2].plot(time, velocity, color = colors[ic], linestyle = '-')

            # Try the resampled version of the kc_separation
            dist0_resampled = np.mean(dist0.reshape(-1, nmeasure), axis=1)
            dist1_resampled = np.mean(dist1.reshape(-1, nmeasure), axis=1)
            axarr[0].plot(time_resampled, dist0_resampled, color = colors[ic], linestyle = '-')
            axarr[0].plot(time_resampled, dist1_resampled, color = colors[ic], linestyle = '--')
            kc_sep = np.fabs(dist0 - dist1)
            kc_sep_resampled = np.mean(kc_sep.reshape(-1, nmeasure), axis=1)
            axarr[1].plot(time_resampled, kc_sep_resampled, color = colors[ic], linestyle = '-')
            velocity = np.gradient(kc_sep_resampled, time_resampled)
            axarr[2].plot(time_resampled, velocity, color = colors[ic], linestyle = '-')

    if sd.PostAnalysis.do_anaphase:
        axarr[0].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')
        axarr[1].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')

    #print "all_arrs: {}".format(all_arrs)
    #np.savetxt('kc_distances.dat', all_arrs)
    with open('kc_distances.dat', 'w') as stream:
        for x in xrange(len(dist_arr)):
            stream.write('{} '.format(time[x]))
            for ic in xrange(nchromo):
                stream.write('{} {} '.format(all_arrs[2*ic][x], all_arrs[2*ic+1][x]))
            stream.write('\n')

# Graph the length, IPF, and amphitelic attachments
def graph_length_ipf_amphi(sd, axarr, color = 'b', xlabel = True, label = 'sd'):
    if xlabel: axarr[2].set_xlabel(r'Time(min)')
    if label == 'sd': label=sd.label
    axarr[0].set_ylabel(r'Spindle length ($\mu$m)')
    axarr[1].set_ylabel(r'Interpolar fraction')
    axarr[2].set_ylabel(r'N amphitelic')
    
    time = np.asarray(sd.PostAnalysis.timedata['time'])
    measure_time = 8.0/60.0
    nmeasure = np.int(measure_time / ((time[1] - time[0])))
    time_resampled = time[0::nmeasure]
    # Spindle length
    spbsep_arr = sd.PostAnalysis.timedata['spb_separation']
    spbsep_arr = np.array([spbsep_arr[ts] for ts in time])
    spbsep_arr = np.array(spbsep_arr)*uc['um'][1]
    spbsep_arr = np.mean(spbsep_arr.reshape(-1, nmeasure), axis=1)
    # Interpolar fraction
    ipf_dict = sd.PostAnalysis.timedata['interpolar_fraction']
    ipf_arr = np.array([ipf_dict[ts] for ts in time ])
    ipf_arr = np.mean(ipf_arr.reshape(-1, nmeasure), axis=1)
    # Amphitelic attachment
    attach_dict = sd.PostAnalysis.timedata['kc_atypes']
    attach_arr = np.array([attach_dict[ts] for ts in time])
    amphitelic_arr = np.zeros(len(attach_arr))

    nchromo = len(attach_arr[0])
    # Rebuild the amphitelic attachments, (ugh)
    for x in xrange(len(attach_arr)):
        for ic in xrange(nchromo):
            attach = attach_arr[x][ic]
            if attach == 4:
                amphitelic_arr[x] += 1.0

    amphitelic_arr = np.mean(amphitelic_arr.reshape(-1, nmeasure), axis=1)

    # Now plot this
    axarr[0].plot(time_resampled, spbsep_arr, color = color)
    axarr[1].plot(time_resampled, ipf_arr, color = color)
    axarr[2].plot(time_resampled, amphitelic_arr, color = color)

# Graph the forceson the 2 SPBs as a function of time
def graph_spb_force(sd, axarr, color = 'b', xlabel = True, label = 'sd'):
    if xlabel: axarr[2].set_xlabel(r'Time(min)')
    if label == 'sd': label=sd.label
    axarr[0].set_ylabel(r'SPB1 Force(pN)')
    axarr[1].set_ylabel(r'SPB2 Force(pN)')
    axarr[2].set_ylabel(r'Outward Force(pN)')

    time = np.asarray(sd.PostAnalysis.timedata['time'])
    force_dict = sd.PostAnalysis.timedata['spb_forces']
    force_arr = np.array([force_dict[ts] for ts in time])

    # This is very noisy, do a rolling average of the points
    force_spb1 = force_arr[:,0]
    force_spb2 = force_arr[:,2]

    def running_mean(x, N):
        cumsum = np.cumsum(np.insert(x, 0, 0))
        return (cumsum[N:] - cumsum[:-N]) / float(N)

    # Have to pad edges of the array to work properly, ugh
    Nwin = 5
    fspb1 = np.pad(running_mean(force_spb1, Nwin), (Nwin/2, Nwin/2), 'edge')
    fspb2 = np.pad(running_mean(force_spb2, Nwin), (Nwin/2, Nwin/2), 'edge')

    #axarr[0].plot(time, force_arr[:,0], color = 'b')
    #axarr[1].plot(time, force_arr[:,2], color = 'r')
    axarr[0].plot(time, fspb1, color = 'b')
    axarr[1].plot(time, fspb2, color = 'r')
    axarr[2].plot(time, -fspb1+fspb2, color = 'g')

    if sd.PostAnalysis.do_anaphase:
        axarr[0].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')
        axarr[1].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')
        axarr[2].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')

# Graph the balance of the AF and Xlink forces!
def graph_xlinkaf_force(sd, axarr, color = 'b', xlabel = True, label = 'sd'):
    if xlabel: axarr[1].set_xlabel(r'Time(min)')
    if label == 'sd': label=sd.label
    axarr[0].set_ylabel(r'Spindle length ($\mu$m)')
    axarr[1].set_ylabel(r'Pole0 forces mag (pN)')

    time = np.asarray(sd.PostAnalysis.timedata['time'])
    spbsep_arr = sd.PostAnalysis.timedata['spb_separation']
    spbsep_arr = np.array([spbsep_arr[ts] for ts in time])
    spbsep_arr = np.array(spbsep_arr)*uc['um'][1]

    force_dict0 = sd.PostAnalysis.timedata['pole_forces'][0]
    force_arr = np.array([force_dict0[ts] for ts in time])

    axarr[0].plot(time, spbsep_arr, color = 'b')
    axarr[1].plot(time, force_arr[:,0], color = 'r') # Shoudl be the xlink forces alog axis
    axarr[1].plot(time, force_arr[:,1], color = 'b') # Should be AF forces aslong axis
    axarr[1].plot(time, force_arr[:,4], color = 'g') # Should be the total chromosome force

# Graph the tangentforceson the 2 SPBs as a function of time
def graph_tangent_force(sd, axarr, color = 'b', xlabel = True, label = 'sd'):
    if xlabel: axarr[3].set_xlabel(r'Time(min)')
    if label == 'sd': label=sd.label
    axarr[0].set_ylabel(r'SPB1 Tangent Force(pN)')
    axarr[1].set_ylabel(r'SPB2 Tangent Force(pN)')
    axarr[2].set_ylabel(r'SPB1+2 Force(pN)')
    axarr[3].set_ylabel(r'Spindle length')

    time = np.asarray(sd.PostAnalysis.timedata['time'])
    force_dict = sd.PostAnalysis.timedata['tangent_forces']
    force_arr = np.array([force_dict[ts] for ts in time])
    spbsep_arr = sd.PostAnalysis.timedata['spb_separation']
    spbsep_arr = np.array([spbsep_arr[ts] for ts in time])
    spbsep_arr = np.array(spbsep_arr)*uc['um'][1]

    # This is very noisy, do a rolling average of the points
    force_spb1 = force_arr[:,0]
    force_spb2 = force_arr[:,1]

    def running_mean(x, N):
        cumsum = np.cumsum(np.insert(x, 0, 0))
        return (cumsum[N:] - cumsum[:-N]) / float(N)

    # Have to pad edges of the array to work properly, ugh
    Nwin = 15
    fspb1 = np.pad(running_mean(force_spb1, Nwin), (Nwin/2, Nwin/2), 'edge')
    fspb2 = np.pad(running_mean(force_spb2, Nwin), (Nwin/2, Nwin/2), 'edge')

    #axarr[0].plot(time, force_arr[:,0], color = 'b')
    #axarr[1].plot(time, force_arr[:,2], color = 'r')
    axarr[0].plot(time, fspb1, color = 'b')
    axarr[1].plot(time, fspb2, color = 'r')
    axarr[2].plot(time, fspb1+fspb2, color = 'g')
    axarr[2].axhline(color = 'k')
    axarr[3].plot(time, spbsep_arr, color = 'k')

    if sd.PostAnalysis.do_anaphase:
        axarr[0].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')
        axarr[1].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')
        axarr[2].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')
        axarr[3].axvline(sd.PostAnalysis.anaphase_onset, color = 'k', linestyle = '--')

# Graph the occupancy and the attachment types versus time
def graph_occupancy_attachment(sd, axarr, color = 'b', xlabel = True, label = 'sd'):
    if xlabel: axarr[3].set_xlabel(r'Time(min)')
    if label == 'sd': label=sd.label
    axarr[0].set_ylabel(r'Occupancy')
    axarr[1].set_ylabel(r'Amphitelic')
    axarr[2].set_ylabel(r'Merotelic')
    axarr[3].set_ylabel(r'Other')

    time = np.asarray(sd.PostAnalysis.timedata['time'])
    attach_dict = sd.PostAnalysis.timedata['kc_atypes']
    attach_arr = np.array([attach_dict[ts] for ts in time])

    occupancy_dict = sd.PostAnalysis.timedata['kc_occupancy']
    occupancy_arr = np.array([occupancy_dict[ts] for ts in time])

    nchromo = len(attach_arr[0])
    #print "n_chromo: {}".format(nchromo)

    unattached_arr = np.zeros(len(attach_arr))
    monotelic_arr  = np.zeros(len(attach_arr))
    merotelic_arr  = np.zeros(len(attach_arr))
    syntelic_arr   = np.zeros(len(attach_arr))
    amphitelic_arr = np.zeros(len(attach_arr))
    other_arr      = np.zeros(len(attach_arr))

    for x in xrange(len(attach_arr)):
        for ic in xrange(nchromo):
            attach = attach_arr[x][ic]
            if attach == 0:
                unattached_arr[x] += 1.0
                other_arr[x] += 1.0
            elif attach == 1:
                monotelic_arr[x] += 1.0
                other_arr[x] += 1.0
            elif attach == 2:
                merotelic_arr[x] += 1.0
            elif attach == 3:
                syntelic_arr[x] += 1.0
                other_arr[x] += 1.0
            elif attach == 4:
                amphitelic_arr[x] += 1.0
            else:
                print "Wrong kind of attachment: {}".format(attach)

    axarr[0].plot(time, occupancy_arr, color = 'b')
    axarr[1].plot(time, amphitelic_arr, color = 'b')
    axarr[2].plot(time, merotelic_arr, color = 'b')
    axarr[3].plot(time, other_arr, color = 'b')

def graph_kmt_lifetimes(sd, axarr, color = 'b', xlabel = True, label = 'sd'):
    time = np.asarray(sd.PostAnalysis.timedata['time'])
    kmt_lifetimes_dict = sd.PostAnalysis.timedata['attachment_lifetimes']
    kmt_lifetimes_arr = [kmt_lifetimes_dict[ts] for ts in time]

    # Just report what the last ones were
    print "last kmt lifetimes = {}".format(kmt_lifetimes_arr[-1])
    print "kMT avg lifetime (min) = {} +/- {}".format(np.mean(kmt_lifetimes_arr[-1]*(time[1]-time[0])), np.std(kmt_lifetimes_arr[-1]*(time[1]-time[0]), ddof=1))

def graph_spb_stageN_xlink_distance(sd, ax, stage=3, label=None, color='b', me=1, 
                                    xlabel=True, xstate=r'integrated', **kwargs):
    if stage == 1: ax.set_title("Stage1 Xlink SPB Separations")
    elif stage == 2: ax.set_title("Stage2 Xlink SPB Separations")
    else: ax.set_title("Stage 1 and 2 Xlink SPB Separations")
    ax.set_ylabel("Probability Density")
    if xlabel: ax.set_xlabel(r'Distance from Parent SPB ($\mu$m)')

    xl = sd.GetXlinkDistributionData(stage, xstate, **kwargs)
    bin_mids = xl['mid_points']
    hist = xl['hist']
    n_points = xl['n_points']

    hist = np.divide(hist, (bin_mids[1]-bin_mids[0])*n_points)
    ax.plot(bin_mids,
            hist,
            color = color,
            label = label )
    return 

def graph_spindle_xlink_distance(sd, ax, label='sd', color='b', 
                                 xlabel=True, **kwargs):
    """ Graphs normalized xlink distributions along spindle from distribution data.
    """
    dd = sd.PostAnalysis.distrdata # Shortcut for PostAnalysis distribution data
    bin_mids, hist, n_points, xstage = get_spindle_xlink_distr_histogram(dd, **kwargs)
    # Graphing options
    ax.set_title(xstage)
    ax.set_ylabel("Probability Density")
    if xlabel: ax.set_xlabel(r'Normalized xlink location along spindle')
    hist = np.divide(hist, (bin_mids[1]-bin_mids[0])*n_points)
    ax.plot(bin_mids,
            hist,
            color = color,
            label = label )
    return 

def graph_spindle_xlink_distance_all( sd, axarr, label='sd', color='b', 
                                  xstate=r'integrated', **kwargs ):
    graph_spindle_xlink_distance(sd, axarr[0], label, color, xstate=xstate, 
                                 xstage='stage1', xlabel=False, **kwargs)
    graph_spindle_xlink_distance(sd, axarr[1], label, color, xstate=xstate, 
                                 xstage='stage2', xlabel=False, **kwargs)
    graph_spindle_xlink_distance(sd, axarr[2], label, color, xstate=xstate, 
                                 xlabel=True, xstage='merge', **kwargs)
    return

if __name__ == "__main__":
    print "Not implemented yet"




