#!/usr/bin/env python
# Basic
import sys, os, pdb
import gc
import argparse
# Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *

from spindle_movie import *
from spindle_run import SpindleRun
from spindle_seed import SpindleSeed
from spindle_sim import SpindleSim
from base_funcs import *
from seed_graph_funcs import *
from sim_graph_funcs import *
from run_graph_funcs import *
from spindle_unit_dict import SpindleUnitDict, ModifyXLabel


'''
Name: SpindleAnalysis.py
Description: Command that calls specific bulk analysis programs
Input: To see type SpindleAnalysis.py -h
Output: See above
'''

def parse_args():
    parser = argparse.ArgumentParser(prog='SpindleAnalysis.py')
    # Analysis type options

    # General options
    parser.add_argument('-sd', '--seed', action='store_true',
            help='Run analysis on only a single seed.')

    parser.add_argument('--sim', action='store_true',
            help='Run analysis on a single sim.')

    parser.add_argument('-d', '--workdir', type=str,
            help='Name of the working directory where simulation will be run.')

    parser.add_argument('--datadir', type=str,
            help='Name of the data directory in which all analyzed data files will be read/written. \
                    Also the directory where all the graphs will be placed when saved. \
                    Default is set to {workdir}/data/.')

    parser.add_argument('-p', '--params', nargs='+', type=str,
            help='List of param names that will be used in graphing.')

    parser.add_argument('-f', '--file', type=str,
            help='File used in the reading function.')

    parser.add_argument('-t', '--title', type=str,
            help='Suptitle for plots if graphing is activated.')

    parser.add_argument('-m',
                        '--movie',
                        nargs='?',
                        type=str,
                        default=None,
                        const='basic',
                        help='flag that tells spindle analysis to make a movie of a seed (and submovies of blurred images).')

    parser.add_argument('--movie_all', action='store_true',
            help='flag that tells spindle analysis to make movies for all seed directories in the work_directory.')

    parser.add_argument('--movie_comp', type=int, nargs='?', const=2,
            help='Option to make movies from simulations in a run. \
                  This will create an equal number of successful and failed \
                  unless there is not enough of one type in which case it will \
                  make as as many as it can. Default number = 2. \
                  Run directory must be the work dir. All movies are placed in \
                  the data/[sim] directory.')

    # Different main actions taken by analysis program

    parser.add_argument('-A', '--analyze', action='store_true',
            help='Analyze data from multiple simulations')

    parser.add_argument('-G', '--graph', action='store_true',
            help='Graph data after analysis has been done.')

    parser.add_argument('-W', '--write', action='store_true',
            help='Make analyzed data files from raw data files.')

    parser.add_argument('-R', '--read', action='store_true',
            help='Read in all analyzed data files in "simulations" directory.')

    parser.add_argument('-F', '--fitness', type=str, default='WT_Cen2', nargs='?', const='WT_Cen2',
            help='Create fitness for spindle simulations.')

    parser.add_argument('-tg', '--test_graph', type=str,
            help="Test a specific graphing program defined in the argument of this option.")

    parser.add_argument('--nopost', action='store_true',
            help="Do not use post analyis program to decrease time of analysis.")


    # Graphing specific options
    parser.add_argument('--xlog', action='store_true',
            help='Criteria graph option that produces graphs with a logarithmic x-axis.')

    parser.add_argument('--ylog', action='store_true',
            help='Criteria graph option that produces graphs with a logarithmic y-axis. (not implemented yet)')

    parser.add_argument('--lin_reg', action='store_true',
            help='Adds linear regression line, confidence interval band, and prediction interval band to succ_frac graph')

    opts = parser.parse_args()
    return opts

