#!/usr/bin/env python
# In case of poor (Sh**y) commenting contact christopher.edelmaier@colorado.edu
# Basic
import sys, os, pdb
import gc
import argparse
import fnmatch
## Analysis
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'Lib'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'Spindle'))

from spindle_sim import SpindleSim
from sim_graph_funcs import *
from stylelib.ase1_styles import ase1_sims_stl
from stylelib.ase1_styles import ase1_runs_stl

from scipy.stats import ks_2samp
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw
import scipy.io as sio

'''
Name: BareMultiRun.py
Description: Takes in a good WT plot, then some number of bad plots
and prints the averages of the variables of interest
Input:
Output:
'''

def parse_args():
    parser = argparse.ArgumentParser(prog='BareMultiRun.py')
    # General options that are actually required
    #parser.add_argument('--sims', nargs='+', type=str,
    #        help='List of simulations to plot.')
    parser.add_argument('--superdir', required=True, type=str,
            help='Superdirectory')
    parser.add_argument('--supersim', required=True, type=str,
            help='Supersimulation')
    parser.add_argument('--nmt', nargs='+', type=str,
            help='List of MT number to plot (separate curves).')
    parser.add_argument('--nx', nargs='+', type=str,
            help='List of Chromosome number to plot (separate X labels).')

    parser.add_argument('--stddevmean', action='store_true',
            help='Write out std. dev of mean, rather than sqrt variance')

    # Minimum of extra options to make this work
    parser.add_argument('--nopost', action='store_true',
            help="Do not use post analyis program to decrease time of analysis.")
    parser.add_argument('-F', '--fitness', action='store_true',
            help='Create fitness for spindle simulations.')
    parser.add_argument('-A', '--analyze', action='store_true',
            help='Analyze data from multiple simulations')
    parser.add_argument('--datadir', type=str,
            help='Name of the data directory in which all analyzed data files will be read/written. \
                    Also the directory where all the graphs will be placed when saved. \
                    Default is set to {workdir}/data/.')
    parser.add_argument('-d', '--workdir', type=str,
            help='Name of the working directory where simulation will be run.')

    opts = parser.parse_args()
    return opts


