#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# Basic
import sys, os, pdb
## Analysis
import cPickle as pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from base_funcs import moving_average
from seed_graph_funcs import *
from math import *
from spindle_unit_dict import SpindleUnitDict

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
from graph_lib import *

'''
Name: sim_graph_funcs.py
Description: File that contains all graphing algorithms for simulations
'''

uc = SpindleUnitDict()

# Helper functions
def hist_prob_density_convert(mid_points=[], hist=[], n_points=1, **kwargs):
    """ function that converts histogram data into a probability density histogram """
    bin_mids = mid_points
    hist = np.divide(hist, (bin_mids[1]-bin_mids[0])*n_points)
    return bin_mids, hist, n_points


# Graphing functions
def graph_sim_spb_stageN_xlink_distributions(sim, ax, opts, **kwargs):
    colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
    for sd, c in zip(sim.seeds, colors):
        graph_spb_stageN_xlink_distance(sd, ax, label=sd.label, color=c, **kwargs)
    legend_outside(ax, set_ax_pos=False)
    return

def graph_sim_spb_stageN_xlink_distributions_error(sim, ax, opts, xlabel=True,**kwargs):
    succ_distances = np.array([])
    succ_weights = np.array([])

    fail_distances = np.array([])
    fail_weights = np.array([])

    for sd in sim.seeds:
        distances, weights = sd.GetXlinkDistributionData(**kwargs)
        if sd.succ_info_dict['succ']:
            succ_distances = np.append(succ_distances, distances)
            succ_weights = np.append(succ_weights, weights)
        else:
            fail_distances = np.append(fail_distances, distances)
            fail_weights = np.append(fail_weights, weights)

    if kwargs['stage'] == 1: ax.set_title("Stage 1 Xlink SPB Separations Avg")
    elif kwargs['stage'] == 2: ax.set_title("Stage 2 Xlink SPB Separations Avg")
    else: ax.set_title("Stage 1 and 2 Xlink SPB Separations")
    ax.set_ylabel("Probability Density")
    if xlabel: ax.set_xlabel(r'Distance from Parent SPB ($\mu$m)')

    succ_hist, sbin_edges = np.histogram(succ_distances, weights=succ_weights, bins=110, range=(0.0, 2.75), density=True)
    fail_hist, fbin_edges = np.histogram(fail_distances, weights=fail_weights, bins=110, range=(0.0, 2.75), density=True)

    sbin_mids = moving_average(sbin_edges)
    fbin_mids = moving_average(fbin_edges)

    serror = np.divide(np.sqrt(succ_hist), succ_distances.size)
    ferror = np.divide(np.sqrt(fail_hist), fail_distances.size)

    ax.errorbar(sbin_mids, succ_hist, yerr=serror, label="Bipolar")
    ax.errorbar(fbin_mids, fail_hist, yerr=ferror, label="Monopolar")

    return