##Class definition
class SpindleAnalysis(object):
    def __init__(self, opts):
        self.opts = opts
        self.cwd = os.getcwd()

        self.ReadOpts()

        self.ProgOpts()

    # Read in options
    def ReadOpts(self):
        if not self.opts.workdir:
            self.opts.workdir = os.path.abspath(self.cwd)
        elif not os.path.exists(self.opts.workdir):
            raise IOError( "Working directory {} does not exist.".format(
                self.opts.workdir) )
        else:
            self.opts.workdir = os.path.abspath(self.opts.workdir)

        # If data files or graphs are going to be made put
        # them a single data directory for easy access
        if ( (self.opts.graph or self.opts.write or 
              self.opts.read or self.opts.movie_comp) and
                not self.opts.datadir ):
            self.opts.datadir = create_datadir(self.opts.workdir)

    def ProgOpts(self):

        if self.opts.movie:
            self.MakeSeedMovie(self.opts.workdir)

        elif self.opts.movie_all:
            self.opts.movie = 'basic'
            for sd in os.listdir(self.opts.workdir):
                sd_path = os.path.abspath(sd)
                print sd_path
                self.MakeSeedMovie(sd_path)

        elif self.opts.seed:
            self.AnalyzeSeed()

        elif self.opts.sim:
            self.AnalyzeSim()

        else:
            self.AnalyzeRun()

    def AnalyzeSeed(self):
        sd = SpindleSeed(self.opts.workdir, self.opts)
        # sd.MakeDataDict()
        sd.AnalyzeAll()

        if self.opts.graph:
            plt.style.use(ase1_runs_stl)
            fig, axarr = plt.subplots(4,1, figsize=(15,10))
            # fig, axarr = plt.subplots(5,1, figsize=(20,10))
            sd.GraphAllvsTime(axarr)
            fig.tight_layout()
            fig.savefig('spindle_parameters.pdf', dpi=fig.dpi)

            if sd.PostAnalysis.analyze_chromosomes:
                plt.style.use(ase1_runs_stl)
                fig2, axarr2 = plt.subplots(3,1, figsize=(15,10))
                sd.GraphKCvsTime(axarr2)
                fig2.tight_layout()
                fig2.savefig('kc_distances.pdf', dpi=fig2.dpi)

                # Now do ones that are for different figures, each full sized, for production quality
                # stuff
                fig7, axarr7 = plt.subplots(3, 1, figsize=(15,10))
                sd.GraphkMTLifetimes(axarr7)
                fig7.tight_layout()
                fig7.savefig('kmt_lifetimes.pdf', dpi=fig7.dpi)

                # Also do the force stuff
                fig3, axarr3 = plt.subplots(4, 1, figsize=(15,10))
                sd.GraphTangentForcevsTime(axarr3)
                fig3.tight_layout()
                fig3.savefig('tangent_forces.pdf', dpi=fig3.dpi)

                ## Also do the other force stuff
                #fig4, axarr4 = plt.subplots(3, 1, figsize=(15,10))
                #sd.GraphSPBForcevsTime(axarr4)
                #fig4.tight_layout()
                #fig4.savefig('spb_forces.pdf', dpi=fig4.dpi)
                # Also do the other force stuff
                fig4, axarr4 = plt.subplots(2, 1, figsize=(15,10))
                sd.GraphXlinkAFForces(axarr4)
                fig4.tight_layout()
                fig4.savefig('xlink_af_forces.pdf', dpi=fig4.dpi)

                # Occupancy, merotelic/amphitelic/other stuff
                fig5, axarr5 = plt.subplots(4, 1, figsize=(15,10))
                sd.GraphOccupancy(axarr5)
                fig5.tight_layout()
                fig5.savefig('occupancy.pdf', dpi=fig5.dpi)

                # Do one that is a 3x version where we have the same format as the kc distances
                fig6, axarr6 = plt.subplots(3, 1, figsize=(15,10))
                sd.GraphChromosomeLengthIPFAmphi(axarr6)
                fig6.tight_layout()
                fig6.savefig('length_ipf_amphi.pdf', dpi=fig6.dpi)

        if bool(self.opts.test_graph):
            self.SeedGraph(sd)

        if (self.opts.graph or self.opts.test_graph != None):
            # plt.tight_layout()
            plt.show()

        if self.opts.fitness:
            wt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'wt.mat')
            lstream_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'lstream.mat')
            sd.Fitness(wt_path, lstream_path)

    def AnalyzeSim(self):
        sim = SpindleSim(self.opts.workdir, opts)
        sim.Analyze()
        sim.CalcSimSuccess()

        if self.opts.graph: sim.GraphSimulation()

        if bool(self.opts.test_graph):
            self.SimGraph(sim)

        if self.opts.fitness:
            wt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'wt.mat')
            lstream_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'lstream.mat')
            sim.Fitness(wt_path, lstream_path)

    def AnalyzeRun(self):
        run = SpindleRun(self.opts)

        if self.opts.analyze or self.opts.graph:
            run.AnalyzeSims()
            self.opts.analyze = True

        if self.opts.write:
            if not self.opts.analyze:
                run.AnalyzeSims()
                self.opts.analyze = True
            run.MakeCriteriaDataFrame()
            run.WriteCriteriaDataFrame()

        if self.opts.fitness:
            wt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'wt.mat')
            lstream_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'lstream.mat')
            run.Fitness(wt_path, lstream_path)

        if self.opts.read:
            if self.opts.file: run.ReadCriteriaDataFrame(self.opts.file)
            else: run.ReadCriteriaDataFrame()
            #For some reason you need this in order to standardize the graphs later
            graph_list = [SpindleRun.GraphSpindleSepScatter,
                          SpindleRun.GraphSuccFracScatter,
                          #graph_run_avg_start_time,
                          ]
            if self.opts.fitness and run.sims[0].analyze_chromosomes:
                graph_list.append(SpindleRun.GraphChromosomeSeconds)
                graph_list.append(SpindleRun.GraphFinalSpindleLength)
                graph_list.append(SpindleRun.GraphFinalIPF)
                graph_list.append(SpindleRun.GraphFinalOccupancy)
                graph_list.append(SpindleRun.GraphSpecialLegos) # FIXME: Figure outwhy this isn't working later:
                #graph_list.append(SpindleRun.GraphFinalForces)
            for GF in graph_list:
                # Apply the correct style only to runs
                with plt.style.context(ase1_runs_stl):
                    fig, ax = self.MakeCritGraph(run, GF)
                    # Clean up after yourself
                    plt.close(fig)
                gc.collect()

        if self.opts.movie_comp:
            run.ReadCriteriaDataFrame()
            mov_dict = run.MakeSuccFailMovieDict()
            # print mov_dict
            for sim_name, comp_dict in mov_dict.iteritems():
                # print sim_name
                sim = comp_dict['sim']
                sim_datadir = sim.SetSimDataDir() # Make sure sim data path is set
                for i in xrange(self.opts.movie_comp):
                    try: self.MakeSeedMovie(comp_dict['succ'][i],
                                            os.path.join(sim_datadir,
                                                '{}_succ_{}'.format(sim_name, {})))
                    except: pass
                    try: self.MakeSeedMovie(comp_dict['fail'][i],
                                            os.path.join(sim_datadir,
                                                '{}_fail_{}'.format(sim_name, {})))
                    except: pass
            return

        if self.opts.graph: 
            run.MakeDistributionGraphs()
            run.GraphSims()
        if self.opts.test_graph: self.RunGraph(run)

        return

    def MakeCritGraph(self, run, GraphFunc):
        if GraphFunc == SpindleRun.GraphSpecialLegos:
            from mpl_toolkits.mplot3d import Axes3D
            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            file_name, title = GraphFunc(run, ax)
        else:
            # Get figures and axis objects
            from stylelib.ase1_styles import cp_spindle_stl
            plt.style.use(cp_spindle_stl)
            fig, ax = plt.subplots(figsize=(2.5,2.5))
            # plt.style.use(ase1_runs_stl)
            # Get the title and file name from the runs graph function
            file_name, title = GraphFunc(run, ax) ######
            #ax.set_title(title)

        fig.tight_layout()
        fig.savefig(file_name)

        return fig, ax
        # return

    def MakeSeedMovie(self, sd_dir, mov_path='' ):
        orig_dir = os.getcwd()
        os.chdir(sd_dir)
        sd = SpindleSeed(sd_dir, self.opts)
        sd.AnalyzeAll()
        mov_file = sd.MakeMovie()
        # mov_file = "{}.mov".format(sd.seed_num)
        if mov_path: 
            mov_path = mov_path.format(mov_file)
            os.rename(mov_file, os.path.abspath(mov_path))
        os.chdir(orig_dir)
        # make_spindle_movie(self.opts.workdir)

    def SeedGraph(self, sd):
        tg = self.opts.test_graph

        if tg == 'GraphSPBStage1XlinkDistance':
            fig, ax = plt.subplots()
            sd.GraphSPBStage1XlinkDistance(ax, xstate=r'integrated')
        elif tg == 'GraphSPBStage2XlinkDistance':
            fig, ax = plt.subplots()
            sd.GraphSPBStage2XlinkDistance(ax, xstate=r'integrated')
        elif tg == 'graph_num_xlinks':
            fig, ax = plt.subplots()
            graph_num_xlinks(ax,sd, species_ind=0,color='y', label="stage0")
            graph_num_xlinks(ax,sd, species_ind=1,color='r', label="stage1")
            graph_num_xlinks(ax,sd, species_ind=2,color='b', label="stage2")
        elif tg == 'GraphPostDistributions':
            if sd.PostAnalysis.analyze_chromosomes:
                fig, axarr = plt.subplots(3,1, figsize=(25,16))
            else:
                # fig_post, axarr_post = plt.subplots(2,1, figsize=(25,16))
                fig, axarr = plt.subplots(2,1)
            sd.GraphPostDistributions(axarr, xstate=r'integrated')
            fig.tight_layout()
        elif tg == 'graph_avg_mt_length':
            fig, ax = plt.subplots()
            graph_avg_mt_length(ax, sd)
        elif tg == 'graph_spb_sep':
            fig, ax = plt.subplots()
            graph_spb_sep(ax, sd)
        elif tg == 'graph_mt_length_distributions':
            fig, ax = plt.subplots()
            graph_mt_length_distributions(sd, ax)

        elif tg == 'graph_all_mt_length_distr':
            fig, axarr = plt.subplots(3,1, figsize = (12, 8))
            graph_all_mt_length_distr(sd, axarr)

        elif tg == 'graph_avg_mt_splay':
            fig, ax = plt.subplots()
            graph_avg_mt_splay(ax,sd, spb_ind=0,color='y', label="spb0")
            graph_avg_mt_splay(ax,sd, spb_ind=1,color='r', label="spb1")
            graph_avg_mt_splay(ax,sd, spb_ind=2,color='b', label="all")
            ax.legend()
        elif tg == 'graph_interpolar_fraction':
            fig, ax = plt.subplots()
            graph_interpolar_fraction(sd,ax)
        elif tg == 'graph_interpolar_length_fraction':
            fig, ax = plt.subplots()
            graph_interpolar_length_fraction(sd,ax)
        elif tg == 'graph_mt_length_distr_by_index':
            fig, ax = plt.subplots()
            graph_mt_length_distr_by_index(sd,ax)
        elif tg == 'graph_spb_stageN_xlink_distance':
            fig, axarr = plt.subplots(3,1, figsize = (10,6))
            for i in xrange(3):
                graph_spb_stageN_xlink_distance(sd, axarr[i], stage=i+1)
            plt.tight_layout()
        elif tg == 'graph_spindle_xlink_distance':
            fig, ax = plt.subplots()
            graph_spindle_xlink_distance(sd, ax) 
        elif tg == 'graph_spindle_xlink_distance_all':
            fig, axarr = plt.subplots(3,1, figsize=(10,8))
            graph_spindle_xlink_distance_all(sd, axarr) 
        elif tg == 'graph_kc_attachment_types':
            fig, axarr = plt.subplots(5,1, figsize=(15,10))
            graph_kc_attachment_types(sd,axarr)
        else:
            raise NotImplementedError("{} not a graph function".format(tg))

    def SimGraph(self, sim):
        tg = self.opts.test_graph

        if tg == r'graph_sim_spb_stage1_xlink_distributions':
            fig, ax = plt.subplots()
            graph_sim_spb_stageN_xlink_distributions(sim, ax, self.opts, stage=1, xstate=r'integrated')
        elif tg == r'graph_sim_spb_stage2_xlink_distributions':
            fig, ax = plt.subplots()
            graph_sim_spb_stageN_xlink_distributions(sim, ax, self.opts, stage=2, xstate=r'integrated')

        elif tg == r'graph_sim_spb_stage2_xlink_distributions_error':
            fig, ax = plt.subplots()
            graph_sim_spb_stageN_xlink_distributions_error(sim, ax, self.opts, stage=2, xstate=r'integrated')

        elif tg == r'graph_sim_xlink_distributions_error':
            # fig, axarr = plt.subplots(2, 2)
            fig, ax = plt.subplots()
            graph_sim_xlink_distributions_error(sim, ax, opts=self.opts, xstate=r'integrated')
            # graph_sim_spb_stageN_xlink_distributions_succ_fail_error(sim, ax, self.opts,
        elif tg == r'graph_sim_xlink_distr_succ_compare':
            fig, axarr = plt.subplots(2,2, figsize=(10,10))
            graph_sim_xlink_distr_succ_compare(sim, axarr, opts=self.opts, xstate=r'integrated', xlabel=True)
        elif tg == r'graph_sim_mt_length_distr':
            fig, axarr = plt.subplots(3,1)
            graph_sim_mt_length_distr(sim, axarr, opts=self.opts)
        elif tg == r'graph_sim_mt_length_distr_succ_compare':
            fig, axarr = plt.subplots(3,2, figsize=(12,8))
            graph_sim_mt_length_distr_succ_compare(sim, axarr, opts=self.opts)
        elif tg == r'graph_sim_mt_length_distr_error':
            fig, axarr = plt.subplots(4,1, figsize=(12,10))
            graph_sim_mt_length_distr_error(sim, axarr, opts=self.opts)
        elif tg == r'graph_sim_mt_length_by_index':
            fig, ax = plt.subplots()
            graph_sim_mt_length_by_index(sim, ax, self.opts)
            return #XXX remove this after tests
        elif tg == r'graph_sim_mt_length_distr_final':
            fig, axarr = plt.subplots(4,3, figsize=(12,12))
            graph_sim_mt_length_distr_final(sim, axarr)
        elif tg == r'graph_sim_spindle_xlink_distance':
            fig, axarr = plt.subplots(3,1, figsize=(10,8))
            graph_sim_spindle_xlink_distance(sim, axarr)
        elif tg == r'graph_sim_spindle_xlink_distance_error':
            fig, ax = plt.subplots()
            graph_sim_spindle_xlink_distance_error(sim, ax)
            # fig, axarr = plt.subplots(3,1, figsize=(10,8))
        elif tg == r'graph_sim_spindle_xlink_distance_error_all':
            from stylelib.ase1_styles import ase1_sims_stl
            with plt.style.context(ase1_xl_spindle_stl):
                fig, axarr = plt.subplots(3,1,figsize=(10,8))
                graph_sim_spindle_xlink_distance_error_all(sim, axarr)
                fig.tight_layout()
        elif tg == r'graph_sim_spindle_xlink_distance_final':
            fig, axarr = plt.subplots(3,2,figsize=(12,8))
            graph_sim_spindle_xlink_distance_final(sim, axarr)
        elif tg == r'graph_sim_start_time_hist':
            # fig, axarr = plt.subplots(figsize=(8,8))
            from stylelib.ase1_styles import ase1_sims_stl
            with plt.style.context(ase1_sims_stl):
                fig, ax = plt.subplots()
                graph_sim_start_time_hist(sim, ax)
                fig.tight_layout()
        elif tg == r'graph_sim_kc_attachment_types':
            fig, axarr = plt.subplots(5, 2, figsize=(15,10))
            graph_sim_kc_attachment_types(sim, axarr, opts = self.opts)
        else: 
            raise NotImplementedError('{} not a graph function'.format(tg) )

        plt.show()
        return

    def RunGraph(self, run):
        tg = self.opts.test_graph

        if tg == r'graph_run_xlink_distr_succ_compare':
            fig, axarr = plt.subplots(2,2, figsize=(10,8))
            run.CollectSims()
            run.AnalyzeSims()
            graph_run_xlink_distr_succ_compare(run, axarr, opts=self.opts, xstate=r'integrated', xlabel=True)
        elif tg == r'graph_run_mt_length_distr_error':
            fig, axarr = plt.subplots(4,1, figsize=(10,8))
            graph_run_mt_length_distr_error(run)
            return
        elif tg == r'graph_run_spindle_xlink_distance_error_all':
            fig, axarr = plt.subplots(3,1, figsize=(10,8))
            graph_run_spindle_xlink_distance_error_all(run, axarr)
        elif tg == r'graph_run_avg_start_time':
            # Set ase1_run style for these graphs
            from stylelib.ase1_styles import ase1_runs_stl
            with plt.style.context(ase1_runs_stl):
                fig, ax = plt.subplots()
                graph_run_avg_start_time(run, ax)
                fig.tight_layout()
        else: 
            raise NotImplementedError('{} not a graph function'.format(tg) )
        plt.show()
        return


##########################################
if __name__ == "__main__":
    opts = parse_args()
    x = SpindleAnalysis(opts)




