#!/usr/bin/env python
# In case of poor (Sh**y) commenting contact christopher.edelmaier@colorado.edu
# Basic
import sys, os, pdb
import gc
import argparse
import fnmatch
## Analysis
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Spindle'))

from spindle_sim import SpindleSim
from sim_graph_funcs import *
from stylelib.ase1_styles import ase1_sims_stl

from scipy.stats import ks_2samp
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw
import scipy.io as sio

'''
Name: FitnessTests.py
Description:
Input:
Output:
'''

def parse_args():
    parser = argparse.ArgumentParser(prog='FitnessTests.py')

    # General options that are actually required
    parser.add_argument('-m', '--model', required=True, type=str,
            help='Model sim')
    parser.add_argument('-e', '--experiment', required=False, type=str,
            help='Experiment sim')
    parser.add_argument('--real', required=False, type=str,
            help='Experimental data to read')

    # Minimum of extra options to make this work
    parser.add_argument('--nopost', action='store_true',
            help="Do not use post analyis program to decrease time of analysis.")
    parser.add_argument('-F', '--fitness', action='store_true',
            help='Create fitness for spindle simulations.')
    parser.add_argument('-A', '--analyze', action='store_true',
            help='Analyze data from multiple simulations')
    parser.add_argument('--datadir', type=str,
            help='Name of the data directory in which all analyzed data files will be read/written. \
                    Also the directory where all the graphs will be placed when saved. \
                    Default is set to {workdir}/data/.')
    parser.add_argument('-d', '--workdir', type=str,
            help='Name of the working directory where simulation will be run.')

    opts = parser.parse_args()
    return opts