# Graphing functions
def graph_sim_xlink_distributions_error(sim, ax,
                                        xstate=r'integrated', stage=2, succ_type=1,
                                        color = 'b',
                                        **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    # print 'graph_sim_spb_stageN_xlink_distributions_succ_fail_error not implemented yet'
    hist_arr = None
    num_succ_type = 0
    for sd in sim.seeds:
        if sd.succ_info_dict['succ'] == succ_type:
            num_succ_type += 1
            xl = sd.GetXlinkDistributionData(stage, xstate, **kwargs)
            # print xl
            bin_mids, hist, n_points = hist_prob_density_convert(**xl)
            # print hist
            if type(hist_arr) != type(hist):
                hist_arr = hist
            else:
                hist_arr = np.column_stack((hist_arr, hist))

    if 'xlabel' in kwargs and kwargs['xlabel'] == True: ax.set_xlabel(r'Distance from Parent SPB ($\mu$m)')
    ax.set_ylabel('Probability Density')
    if type(hist_arr) == type(None):
        # Handle case where there were no successes or failures, return zeros
        # Choose a random seed to get data regardless of success status
        xl = sd.GetXlinkDistributionData(stage, xstate, **kwargs)
        # Get real histogram for crosslinks
        bin_mids, hist, n_points = hist_prob_density_convert(**xl)
        # Use real histogram data to make zero arrays
        hist_mean = np.zeros(hist.size)
        hist_err = np.zeros(hist.size)
    elif len(hist_arr.shape) == 1:
        hist_mean = hist_arr
        hist_err = np.sqrt(np.divide(hist_arr, (bin_mids[1] - bin_mids[0])*n_points))
    else:
        hist_mean = np.mean(hist_arr, axis=1) # Mean of bins
        hist_err = np.divide(np.std(hist_arr, axis=1), np.sqrt(num_succ_type)) # STD error of distributions

    ax.errorbar(bin_mids, hist_mean, yerr=hist_err, label=sim.title, color=color)

def graph_sim_kc_spindle1d_error(sim, ax,
                                 xstate=r'integrated', succ_type=1,
                                 xtype = 'full',
                                 color = 'b',
                                 **kwargs):
    # Home written version
    hist_arr = None
    num_succ_type = 0
    for sd in sim.seeds:
        if sd.PostAnalysis.analyze_chromosomes:
            if sd.succ_info_dict['succ'] == succ_type:
                num_succ_type += 1
                kc = sd.GetKCDistributionData(xstate, xtype, **kwargs)
                bin_mids, hist, n_points = hist_prob_density_convert(**kc)
                if type(hist_arr) != type(hist):
                    hist_arr = hist
                else:
                    hist_arr = np.column_stack((hist_arr, hist))
        else:
            return 

    if 'xlabel' in kwargs and kwargs['xlabel'] == True: ax.set_xlabel(r'Normalized Spindle 1D Distance')
    ax.set_ylabel('Probability Density')

    if type(hist_arr) == type(None):
        kc = sd.GetKCDistributionData(xstate, **kwargs)
        bin_mids, hist, n_points = hist_prob_density_convert(**kc)
        hist_mean = np.zeros(hist.size)
        hist_err = np.zeros(hist.size)
    elif len(hist_arr.shape) == 1:
        hist_mean = hist_arr
        hist_err = np.sqrt(np.divide(hist_arr, (bin_mids[1] - bin_mids[0])*n_points))
    else:
        hist_mean = np.mean(hist_arr, axis=1)
        hist_err = np.divide(np.std(hist_arr, axis=1), np.sqrt(num_succ_type))

    ax.errorbar(bin_mids, hist_mean, yerr=hist_err, label=sim.title, color=color, marker="o")

def graph_sim_kc_interkc(sim, ax,
                         xstate = r'integrated',
                         succ_type = 1,
                         xtype = 'full',
                         color = 'b',
                         **kwargs):
    hist_arr = None
    num_succ_type = 0
    for sd in sim.seeds:
        if sd.PostAnalysis.analyze_chromosomes:
            if sd.succ_info_dict['succ'] == succ_type:
                num_succ_type += 1
                kc = sd.GetKCStretchData(xstate, xtype, **kwargs)
                bin_mids, hist, n_points = hist_prob_density_convert(**kc)
                if type(hist_arr) != type(hist):
                    hist_arr = hist
                else:
                    hist_arr = np.column_stack((hist_arr, hist))
        else:
            return
    if 'xlabel' in kwargs and kwargs['xlabel'] == True: ax.set_xlabel(r'Spindle-KC Separation')
    ax.set_ylabel('Probability Density')

    if type(hist_arr) == type(None):
        kc = sd.GetKCStretchData(xstate, **kwargs)
        bin_mids, hist, n_points = hist_prob_density_convert(**kc)
        hist_mean = np.zeros(hist.size)
        hist_err = np.zeros(hist.size)
    elif len(hist_arr.shape) == 1:
        hist_mean = hist_arr
        hist_err = np.sqrt(np.divide(hist_arr, (bin_mids[1] - bin_mids[0])*n_points))
    else:
        hist_mean = np.mean(hist_arr, axis=1)
        hist_err = np.divide(np.std(hist_arr, axis=1), np.sqrt(num_succ_type))

    ax.errorbar(bin_mids, hist_mean, yerr=hist_err, label=sim.title, color=color, marker='o')

def graph_sim_xlink_distr_succ_compare(sim, axarr, 
                                       xstate=r'integrated',
                                       **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    # Graph stage 1 success (0,0)
    graph_sim_xlink_distributions_error(sim, axarr[0,0], xstate=xstate, stage=1, succ_type=1, **kwargs)
    axarr[0,0].set_title("Stage 1 Xlink Bipolar Spindles")
    # Graph stage 2 success (0,1)
    graph_sim_xlink_distributions_error(sim, axarr[0,1], xstate=xstate, stage=2, succ_type=1, **kwargs)
    axarr[0,1].set_title("Stage 2 Xlink Bipolar Spindles")
    # Graph stage 1 failure (1,0)
    graph_sim_xlink_distributions_error(sim, axarr[1,0], xstate=xstate, stage=1, succ_type=0, **kwargs)
    axarr[1,0].set_title("Stage 1 Xlink Monopolar Spindles")
    # Graph stage 2 failure (1,1)
    graph_sim_xlink_distributions_error(sim, axarr[1,1], xstate=xstate, stage=2, succ_type=0, **kwargs)
    axarr[1,1].set_title("Stage 2 Xlink Monopolar Spindles")
    plt.tight_layout()

def graph_sim_spb_xlink_dist_error(sim, axarr, succ_type=1, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    graph_sim_xlink_distributions_error(sim, axarr[0], stage=1, succ_type=succ_type, **kwargs)
    axarr[0].set_title("Singly bound crosslinkers")
    # Graph stage 2 success (0,1)
    graph_sim_xlink_distributions_error(sim, axarr[1], stage=2, succ_type=succ_type, **kwargs)
    axarr[1].set_title("Doubly bound crosslinkers")

def graph_sim_spindle_xlink_distance(sim, axarr, succ_type=1,
                                     xstate=r'integrated', anchor=(1.0, .5),
                                     **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    hist_arr = None
    colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
    for sd,c in zip(sim.seeds, colors):
        if sd.succ_info_dict['succ']:
            graph_spindle_xlink_distance_all(sd, axarr, color=c, xstate=xstate, label=sd.label, **kwargs)

    fig = axarr[0].get_figure()
    legend_outside(axarr[1], anchor=anchor, set_ax_pos=False)
    # fig.subplots_adjust(right=.85)
    return

def graph_sim_kc_spindle1d(sim, axarr,
                           xstate=r'integrated',
                           **kwargs):

    # Graph short spindles
    graph_sim_kc_spindle1d_error(sim, axarr[0,0], xstate=xstate, succ_type=1, xtype = 'short', **kwargs)
    axarr[0,0].set_title("Short Spindles")
    # Graph medium spindles
    graph_sim_kc_spindle1d_error(sim, axarr[0,1], xstate=xstate, succ_type=1, xtype = 'med', **kwargs)
    axarr[0,1].set_title("Medium Spindles")
    # Graph long spindles
    graph_sim_kc_spindle1d_error(sim, axarr[1,0], xstate=xstate, succ_type=1, xtype = 'long', **kwargs)
    axarr[1,0].set_title("Long Spindles")
    # Graph full spindles
    graph_sim_kc_spindle1d_error(sim, axarr[1,1], xstate=xstate, succ_type=1, xtype = 'full', **kwargs)
    axarr[1,1].set_title("Full Spindles")

    plt.tight_layout()


def graph_sim_spindle_xlink_distance_error( sim, ax, succ_type=1,
                                            xstate=r'integrated', xstage='merge',
                                            color='b', **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    # # Initialize required variables
    hist_arr = None
    num_succ_type = 0
    n_points_tot = 0
    # Get data from sim seeds and combine to determine error
    for sd in sim.seeds:
        if sd.succ_info_dict['succ'] == succ_type:
            bin_mids, hist, n_points, xstage = get_spindle_xlink_distr_histogram( sd.PostAnalysis.distrdata, 
                                                                          xstate=xstate, xstage=xstage,
                                                                          **kwargs )
            num_succ_type += 1
            bin_mids, hist, n_points = hist_prob_density_convert(bin_mids, hist, n_points)
            if type(hist_arr) != type(hist): hist_arr = hist
            else: hist_arr = np.column_stack((hist_arr, hist))
            n_points_tot += n_points
    bin_mids = sim.seeds[0].PostAnalysis.distrdata['spindle_xlink_distance_stage1'][xstate]['mid_points']
    if num_succ_type == 0: # Handle case where there were no successes or failures,
        return
    else: # Manipulate histogram data do give probability density
        hist_err = None
        hist_mean = None
        if num_succ_type == 1:
            hist_mean = hist_arr
            hist_err = np.sqrt(np.divide(hist_arr, (bin_mids[1] - bin_mids[0])*n_points_tot))
        else:
            hist_mean = np.mean(hist_arr, axis=1) # Mean of bins
            hist_err = np.divide(np.std(hist_arr, axis=1), np.sqrt(num_succ_type)) # STD error of distributions
    title_dict = {"stage1": "Singly bound crosslinkers",
                  "stage2": "Doubly bound crosslinkers",
                  "merge": "Merged"
                 }
    ax.set_title(title_dict[xstage])
    # Create error bar
    ax.errorbar(bin_mids, hist_mean, yerr=hist_err, label=sim.title, color=color)

def graph_sim_spindle_xlink_distance_error_all(sim, axarr, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    key_list = ['stage1', 'stage2', 'merge']
    for i, k in zip(xrange(3), key_list):
        graph_sim_spindle_xlink_distance_error(sim, axarr[i], xstage=k, **kwargs)
        axarr[i].set_ylabel('Probability Density')

    axarr[-1].set_xlabel(r'Normalized crosslink position ($x/x_s$)')
    return

def graph_sim_spindle_xlink_distance_final(sim, axarr, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    graph_sim_spindle_xlink_distance(sim, axarr[:,0], anchor = (2.25, .5), **kwargs)
    graph_sim_spindle_xlink_distance_error_all(sim, axarr[:,1], **kwargs)
    fig = axarr[0,0].get_figure()
    fig.tight_layout()
    fig.subplots_adjust(right=.8)
    return

def graph_sim_mt_length_distr(sim, axarr, opts, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
    for sd, color in zip(sim.seeds, colors):
        graph_all_mt_length_distr(sd, axarr, color=color, **kwargs)

    legend_outside(axarr[1], set_ax_pos=False)
    fig = axarr[0].get_figure()
    plt.figure(fig.number)
    plt.subplots_adjust(right=.8)

    return

def graph_sim_mt_length_distr_succ_compare(sim, axarr, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
    for sd, color in zip(sim.seeds, colors):
        succ = sd.succ_info_dict['succ']
        graph_all_mt_length_distr(sd, axarr[:, succ], color=color)

    fig = axarr[0,0].get_figure()
    fig.suptitle('   Monopolar vs Bipolar Spindle  ', fontsize=30)
    plt.figure(fig.number)
    legend_outside(axarr[1,1], set_ax_pos=False)
    plt.subplots_adjust(right=.87, top=.87)
    return

def graph_sim_mt_length_distr_final(sim, axarr, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    graph_sim_mt_length_distr_succ_compare(sim, axarr, **kwargs)
    graph_sim_mt_length_distr_error(sim, axarr[:,2], succ_type=0, **kwargs)
    fig = axarr[0,0].get_figure()
    fig.suptitle('   Monopolar vs Bipolar vs Bipolar mean  ', fontsize=30)
    return

def graph_sim_mt_length_distr_error(sim, axarr, succ_type=0, color='b', **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    ### TODO Still in one off stage, need to standardize this
    # Initialize necessary parameters
    hist_arr = [None]*3
    n_points_arr = [0]*3
    num_succ_type = 0
    key_list = ['spb1', 'spb2', 'merge']
    for sd in sim.seeds:
        if sd.succ_info_dict['succ'] == succ_type:
            num_succ_type += 1
            mt_l_dict = sd.PostAnalysis.distrdata['mt_lengths']
            for i in xrange(3):
                if type(hist_arr[i]) != type(mt_l_dict[key_list[i]]['hist']): hist_arr[i] = mt_l_dict[key_list[i]]['hist']
                else: hist_arr[i] = np.column_stack((hist_arr[i], mt_l_dict[key_list[i]]['hist']))
                n_points_arr[i] += mt_l_dict[key_list[i]]['n_points']
    bin_mids = sim.seeds[0].PostAnalysis.distrdata['mt_lengths']['spb1']['mid_points']

    if num_succ_type == 0: # Handle case where there were no successes or failures,
        return
    else: # Manipulate histogram data do give probability density
        hist_mean_0 = np.zeros((bin_mids.shape[0]))
        hist_mean_1 = np.zeros((bin_mids.shape[0]))
        hist_err_tot = np.zeros((bin_mids.shape[0]))
        for i in xrange(3):
            hist_err = None
            hist_mean = None
            if num_succ_type == 1:
                hist_mean = hist_arr[i]
                hist_err = np.sqrt(np.divide(hist_arr[i], (bin_mids[1] - bin_mids[0])*n_points[i]))
            else:
                hist_mean = np.mean(hist_arr[i], axis=1) # Mean of bins
                hist_err = np.divide(np.std(hist_arr[i], axis=1), np.sqrt(num_succ_type)) # STD error of distributions
            hist_err_tot += hist_err
            if i == 0: hist_mean_0[:] = hist_mean[:]
            elif i == 1: hist_mean_1[:] = hist_mean[:]

            axarr[i].errorbar(bin_mids, hist_mean, yerr=hist_err, label=sim.title, color=color)
    axarr[0].set_title("SPB 0 MT length distribution")
    axarr[1].set_title("SPB 1 MT length distribution")
    axarr[2].set_title("Total MT length distribution")
    axarr[3].set_title("SPB1 - SPB0 MT length distribution")

    axarr[3].errorbar(bin_mids, hist_mean_1-hist_mean_0, yerr=hist_err_tot, color=color)
    for i in xrange(4): 
        axarr[i].set_ylabel("Counts")
        axarr[i].set_xlabel("MT length ($\mu$m)")
    
    plt.tight_layout()
        # if opts.sim: save_path = sim.sim_path
        # else: save_path = os.path.join(opts.datadir, sim.name)
        # plt.savefig("{}_mt_distr_error.pdf".format(save_path))

    return

def graph_sim_start_time_hist(sim, ax, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    ax.set_xlabel("Spindle start time (min)")
    ax.set_ylabel("Counts")
    start_times = []
    for sd in sim.seeds:
        sid = sd.succ_info_dict
        if sid['succ'] == 1: start_times += [sid['start_time']]
    
    hist, bin_edges = np.histogram(start_times,  bins=30, range=(0.0, 25.0))
    bin_mids = moving_average(bin_edges)
    bin_width = bin_mids[1]-bin_mids[0]
    ax.bar(bin_mids, hist, bin_width )
    return

def graph_sim_anaphase_onset(sim, ax, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    ax.set_xlabel("Anaphase onset time (min)")
    ax.set_ylabel("Counts")
    anaphase_onset = []
    for sd in sim.seeds:
        if sd.PostAnalysis.analyze_chromosomes and sd.PostAnalysis.do_anaphase:
            sid = sd.succ_info_dict
            if sid['succ'] == 1:
                anaphase_onset += [sd.PostAnalysis.anaphase_onset]
    hist, bin_edges = np.histogram(anaphase_onset,  bins=50, range=(0.0, 100.0))
    bin_mids = moving_average(bin_edges)
    bin_width = bin_mids[1]-bin_mids[0]
    ax.bar(bin_mids, hist, bin_width )
    # print this information for now to hardvest it for later use
    # FIXME
    print "name = {}".format(sim.name)
    print "anaphase_onset = {}".format(anaphase_onset)
    return

# Graph the KC attachment types
def graph_sim_kc_attachment_types(sim, axarr, opts, xlabel=True, colors=None, **kwargs):
    if np.any(colors == None):
        colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
    axarrseeds = axarr[0:6,0]
    axarravg   = axarr[0:6,1]
    
    # Figure out the min size (see spindle_sim.py)
    min_size = 0
    for sd in sim.seeds:
        if (min_size == 0 or min_size > sd.time.size):
            min_size = sd.time.size

    time = sim.seeds[-1].time[:min_size]
    averages = np.zeros((5,len(time)))

    # This is just default behavior, it gets set properly below in the loop
    nchromo = 3
    for sd, c in zip(sim.seeds, colors):
        mavg = graph_kc_attachment_types(sd, axarrseeds, color = c)
        averages[0] += mavg[0]
        averages[1] += mavg[1]
        averages[2] += mavg[2]
        averages[3] += mavg[3]
        averages[4] += mavg[4]
        nchromo = mavg[5]

    averages[0] /= len(sim.seeds)
    averages[1] /= len(sim.seeds)
    averages[2] /= len(sim.seeds)
    averages[3] /= len(sim.seeds)
    averages[4] /= len(sim.seeds)

    # Get the force information
    time_resampled = sim.time_resampled_forces
    sim_forces = sim.avg_interkinetochore_force

    # Plot
    axarravg[0].plot(time, averages[0])
    axarravg[1].plot(time, averages[1])
    axarravg[2].plot(time, averages[2])
    axarravg[3].plot(time, averages[3])
    axarravg[4].plot(time, averages[4])
    axarravg[5].plot(time_resampled, sim_forces)

    axarravg[0].set_ylim([0.0, nchromo])
    axarravg[1].set_ylim([0.0, nchromo])
    axarravg[2].set_ylim([0.0, nchromo])
    axarravg[3].set_ylim([0.0, nchromo])
    axarravg[4].set_ylim([0.0, nchromo])

def graph_sim_mt_length_by_index(sim, ax, opts, **kwargs):
    """ Function TODO
        Inputs:  
        Outputs: 
    """
    ax.set_title("Individual MT Length Distributions")
    ax.set_ylabel(r'Probability Density')
    ax.set_xlabel(r'Length ($\mu$m)')
    n_bonds = sim.seeds[0].PostAnalysis.n_bonds

    hist_arr = [None]*n_bonds
    bin_mids = None
    colors = mpl.cm.rainbow(np.linspace(0,1,n_bonds))
    for sd in sim.seeds:
        mt_l_dict = sd.PostAnalysis.distrdata['mt_lengths_by_index']
        for i in xrange(0, n_bonds):
            hist = mt_l_dict[i]['hist']
            bin_mids = mt_l_dict[i]['bin_mids']
            if type(hist_arr[i]) != type(hist): hist_arr[i] = hist
            else: hist_arr[i] += hist

    for i,c in zip(xrange(0, n_bonds),colors):
        ax.plot(bin_mids, hist_arr[i], color = c, label = 'mt_{}'.format(i))
    legend_outside(ax)

    if opts.sim: save_path = sim.sim_path
    else: save_path = os.path.join(opts.datadir, sim.name)
    plt.savefig("{}_mt_distr_error.pdf".format(save_path))
    return

##########################################
if __name__ == "__main__":
    print "Not implemented yet"




