#!/usr/bin/env python
# In case of poor (Sh**y) commenting contact adam.lamson@colorado.edu
# Basic
import sys, os, pdb
from collections import OrderedDict
from subprocess import call
from shutil import rmtree
## Analysis
import yaml
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Image_Generation'))
from seed_base import SeedBase
from spindle_unit_dict import SpindleUnitDict
from spindle_fitness import SpindleFitness
from spindle_movie import make_spindle_movie
from SpindlePostAnalysis import SpindlePostAnalysis
from seed_graph_funcs import *
from criteria_funcs import *
try: import line_profiler
except: pass
# Chris Edelmaier was here, needed for special image generation of movies
from blur_2d import GaussianBlurCreator2D

import cv2

'''
Name: spindle_seed.py
Description:
'''

uc = SpindleUnitDict()

#Class definition
class SpindleSeed(SeedBase):
    def __init__(self, path, opts):
        SeedBase.__init__(self, path, opts)

        self.calc_avg_flag = False
        self.criteria_flag = False
        self.analyzed = False # Flag to determine whether timedata should be re-pickled

        # Analysis objects
        self.PostAnalysis = None

        # List of beginning and end times if spindle was successful.
        # If not self.success_spindle remains None
        self.success_spindle = None
        self.succ_info_dict = OrderedDict()

        self.spb_sep_avg = 0

    def AnalyzeAll(self):
        if (self.opts != None and not self.opts.nopost):
            self.AnalyzePostAnalysis()

    def AnalyzePostAnalysis(self):
        """ Specialized analysis of distributions and other data """
        self.PostAnalysis = SpindlePostAnalysis(self.path, self.opts, uc)
        # If there is a bad seed, bail early
        if self.PostAnalysis.bad_seed:
            print "Bad seed detected, returning without analyzing"
            self.PostAnalysis.UnloadPosit()
            return
        self.analyzed = self.PostAnalysis.Analyze(self.opts.analyze)
        self.MakeSeedSuccessDict() # Maybe modified later with options
        if self.PostAnalysis.MakeDistributionData(): self.analyzed = True
        if self.analyzed: self.PostAnalysis.savestate()
        self.PostAnalysis.UnloadPosit()
        self.time = np.array(self.PostAnalysis.timedata['time'])

    def MakeSeedSuccessDict(self, **kwargs):
        # if not self.criteria_flag: self.CalcDurationCriteria(interpolar_fraction_threshold, **kwargs)

        # If we have a bad seed, report that!
        if self.PostAnalysis.bad_seed:
            print "Bad seed detected, continuing on!"
            return

        if 'succ_info' not in self.PostAnalysis.timedata: 
            if self.PostAnalysis.analyze_chromosomes:
                self.PostAnalysis.SeedSuccess(seed_num = self.seed_num,
                                              crit_func = amphitelic_attachment_threshold,
                                              namphi = 3,
                                              spindle_length = 1.05,
                                              duration = 0.01,
                                              end_flag = False,
                                              break_legnth = 0.2) 
                #self.PostAnalysis.SeedSuccess(seed_num = self.seed_num,
                #                              crit_func = amphitelic_attachment_interpolar_threshold,
                #                              namphi = 3,
                #                              spindle_length = 1.05,
                #                              duration = 2.0,
                #                              end_flag = False,
                #                              break_legnth = 0.2) 
                self.analyzed = True
            #elif self.PostAnalysis.wait_time >= 0.08: # Hack for now on the wait time beginining
            #    self.PostAnalysis.SeedSuccess(seed_num = self.seed_num,
            #                                  crit_func = first_bipolar_formation_threshold)
            #    self.analyzed = True
            else:
                self.PostAnalysis.SeedSuccess(seed_num = self.seed_num)
                self.analyzed = True

        self.succ_info_dict = self.PostAnalysis.timedata['succ_info']

        if not self.criteria_flag:
            self.label += '*' if self.succ_info_dict['succ'] else ''
            self.criteria_flag = True

        return self.succ_info_dict

    def GetXlinkDistributionData(self, stage=3, xstate=r'integrated', **kwargs):
        if stage == 1: 
            xl = self.PostAnalysis.distrdata['spb_stage1_xlink_distance'][xstate]
        elif stage == 2: 
            xl = self.PostAnalysis.distrdata['spb_stage2_xlink_distance'][xstate]
        else: 
            xl_one = self.PostAnalysis.distrdata['spb_stage1_xlink_distance'][xstate]
            xl_two = self.PostAnalysis.distrdata['spb_stage2_xlink_distance'][xstate]
            xl = {}
            for k in xl_one:
                xl[k] = np.add(xl_one[k], xl_two[k])
        return xl

    def GetKCDistributionData(self, xstate=r'integrated', xtype = 'full', **kwargs):
        if 'kc_distribution' not in kwargs: kwargs['kc_distribution'] = 'kc_spindle_1d'
        namedist = kwargs['kc_distribution']
        kc = self.PostAnalysis.distrdata[namedist][xstate][xtype]

        return kc

    def GetKCStretchData(self, xstate=r'integrated', xtype = 'full', **kwargs):
        kc = self.PostAnalysis.distrdata['inter_kc_stretch']

        return kc

    # Set my internal force data
    def SetForceData(self, time_resampled, force_data):
        self.time_resampled_forces = time_resampled
        self.force_data = force_data

    # Graphs all possible data with relation to time
    def GraphAllvsTime(self, axarr, color = 'b'):
        graph_spb_sep(self, axarr[0], color = color)
        graph_interpolar_fraction(self, axarr[1], color = color)
        graph_interpolar_length_fraction(self, axarr[2], color = color)
        graph_avg_mt_splay(self, axarr[3], color = color)
        # graph_avg_mt_length(self, axarr[3], color = color)
        # graph_num_xlinks(self, axarr[4], color = color)
        return

    def GraphKCvsTime(self, axarr, color = 'b'):
        graph_kc_spb_distance(self, axarr, color = color)    
        return

    def GraphkMTLifetimes(self, axarr, color = 'b'):
        graph_kmt_lifetimes(self, axarr, color = color)
        return

    def GraphChromosomeLengthIPFAmphi(self, axarr, color = 'b'):
        graph_length_ipf_amphi(self, axarr, color = color)
        return

    def GraphSPBForcevsTime(self, axarr, color = 'b'):
        graph_spb_force(self, axarr, color = color)
        return

    def GraphXlinkAFForces(self, axarr, color = 'b'):
        graph_xlinkaf_force(self, axarr, color = color)
        return

    def GraphTangentForcevsTime(self, axarr, color = 'b'):
        graph_tangent_force(self, axarr, color = color)
        return

    def GraphOccupancy(self, axarr, color = 'b'):
        graph_occupancy_attachment(self, axarr, color = color)
        return

    # Graph the specialized distributions
    def GraphPostDistributions(self, axarr, color = 'b', xstate=r'final', **kwargs):
        graph_spb_stageN_xlink_distance(self, axarr[0], stage = 1, color = color, xstate = xstate, **kwargs )
        graph_spb_stageN_xlink_distance(self, axarr[1], stage = 2, color = color, xstate = xstate, **kwargs )
        # if self.PostAnalysis.analyze_chromosomes:
            # self.GraphKCSpindle1D(axarr[2], color = color, xstate = xstate)

    def GraphXlinkSpindle1D(self, ax, color = 'b', xlabel = True, xstate = r'final', xdata_in = None):
        yarr = self.PostAnalysis.GraphXlinkSpindle1D(ax, self.label, color = color, xlabel = xlabel, xstate = xstate, xdata_in = xdata_in)
        return yarr

    def GraphKCSpindle1D(self, ax, color = 'b', xlabel = True, xstate = r'final', xdata_in = None):
        yarr = self.PostAnalysis.GraphKCSpindle1D(ax, self.label, color = color, xlabel = xlabel, xstate = xstate, xdata_in = xdata_in)
        return yarr

    def WriteAllData(self):
        # print self.SSAnalysis
        # print self.IFAnalysis
        # print self.MTAnalaysis
        return

    def Fitness(self, wt_mat, lstream_mat):
        strain_name = self.opts.fitness
        print "strain name = {}".format(self.opts.fitness)
        self.fitness = SpindleFitness(self.path, strain_name)
        if strain_name == 'WT_wt':
            if not self.PostAnalysis.bad_seed:
                self.fitness.LoadExperimentDistributions(wt_mat, lstream_mat)
                self.fitness.LoadStrainData()
                self.fitness.StatisticalTests(self.PostAnalysis.distrdata['fitness'])
                self.fitness.CalcLengthFitness(self.PostAnalysis.timedata['time'], self.PostAnalysis.timedata['spb_separation'])
                self.fitness.CalcTimeSeriesCorrelation(self.PostAnalysis.timedata['time'], self.PostAnalysis.timedata['spb_separation'])
                if self.PostAnalysis.analyze_chromosomes:
                    self.fitness.ChromosomeFitness(self.PostAnalysis.timedata['time'], self.PostAnalysis.timedata['spb_separation'], self.PostAnalysis.timedata['kc_atypes'], self.PostAnalysis.timedata['kc_distance'], self.PostAnalysis.timedata['kc_nend'])
                #self.fitness.Print()
                self.fitness.PrintYAML(self.path, self.succ_info_dict['succ'])
            else:
                self.fitness.PrintEmptyYAML(self.path)
        elif strain_name == 'WT_Cen2':
            if not self.PostAnalysis.bad_seed and self.PostAnalysis.analyze_chromosomes:
                self.fitness.LoadExperimentDistributions(wt_mat, lstream_mat)
                self.fitness.LoadStrainData()
                self.fitness.Cen2Fitness(self)
                self.fitness.PrintCen2YAML(self.path)
            else:
                self.fitness.PrintEmptyCen2YAML(self.path)
        elif strain_name == "Cerevisiae_basic":
            if not self.PostAnalysis.bad_seed and self.PostAnalysis.analyze_chromosomes:
                ## XXX: Need to get EM data for budding yeast spindles, as well as strain data
                #self.fitness.LoadExperimentDistributions(wt_mat, lstream_mat)
                self.fitness.LoadStrainData()
                self.fitness.CerevisiaeBasicFitness(self)
                self.fitness.PrintCerevisiaeBasicYAML(self.path)
            else:
                self.fitness.PrintEmptyCerevisiaeYAML(self.path)
        elif strain_name == "Cerevisiae_length":
            if not self.PostAnalysis.bad_seed and self.PostAnalysis.analyze_chromosomes:
                self.fitness.LoadStrainData()
                self.fitness.CerevisiaeBasicFitness(self)
                self.fitness.PrintCerevisiaeBasicYAML(self.path)
            else:
                self.fitness.PrintEmptyCerevisiaeYAML(self.path)

        return strain_name

    def MakeMovie(self, movie_length=30):
        # Get default and equil files, modify and reload
        #d_file = os.path.join(self.path, 'spindle_bd_mp.default.yaml')
        #e_file = os.path.join(self.path, 'spindle_bd_mp.equil.yaml')
        d_file = 'spindle_bd_mp.default.yaml'
        e_file = 'spindle_bd_mp.equil.yaml'
        #d_file = 'cerevisiae.default.yaml'
        #e_file = 'cerevisiae.equil.yaml'

        with open(d_file, 'r') as df: d_dict = yaml.load(df)
        with open(e_file, 'r') as df: 
            e_dict = yaml.load(df)
            if not e_dict: e_dict = {}

        # Figure out the proper rate of graphing
        if 'n_steps' in e_dict: n_steps = e_dict['n_steps']
        elif 'n_steps' in d_dict: n_steps = d_dict['n_steps']
        else: raise ValueError('n_steps not in parameter files')
        print "nsteps: ", n_steps

        if 'n_posit' in e_dict: n_posit = e_dict['n_posit']
        elif 'n_posit' in d_dict: n_posit = d_dict['n_posit']
        else: raise ValueError('n_posit not in parameter files')
        print "nposit: ", n_posit
        tot_frames = n_steps/n_posit
        print "tot_frames: ", tot_frames
        if movie_length > 0: 
            fps = float(tot_frames/movie_length)
            if fps < 1:
                raise ValueError('Specified movie length is too long. Must be less than {}.'.format(tot_frames))
        else:
            fps = 25
            movie_length = float(tot_frames/fps)
        # Seems redundant but n_graph needs to be a multiple of n_posit
        n_graph = int((tot_frames/(movie_length*fps)))*n_posit
        if n_graph == 0: 
            raise ValueError('*** n_graph was calculated to be 0. \
                    Check to n_steps and n_posit. ***')

        # Make the directory where movie frames will go
        grab_dir = os.path.join(self.path, 'frames')
        grab_file = os.path.join(self.path, 'frames/frame')
        if not os.path.exists(grab_dir):
            os.mkdir(grab_dir)

        # Modify the equil file to have the correct graphing parameters
        e_dict['delta'] = d_dict['delta'] # For time conversion
        e_dict['graph_flag'] = 1
        e_dict['grab_flag'] = 1
        e_dict['grab_file'] = grab_file
        e_dict['graph_boundary_flag'] = 0
        #e_dict['n_graph'] = n_graph

        with open(e_file, 'w') as ef:
            yaml.dump(e_dict, ef, default_flow_style=False)

        # Make title fo simulation
        mov_file = "{}.mov".format(self.name)

        #p_file = os.path.join(self.path, 'spindle_bd_mp.posit')
        p_file = 'spindle_bd_mp.posit'

        # Run spindle_analysis_new to create movie frames
        print "d_File: {}, e_File: {}".format(d_file, e_file)
        if self.opts.movie == 'fixed':
            status = call(['spindle_analysis_new', d_file, e_file, '-p', p_file, '--tomo', '0'])
            #status = None
            alt_frame_name = 'fixed'
            compose_background = True
        elif self.opts.movie == 'planar':
            status = call(['spindle_analysis_new', d_file, e_file, '-p', p_file, '--tomo', '1'])
            #status = None
            alt_frame_name = 'planar'
            compose_background = True
        elif self.opts.movie == 'chromosome0':
            status = call(['spindle_analysis_new', d_file, e_file, '-p', p_file, '--tomo', '2'])
            alt_frame_name = 'chromosome0'
            compose_background = False
        elif self.opts.movie == 'chromosome1':
            status = call(['spindle_analysis_new', d_file, e_file, '-p', p_file, '--tomo', '3'])
            alt_frame_name = 'chromosome1'
            compose_background = False
        elif self.opts.movie == 'chromosome2':
            status = call(['spindle_analysis_new', d_file, e_file, '-p', p_file, '--tomo', '4'])
            alt_frame_name = 'chromosome2'
            compose_background = False
        else:
            print "Wrong option for movie generation {}, exiting!".format(self.opts.movie)
            sys.exit(1)
        if status:
            raise ValueError('spindle_analysis_new did not run properly')

        # Make movie with mask and other things
        [vid, anaphase_b_frame] = make_spindle_movie(self.path, name=mov_file,
                                 default_file=d_file, equil_file=e_file,
                                 fps=fps, frame_dir=grab_dir, alt_frame_name=alt_frame_name,
                                 compose_background=compose_background,
                                 do_anaphase=self.PostAnalysis.do_anaphase,
                                 anaphase_step=self.PostAnalysis.anaphase_onset_step)
        vid.release()

        # If the special command exists, see if we are making movies for microtubules (default), cut7 (or whatever the first crosslinker is),
        # and kinetochores
        #anaphase_b_frame = 1600 #FIXME: Change this back!
        #if (self.opts.movie == 'fixed'):
        if (self.opts.movie == 'planar'):
            print "Creating custom fluorescent images of seed!"
            gb2d = GaussianBlurCreator2D(self.path, True, True, True, True, self.PostAnalysis.do_anaphase, anaphase_b_frame)
            blurdir = os.path.join(self.path, 'blurdir')
            if not os.path.exists(blurdir):
                os.mkdir(blurdir)
            gb2d.CreateData(blurdir)
            blurvids = gb2d.CreateMovie('')
            #for vid in blurvids:
            #    vid.release()


        # Clean up frames directory
        # rmtree(grab_dir)


##########################################
if __name__ == "__main__":
    cwd = os.getcwd()
    sd = SpindleSeed(cwd)