# Class definition
class FitnessTests(object):
    def __init__(self, opts):
        self.opts = opts
        self.cwd = os.getcwd()

        if self.opts.experiment:
            self.real = False
        elif self.opts.real:
            self.real = True
        else:
            print "Error in what kind of sim to read!"
            sys.exit(1)

        self.ReadOpts()
        self.AnalyzeSims()

    # Read in options
    def ReadOpts(self):
        if not self.opts.workdir:
            self.opts.workdir = os.path.abspath(self.cwd)
        elif not os.path.exists(self.opts.workdir):
            raise IOError( "Working directory {} does not exist.".format(
                self.opts.workdir) )
        else:
            self.opts.workdir = os.path.abspath(self.opts.workdir)

        self.model_dir = os.path.abspath(self.opts.model)
        self.model_sim = SpindleSim(self.model_dir, opts)

        if not self.real:
            self.experiment_dir = os.path.abspath(self.opts.experiment)
            self.experiment_sim = SpindleSim(self.experiment_dir, opts)
        else:
            self.LoadData()

    # Read in the data from a specified location
    def LoadData(self):
        print "Reading in data from {}".format(self.opts.real)
        if self.opts.real == 'klp56':
            experiment_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'Klp56', 'Pandas')
            listoffiles = os.listdir(experiment_dir)
            pattern = '*.pickle'
            experiments = []
            for entry in listoffiles:
                if fnmatch.fnmatch(entry, pattern):
                    experiments += [os.path.join(experiment_dir, entry)]

        experiment_data = {}
        for experiment in experiments:
            track_names = os.path.basename(experiment).split('_')
            track_name = track_names[0] + '_' + track_names[1]
            df = pd.read_pickle(experiment)
            experiment_data[track_name] = df

    # Analyze the sims
    def AnalyzeSims(self):
        self.model_sim.Analyze()

        if not self.real:
            self.experiment_sim.Analyze()

        self.model_sim.CalcSimSuccess()
        if not self.real:
            self.experiment_sim.CalcSimSuccess()

        if self.opts.fitness:
            wt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'wt.mat')
            lstream_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'Data', 'lstream.mat')
            self.model_sim.Fitness(wt_path, lstream_path)
            if not self.real:
                self.experiment_sim.Fitness(wt_path, lstream_path)

    def GraphTimeSeries(self):
        fig, axarr = plt.subplots(2, 2, figsize=(25,16))

        self.GraphTimeSeriesLength(fig, axarr, self.model_sim, 0)
        self.GraphTimeSeriesLength(fig, axarr, self.experiment_sim, 1)

        plt.savefig("model_experiment.png")
        #plt.show()
        plt.close()
        gc.collect()
        mpl.rcdefaults()


    def GraphTimeSeriesLength(self, fig, axarr, sim, row):
        # Graph the model that we are doing
        colors = mpl.cm.rainbow(np.linspace(0,1,len(sim.seeds)))
        min_size = 0
        for sd in sim.seeds:
            if (min_size == 0 or min_size > sd.time.size):
                min_size = sd.time.size

        time = sim.seeds[-1].time[:min_size]

        avg = np.zeros(min_size)
        num_seeds = 0

        graph = graph_spb_sep
        axr = axarr[row]

        lstream = []

        for sd, col in zip(sim.seeds, colors):
            yarr = graph(sd, axr[0], color=col, xlabel=False)
            lstream += [yarr]
            num_seeds += 1
            avg = np.add(avg, yarr[:min_size])

        avg /= float(num_seeds)
        axr[1].plot(time, avg)

        # Modify the limits, etc
        axarr[1, 0].legend(loc='center left', bbox_to_anchor=(2.2,-.19))

        axarr[-1,1].set_xlabel(r'Time (min)')
        axarr[-1,0].set_xlabel(r'Time (min)')
        fig.tight_layout()
        plt.subplots_adjust(hspace = .38, right=.85, top=.87)

        if row == 0:
            self.model_avg_length = avg
            self.model_lstream = lstream
            self.model_times = time
            self.model_nseeds = num_seeds
        elif row == 1:
            self.experiment_avg_length = avg
            self.experiment_lstream = lstream
            self.experiment_times = time
            self.experiment_nseeds = num_seeds

    def OriginalLengthFitness(self):
        # Check the original length fitness, via what should be the original setup
        # where the length is taken to be at least 1.1 micron, etc etc

        measuretime = 15.0 # 15 seconds between measurements

        # Create the experimental distribution
        lstream_experiment = []
        for eseed in xrange(self.experiment_nseeds):
            lexperiment = self.experiment_lstream[eseed]
            times = self.experiment_times * 60.0
            t0 = times[np.argmax(lexperiment > 1.1)] + 10.0
            mask = np.zeros(1)
            if t0.size == 0:
                mask = np.zeros(times.size)
            else:
                mask = (times >= t0)

            nmeasure = np.int(measuretime / ((times[1] - times[0])))
            lstream_experiment = []
            if sum(mask) > 5:
                lexperiment = lexperiment[mask]
                lexperiment = lexperiment[0::nmeasure]
                lstream_experiment = np.append(lstream_experiment, lexperiment)

        # Now do the model
        for mseed in xrange(self.model_nseeds):
            lmodel = self.model_lstream[mseed]
            times = self.model_times * 60.0
            t0 = times[np.argmax(lmodel > 1.1)] + 10.0
            mask = np.zeros(1)
            if t0.size == 0:
                mask = np.zeros(times.size)
            else:
                mask = (times >= t0)

            nmeasure = np.int(measuretime / ((times[1] - times[0])))
            lstream_model = []
            if sum(mask) > 5:
                lmodel = lmodel[mask]
                lmodel = lmodel[0::nmeasure]
                lstream_model = np.append(lstream_model, lmodel)

        if len(lstream_experiment) > 0 and len(lstream_model) > 0:
            length_res = ks_2samp(lstream_model, lstream_experiment)
            print "length pvalue: {}".format(length_res)
        else:
            print "length pvalue: 0.0 (lengths not correct)"

        bins = np.linspace(0.0, 2.75, 10)

        plt.hist(lstream_model, bins, alpha = 0.5, label = 'model', normed = True)
        plt.hist(lstream_experiment, bins, alpha = 0.5, label = 'experiment', normed = True)
        plt.legend()
        plt.savefig("length_streams.png")
        plt.close()
        gc.collect()
        mpl.rcdefaults()

    def CrossCorrelation(self):
        # Check the cross correlation of all of the seeds against each other....
        max_correlation = np.zeros(self.model_nseeds)
        # Compare each model to all the experiments
        for mseed in xrange(self.model_nseeds):
            model_ts = pd.Series(self.model_lstream[mseed])
            for eseed in xrange(self.experiment_nseeds):
                experiment_ts = pd.Series(self.experiment_lstream[eseed])
                corr = model_ts.corr(experiment_ts)
                #print "Cross correlation: model: {}, experiment: {} = {}".format(mseed, eseed, corr)
                if corr > max_correlation[mseed]:
                    max_correlation[mseed] = corr

        for mseed in xrange(self.model_nseeds):
            print "Cross correlation max: model: {} = {}".format(mseed, max_correlation[mseed])

        print "Cross correlation avg of max: {}".format(np.mean(max_correlation))

        # What about the average length distributions?
        avg_e = pd.Series(self.experiment_avg_length)
        avg_m = pd.Series(self.model_avg_length)
        corr = avg_m.corr(avg_e)
        print "Cross correlation: model: avg, experiment: avg = {}".format(corr)

    def DynamicTimeWarping(self):
        # Check the dynamic time warping between everything
        max_timewarp = np.zeros(self.model_nseeds)
        # Compare the model to all experiments
        for mseed in xrange(self.model_nseeds):
            model_ts = np.column_stack([self.model_times, self.model_lstream[mseed]])
            max_dist = 2.75 * len(model_ts[:,0]) # multiply by the number of samples in the sequence
            for eseed in xrange(self.experiment_nseeds):
                experiment_ts = np.column_stack([self.experiment_times, self.experiment_lstream[eseed]])
                distance, _ = fastdtw(model_ts, experiment_ts, dist=euclidean)
                similarity = 1.0 - distance/max_dist
                #print "DTW: model: {}, experiment: {} = {}".format(mseed, eseed, similarity)
                if similarity > max_timewarp[mseed]:
                    max_timewarp[mseed] = similarity

        for mseed in xrange(self.model_nseeds):
            print "DTW max: model: {} = {}".format(mseed, max_timewarp[mseed])

        print "DTW avg of max: {}".format(np.mean(max_timewarp))

        # What about the average distributions?
        model_avg_ts = np.column_stack([self.model_times, self.model_avg_length])
        max_dist = 2.75 * len(model_avg_ts[:,0])
        experiment_avg_ts = np.column_stack([self.experiment_times, self.experiment_avg_length])
        distance, _ = fastdtw(model_avg_ts, experiment_avg_ts, dist=euclidean)
        similarity = 1.0 - distance/max_dist
        print "DTW: model: avg, experiment: avg = {}".format(similarity)


##########################################
if __name__ == "__main__":
    opts = parse_args()
    x = FitnessTests(opts)
    x.GraphTimeSeries()
    x.OriginalLengthFitness()
    x.CrossCorrelation()
    x.DynamicTimeWarping()
