#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# Basic
import sys, os, pdb
## Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *
from stylelib.ase1_styles import ase1_runs_stl
# sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))

from graph_lib import *
from sim_graph_funcs import *
from bin_functions import *
from spindle_unit_dict import *

'''
Name: run_graph_funcs.py
Description: library of graphing functions for SpindleRun.py objects
'''

uc = SpindleUnitDict()

# Helper functions
def graph_run_scatter_error(ax, xticks, p_vals, arr, label = None, marker = 's'):
    """ Generic function to graph scatter plot with error bars
        Inputs: ax     = mpl axes object
                xticks = array of unique x-values values graphed
                p_vals = x-values corresponding with arr values,
                         (p_vals are repeated unlike xticks)
                arr    = y-values to take statistics on and then graphed with mean
                         and error
                label  = legend label
                marker = shape of the marker graphed, (default: square)
        Outputs: nothing
    """
    Abs_bin_max = p_vals.size/len(xticks)
    # Sort successes into bins based on parameter value
    bins, bin_err, bin_edges, bin_n, bin_width, bin_max = BarArray(p_vals, arr, use_zeros=False)
    # Graph average spindle separation(bins) vs param values(xticks).
    # Color is determined by successful spindle fraction of that
    # particular parameter value.
    ax.scatter(xticks, bins, color=mpl.cm.rainbow(bin_n/Abs_bin_max),
               zorder=100, s=100, label=label, marker=marker)
    # Graph error bars
    ax.errorbar(xticks, bins, yerr=bin_err,
        ecolor='k', elinewidth=2, capsize=7, capthick=1, zorder=0,
        fmt='none', marker='none')
    # Set x and y limits for readability
    ax.set_xlim(left=xticks[0]-bin_width*1.05, right=xticks[-1]+bin_width*1.05)
    return ax

def graph_run_scatter_binomial_error(ax, xticks, p_vals, binary_arr, 
                                     xlog=False, percent_fac=100, 
                                     color='k', label=None, marker='s'):
    """ Generic function to graph points with error bars with the error bars based off the binomial error.
        Inputs: ax          = mpl axes object
                xticks      = array of unique x-values graphed
                p_vals      = x-values corresponding with arr values,
                              (p_vals are repeated unlike xticks)
                binary_arr  = binary array corresponding to p_values. These will
                              binned for a give p_val, turned into a fraction 
                              (p_val # of 1's/total p_val #), and then used to find 
                              the binomial error.
                percent_fac = scalling factor to multiply y-values by 
                label       = legend label
                marker      = shape of the marker graphed, (default: square)
        Outputs: ax
    """
    if xlog: ax.set_xscale("log", nonposx='mask')
    bins, bin_err, bin_edges, bin_n, bin_width, bin_max = BinomialBarArray(p_vals, binary_arr)
    y = bins*percent_fac
    yerr = bin_err*percent_fac
    # Graph fraction success percentage in a scatter plot vs their parameter value.
    # Also graph the error based on their bin_err.
    ax.scatter(xticks, y, zorder=100,
               s=100, marker='s', color=color, label=label)
    # Make error bar graph using param values (xticks),
    # percent successful (bins*percent_fac), and calculated
    # error of percentage (bin_err*percent_fac).
    # All other variables are for formating error bars.
    ax.errorbar(xticks, y, yerr=yerr,
        ecolor=color, elinewidth=2, capsize=7, capthick=1, zorder=0,
        fmt='none', marker='none')
    # Set x-axis limits to make data readable
    if len(xticks) == 1:
        ax.set_xticks(xticks)
    if xlog:
        # ax.set_xticks(xticks)
        ax.set_xlim(left=(xticks[1]-xticks[1]*.2), right=(xticks[-1]+xticks[-1]*.2))
    else:
        ax.set_xlim(left=xticks[0]-bin_width*1.05, right=xticks[-1]+bin_width*.5)

    return ax

def add_linear_regression_lines(ax, x_orig, y):
    """ Add a linear regression line with confidence and prediction intervals to data
        Refer to http://markthegraph.blogspot.com/2015/05/using-python-statsmodels-for-ols-linear.html
        Inputs:
        Outputs:
    """
    # Linear regression line
    import statsmodels.api as sm
    x = sm.add_constant(x_orig)
    model = sm.OLS(y, x)
    fitted = model.fit()
    x_pred = np.linspace(x.min(), x.max(), 50)
    x_pred2 = sm.add_constant(x_pred)
    y_pred = fitted.predict(x_pred2)
    ax.plot(x_pred, y_pred, linewidth=2) 
    print fitted.summary()

    # Confidence interval band
    from scipy import stats
    y_hat = fitted.predict(x)
    y_err = y - y_hat
    mean_x = x.T[1].mean()
    n = len(x)
    dof = n - fitted.df_model - 1
    t = stats.t.ppf(1. - 0.025, df=dof)
    s_err = np.sum(np.power(y_err, 2))
    conf = t * np.sqrt((s_err/(n-2))*(1.0/n +
        (np.power((x_pred-mean_x),2)/((np.sum(np.power(x_pred,2)))-n*(np.power(mean_x,2))))))
    upper = y_pred + abs(conf)
    lower = y_pred - abs(conf)
    ax.fill_between(x_pred, lower, upper, color='k', alpha=0.2, label='confidence interval')

    # Prediction interval band
    from statsmodels.sandbox.regression.predstd import wls_prediction_std
    sdev, pi_lower, pi_upper = wls_prediction_std(fitted, exog=x_pred2, alpha=0.05)
    ax.fill_between(x_pred, pi_lower, pi_upper, color='k', alpha=0.1, label='prediction interval')