# Class definition
class BareMultiRun(object):
    def __init__(self, opts):
        self.opts = opts
        self.cwd = os.getcwd()

        print "opts: {}".format(opts)
        #FIXME: Nseeds
        self.nseeds = 12

        self.ReadOpts()
        self.AnalyzeSims()

    def ReadOpts(self):
        if not self.opts.workdir:
            self.opts.workdir = os.path.abspath(self.cwd)
        elif not os.path.exists(self.opts.workdir):
            raise IOError( "Working directory {} does not exist.".format(
                self.opts.workdir) )
        else:
            self.opts.workdir = os.path.abspath(self.opts.workdir)

        # Come up with all of the simulations based on nmt and nx that can be gotten
        # to in the current superdir/supersim combinations
        self.simnames = []
        maindir = os.path.abspath(self.opts.superdir)
        for nmt in self.opts.nmt:
            for nx in self.opts.nx:
                subdir = os.path.join(maindir, 'mt{}_x{}'.format(nmt, nx), 'simulations', self.opts.supersim)
                self.simnames += [subdir]
                print subdir

        self.sims = {}
        for simdir in self.simnames:
            mydir = os.path.abspath(simdir)
            path_broken = os.path.normpath(mydir).split(os.sep)
            mt_x = path_broken[-3]
            mt_x_split = [x for x in mt_x.split("_")]
            NMT = mt_x_split[0]
            X = mt_x_split[1]
            if NMT not in self.sims:
                self.sims[NMT] = {}
            self.sims[NMT][X] = SpindleSim(mydir, opts)

    def AnalyzeSims(self):
        for knmt, vnmt in self.sims.iteritems():
            for kx, vx in vnmt.iteritems():
                sim = vx
                sim.Analyze()
                sim.CalcSimSuccess()

        if self.opts.fitness:
            wt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'Data', 'wt.mat')
            lstream_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'Data', 'lstream.mat')

            for knmt, vnmt in self.sims.iteritems():
                for kx, vx in vnmt.iteritems():
                    sim = vx
                    sim.Fitness(wt_path, lstream_path)


    def GraphSpindleLength(self):
        plt.style.use(ase1_runs_stl)
        fig, ax = plt.subplots() 

        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.opts.nmt)))
        markers = itertools.cycle(('s', 'o', 'v', '^', 'H', 'D', '8', 'p', '<','>','*','h','d'))

        # Loop over the vertical stack (number of MTs) first
        ic = 0
        for nmt in self.opts.nmt:
            imt = np.int(nmt)
            knmt = "mt{}".format(imt)

            # Get the marker and color information
            mmarker = markers.next()
            mcolor = colors[ic]

            # Set up the arrays for the data, it will be forthcoming
            xvals = []
            yvals = []
            yerrs = []

            # Loop over the number of chromosomes
            ix = 1
            for nx in self.opts.nx:
                inx = np.int(nx)
                knx = "x{}".format(inx)

                # Get the late time spindle info
                late_spindle_length_collection = np.array([self.GetLateSpindleLength(self.sims[knmt][knx])]).flatten()
                xvals += [ix]
                yvals += [np.mean(late_spindle_length_collection)]
                if self.opts.stddevmean:
                    yerrs += [np.std(late_spindle_length_collection, ddof=1)/np.sqrt(self.nseeds)]
                else:
                    yerrs += [np.std(late_spindle_length_collection, ddof=1)]
                ix += 1

            # Now we can plot this
            custom_label = "MT{}".format(imt)
            ax.scatter(xvals, yvals, zorder=100,
                       s=100, marker=mmarker, color=mcolor, label=custom_label)
            ax.errorbar(xvals, yvals, yerr=yerrs,
                       ecolor=mcolor, elinewidth=2, capsize=7, capthick=1, zorder=0,
                       fmt='none', marker='none')

            # Cycle the color
            ic += 1
            outer_xvals = xvals

        # Save this
        ax.set_ylim(0.0, 3.0)
        xtick_names = ["X{}".format(nx) for nx in self.opts.nx]
        plt.xticks(outer_xvals, xtick_names)
        ax.set_ylabel(r'SPB Separation ($\mu$m)')
        ax.legend()
        fig.tight_layout()
        plt.savefig('late_complex_spindle_length.pdf', dpi=fig.dpi)
        plt.close()

    def GetLateSpindleLength(self, sim):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        # Redo the self.nmeasure to make it so we can actually see the error bars
        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]

        average_lengths = []
        for sd in sim.seeds:
            time = sd.PostAnalysis.timedata['time']
            spbsep_dict = sd.PostAnalysis.timedata['spb_separation']
            spbsep_arr = [spbsep_dict[ts] for ts in time ]
            spbsep_arr = np.array(spbsep_arr)*uc['um'][1]
            spbsep_arr = np.mean(spbsep_arr.reshape(-1, self.nmeasure), axis=1)
            # Grab the average lengths after 10 minutes
            long_time_lengths = spbsep_arr[time_resampled > 10.0]
            average_lengths += [long_time_lengths[:]]

        return average_lengths

    # What about the spindle late time IPF?
    def GraphSpindleIPF(self):
        plt.style.use(ase1_runs_stl)
        fig, ax = plt.subplots()

        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.opts.nmt)))
        markers = itertools.cycle(('s', 'o', 'v', '^', 'H', 'D', '8', 'p', '<','>','*','h','d'))

        # Loop over the vertical stack (number of MTs) first
        ic = 0
        for nmt in self.opts.nmt:
            imt = np.int(nmt)
            knmt = "mt{}".format(imt)

            # Get the marker and color information
            mmarker = markers.next()
            mcolor = colors[ic]

            # Set up the arrays for the data, it will be forthcoming
            xvals = []
            yvals = []
            yerrs = []

            # Loop over the number of chromosomes
            ix = 1
            for nx in self.opts.nx:
                inx = np.int(nx)
                knx = "x{}".format(inx)

                # Get the late time spindle info
                late_spindle_ipf_collection = np.array([self.GetLateSpindleIPF(self.sims[knmt][knx])]).flatten()
                xvals += [ix]
                yvals += [np.mean(late_spindle_ipf_collection)]
                if self.opts.stddevmean:
                    yerrs += [np.std(late_spindle_ipf_collection, ddof=1)/np.sqrt(self.nseeds)]
                else:
                    yerrs += [np.std(late_spindle_ipf_collection, ddof=1)]
                ix += 1

            # Now we can plot this
            custom_label = "MT{}".format(imt)
            # First figure
            ax.scatter(xvals, yvals, zorder=100,
                       s=100, marker=mmarker, color=mcolor, label=custom_label)
            ax.errorbar(xvals, yvals, yerr=yerrs,
                       ecolor=mcolor, elinewidth=2, capsize=7, capthick=1, zorder=0,
                       fmt='none', marker='none')

            # Cycle the color
            ic += 1
            outer_xvals = xvals

        # Save this
        plt.figure(1)
        ax.set_ylim(0.0, 1.1)
        xtick_names = ["X{}".format(nx) for nx in self.opts.nx]
        plt.xticks(outer_xvals, xtick_names)
        ax.set_ylabel(r'Interpolar Fraction')
        ax.legend()
        fig.tight_layout()
        plt.savefig('late_complex_ipf.pdf', dpi=fig.dpi)
        plt.close()

    # What about the spindle late time number of ipMTs?
    def GraphSpindleNipMT(self):
        plt.style.use(ase1_runs_stl)
        fig, ax = plt.subplots()

        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.opts.nmt)))
        markers = itertools.cycle(('s', 'o', 'v', '^', 'H', 'D', '8', 'p', '<','>','*','h','d'))

        # Loop over the vertical stack (number of MTs) first
        ic = 0
        for nmt in self.opts.nmt:
            imt = np.int(nmt)
            knmt = "mt{}".format(imt)

            # Get the marker and color information
            mmarker = markers.next()
            mcolor = colors[ic]

            # Set up the arrays for the data, it will be forthcoming
            xvals = []
            yvals = []
            yerrs = []

            # Loop over the number of chromosomes
            ix = 1
            for nx in self.opts.nx:
                inx = np.int(nx)
                knx = "x{}".format(inx)

                # Get the late time spindle info
                late_spindle_ipf_collection = np.array([self.GetLateSpindleIPF(self.sims[knmt][knx])]).flatten()
                xvals += [ix]
                yvals += [np.mean(late_spindle_ipf_collection) * 2.* imt]
                if self.opts.stddevmean:
                    yerrs += [np.std(late_spindle_ipf_collection, ddof=1) * 2.* imt/np.sqrt(self.nseeds)]
                else:
                    yerrs += [np.std(late_spindle_ipf_collection, ddof=1)*2.*imt]
                ix += 1

            # Now we can plot this
            custom_label = "MT{}".format(imt)
            # First figure
            ax.scatter(xvals, yvals, zorder=100,
                       s=100, marker=mmarker, color=mcolor, label=custom_label)
            ax.errorbar(xvals, yvals, yerr=yerrs,
                       ecolor=mcolor, elinewidth=2, capsize=7, capthick=1, zorder=0,
                       fmt='none', marker='none')

            # Cycle the color
            ic += 1
            outer_xvals = xvals

        # Save this
        plt.figure(1)
        ax.set_ylim(0.0, 41.0)
        xtick_names = ["X{}".format(nx) for nx in self.opts.nx]
        plt.xticks(outer_xvals, xtick_names)
        ax.set_ylabel(r'N ipMT')
        ax.legend()
        fig.tight_layout()
        plt.savefig('late_complex_nipmt.pdf', dpi=fig.dpi)
        plt.close()

    def GetLateSpindleIPF(self, sim):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        # Redo the self.nmeasure to make it so we can actually see the error bars
        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]

        average_ipf = []
        for sd in sim.seeds:
            time = sd.PostAnalysis.timedata['time']
            ipf_dict = sd.PostAnalysis.timedata['interpolar_fraction']
            ipf_arr = np.array([ipf_dict[ts] for ts in time ])
            ipf_arr = np.mean(ipf_arr.reshape(-1, self.nmeasure), axis=1)
            # Grab the average lengths after 10 minutes
            long_time_ipf = ipf_arr[time_resampled > 10.0]
            average_ipf += [long_time_ipf[:]]

        return average_ipf

    # Graph the spindle attachment states
    def GraphSpindleAttach(self):
        attachment_types = ['unattached',
                            'monotelic',
                            'merotelic',
                            'syntelic',
                            'amphitelic',
                            'singlepole',
                            'doublepole']
        for atype in attachment_types:
            plt.style.use(ase1_runs_stl)
            fig, ax = plt.subplots() 

            colors = mpl.cm.rainbow(np.linspace(0,1,len(self.opts.nmt)))
            markers = itertools.cycle(('s', 'o', 'v', '^', 'H', 'D', '8', 'p', '<','>','*','h','d'))

            # Loop over the vertical stack (number of MTs) first
            ic = 0
            for nmt in self.opts.nmt:
                imt = np.int(nmt)
                knmt = "mt{}".format(imt)

                # Get the marker and color information
                mmarker = markers.next()
                mcolor = colors[ic]

                # Set up the arrays for the data, it will be forthcoming
                xvals = []
                yvals = []
                yerrs = []

                # Loop over the number of chromosomes
                ix = 1
                for nx in self.opts.nx:
                    inx = np.int(nx)
                    knx = "x{}".format(inx)

                    # Get the late time spindle info
                    if knx != "x0":
                        late_spindle_length_collection = np.array([self.GetLateSpindleAttach(self.sims[knmt][knx], atype)]).flatten()
                    else:
                        late_spindle_length_collection = np.array([0.0, 0.0])
                    xvals += [ix]
                    yvals += [np.mean(late_spindle_length_collection)]
                    if self.opts.stddevmean:
                        yerrs += [np.std(late_spindle_length_collection, ddof=1)/np.sqrt(self.nseeds)]
                    else:
                        yerrs += [np.std(late_spindle_length_collection, ddof=1)]
                    ix += 1

                # Now we can plot this
                custom_label = "MT{}".format(imt)
                ax.scatter(xvals, yvals, zorder=100,
                           s=100, marker=mmarker, color=mcolor, label=custom_label)
                ax.errorbar(xvals, yvals, yerr=yerrs,
                           ecolor=mcolor, elinewidth=2, capsize=7, capthick=1, zorder=0,
                           fmt='none', marker='none')

                # Cycle the color
                ic += 1
                outer_xvals = xvals

            # Save this
            ax.set_ylim(0.0, 4.1)
            xtick_names = ["X{}".format(nx) for nx in self.opts.nx]
            plt.xticks(outer_xvals, xtick_names)
            ax.set_ylabel(r'N {}'.format(atype))
            ax.legend()
            fig.tight_layout()
            plt.savefig('late_complex_{}.pdf'.format(atype), dpi=fig.dpi)
            plt.close()

    def GetLateSpindleAttach(self, sim, atype):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]
        average_state = []

        for sd in sim.seeds:
            #print sd.succ_info_dict
            time = sd.PostAnalysis.timedata['time']
            attach_dict = sd.PostAnalysis.timedata['kc_atypes']
            attach_arr = [attach_dict[ts] for ts in time]
            nchromo = len(attach_arr[0])
            #atype = 'merotelic'
            if atype == 'unattached':
                atypeint = [0]
            elif atype == 'monotelic':
                atypeint = [1]
            elif atype == 'merotelic':
                atypeint = [2]
            elif atype == 'syntelic':
                atypeint = [3]
            elif atype == 'amphitelic':
                atypeint = [4]
            elif atype == 'singlepole':
                atypeint = [0, 1, 3]
            elif atype == 'doublepole':
                atypeint = [2, 4]

            plot_arr = np.zeros(len(attach_arr))
            
            # Regenerate the merotelic numbers
            for x in xrange(len(attach_arr)):
                for ic in xrange(nchromo):
                    attach = attach_arr[x][ic]
                    if attach in atypeint:
                        plot_arr[x] += 1.0
            plot_arr = np.mean(plot_arr.reshape(-1, self.nmeasure), axis=1)
            long_time_state = plot_arr[time_resampled > 10.0]
            average_state += [long_time_state[:]]

        return average_state

    # Get the occupancy of the binding sites
    def GraphSpindleOccupancy(self):
        plt.style.use(ase1_runs_stl)
        fig, ax = plt.subplots() 

        colors = mpl.cm.rainbow(np.linspace(0,1,len(self.opts.nmt)))
        markers = itertools.cycle(('s', 'o', 'v', '^', 'H', 'D', '8', 'p', '<','>','*','h','d'))

        # Loop over the vertical stack (number of MTs) first
        ic = 0
        for nmt in self.opts.nmt:
            imt = np.int(nmt)
            knmt = "mt{}".format(imt)

            # Get the marker and color information
            mmarker = markers.next()
            mcolor = colors[ic]

            # Set up the arrays for the data, it will be forthcoming
            xvals = []
            yvals = []
            yerrs = []

            # Loop over the number of chromosomes
            ix = 1
            for nx in self.opts.nx:
                inx = np.int(nx)
                knx = "x{}".format(inx)

                # Get the late time spindle info
                if knx != "x0":
                    late_spindle_length_collection = np.array([self.GetLateSpindleOccupancy(self.sims[knmt][knx])]).flatten()
                else:
                    late_spindle_length_collection = np.array([0.0, 0.0])
                xvals += [ix]
                yvals += [np.mean(late_spindle_length_collection)]
                if self.opts.stddevmean:
                    yerrs += [np.std(late_spindle_length_collection, ddof=1)/np.sqrt(self.nseeds)]
                else:
                    yerrs += [np.std(late_spindle_length_collection, ddof=1)]
                ix += 1

            # Now we can plot this
            custom_label = "MT{}".format(imt)
            ax.scatter(xvals, yvals, zorder=100,
                       s=100, marker=mmarker, color=mcolor, label=custom_label)
            ax.errorbar(xvals, yvals, yerr=yerrs,
                       ecolor=mcolor, elinewidth=2, capsize=7, capthick=1, zorder=0,
                       fmt='none', marker='none')

            # Cycle the color
            ic += 1
            outer_xvals = xvals

        # Save this
        ax.set_ylim(0.0, 24.0)
        xtick_names = ["X{}".format(nx) for nx in self.opts.nx]
        plt.xticks(outer_xvals, xtick_names)
        ax.set_ylabel(r'KC-MT Occupancy')
        ax.legend()
        fig.tight_layout()
        plt.savefig('late_complex_occupancy.pdf', dpi=fig.dpi)
        plt.close()

    def GetLateSpindleOccupancy(self, sim):
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        # Set the measurement time for everybody
        self.measure_time = 8.0/60.0
        time = sim.seeds[-1].time[:min_size]
        self.nmeasure = np.int(self.measure_time / ((time[1] - time[0])))

        self.nmeasure = 4 * self.nmeasure

        time = sim.seeds[-1].time[:min_size]
        time_resampled = time[0::self.nmeasure]
        average_occupancy = []

        for sd in sim.seeds:
            time = sd.PostAnalysis.timedata['time']
            occupancy_dict = sd.PostAnalysis.timedata['kc_occupancy']
            occupancy_arr = np.array([occupancy_dict[ts] for ts in time])
            occupancy_arr = np.mean(occupancy_arr.reshape(-1, self.nmeasure), axis=1)
            late_time_occupancy = occupancy_arr[time_resampled > 10.0]
            average_occupancy += [late_time_occupancy[:]]

        return average_occupancy

################
if __name__ == "__main__":
    opts = parse_args()
    x = BareMultiRun(opts)
    x.GraphSpindleLength()
    x.GraphSpindleIPF()
    x.GraphSpindleNipMT()
    x.GraphSpindleAttach()
    x.GraphSpindleOccupancy()