def graph_run_xlink_distr_succ_compare( run, axarr, **kwargs ):
    """ Graphs a 2x2 figure of each sims combined xlink distances from SPB
        probability distributions. Top row show bipolar spindles for stage1
        (left column) and stage2 xlink distances (right column). The bottom row
        has a similar layout for monopolar spindles.
        Inputs:  run   = SpindleRun object
                 axarr = 2x2 mpl axis array,
                 kwargs = generic dictionary of keyword arguments
        Outputs: nothing
    """
    # Color array for each sim
    colors = mpl.cm.rainbow(np.linspace(0,1,len(run.sims)))
    # Loop over sims and graph
    for sim, c in zip(run.sims, colors):
        graph_sim_xlink_distr_succ_compare(sim, axarr, color=c, **kwargs)
    # Add legend to middle of the axis array
    axarr[0,1].legend(loc='center left', bbox_to_anchor=(1.0, -.25))
    plt.subplots_adjust(right=.82)
    return

def graph_run_mt_length_distr_error(run, axarr, **kwargs):
    """ Graphs a 4x1 figure of mt length distributions for each sim """
    colors = mpl.cm.rainbow(np.linspace(0,1,len(run.sims)))
    for sim, c in zip(run.sims, colors):
        graph_sim_mt_length_distr_error(sim, axarr, succ_type=1, color=c)
    # Graphing options
    axarr[1].legend(loc='center left', bbox_to_anchor=(1.0, -.25))
    plt.subplots_adjust(right=.82)
    return

def graph_run_spindle_xlink_distance_error_all(run, axarr, **kwargs):
    """ Function TODO
        Inputs:
        Output
    """
    colors = mpl.cm.rainbow(np.linspace(0,1,len(run.sims)))
    for sim, c in zip(run.sims, colors):
        graph_sim_spindle_xlink_distance_error_all(sim, axarr, color=c, **kwargs)
    # Graphing options
    plt.tight_layout()
    axarr[1].legend(loc='center left', bbox_to_anchor=(1.0, .5))
    plt.subplots_adjust(right=.82)

    return

def graph_run_kc_spindle1d(run, axarr,
                           xstate=r'integrated',
                           **kwargs):
    kwargs['kc_distribution'] = 'kc_spindle_1d'
    colors = mpl.cm.rainbow(np.linspace(0,1,len(run.sims)))
    for sim, c in zip(run.sims, colors):
        graph_sim_kc_spindle1d(sim, axarr, color=c, **kwargs)

    axarr[0,1].legend(loc='center left', bbox_to_anchor=(1.0, -.25))
    plt.subplots_adjust(right=.82)

    return

def graph_run_avg_start_time(run, ax, param = '', **kwargs):
    """ Function TODO
        Inputs:  run    = SpindleRun object
                 ax     = mpl axis object
                 param  = Shortcut parameter name to be graphed. Used by
                          spindle_unit_dict, SpindleRun, SpindleSim, etc.
                          eg. pa, eqL, etc.
                 kwargs = generic dictionary of keyword arguments
        Outputs: file_name = file path for saving figure
                 title = possible title of figure based on run path
    """
    param = run.CheckGraphParam(param)
    p_vals = run.sim_crit_df[param].astype(uc[param][2])*uc[param][1]
    xticks = np.unique(p_vals)
    start_time = run.sim_crit_df['start_time']
    # Graph dat data
    graph_run_scatter_error(ax, xticks, p_vals, run.sim_crit_df['start_time'])
    # Graphing options
    ax.set_ylabel("Start time (min)")
    ax.set_ylim(ymin=0) # Start y axis at zero
    make_color_bar(ax, v_max=100, label=r'Bipolar Spindle Frequency ($\%$)')
    ModifyXLabel(ax, param)
    file_name = os.path.join(run.opts.datadir,
            '{}_{}_avg_start_time.pdf'.format(run.run_name, param))
    title = r'{}'.format(run.run_name.replace('_',' '))
    return file_name, title

def graph_run_succ_scan(run, ax, param = '', **kwargs):
    """ Function TODO
        Inputs:  run    = SpindleRun object
                 ax     = mpl axis object
                 param  = Shortcut parameter name to be graphed. Used by
                          spindle_unit_dict, SpindleRun, SpindleSim, etc.
                          eg. pa, eqL, etc.
                 kwargs = generic dictionary of keyword arguments
        Outputs: file_name = file path for saving figure
                 title = possible title of figure based on run path
    """
    param = run.CheckGraphParam(param)
    p_vals = self.sim_crit_df[param].astype(uc[param][2])*uc[param][1]
    xticks = np.unique(p_vals)
    return 

def graph_run_avg_chromosome_seconds(run, ax, param = '', **kwargs):
    """ Function TODO
    """
    param = run.CheckGraphParam(param)
    print "not implemented yet"

##########################################
if __name__ == "__main__":
    print "Not implemented yet"




