#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact christopher.edelmaier@colorado.edu
# Basic
import sys, os, pdb
import pickle
## Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *

sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Spindle'))
from correlationdata import CorrelationData
from spindle_unit_dict import SpindleUnitDict

'''
Name: ChromosomeDiffusionAnalysis.py
Description:
Input:
Output:
'''

#Class definition
class XlinkAnalysis(object):
    def __init__(self, run_dir_path, analysis_name):
        self.run_dir_path = os.path.realpath(run_dir_path)
        self.analysis_name = analysis_name

        self.InitXlink()
        self.InitFP()

        self.bin_size = 0.05

    def InitXlink(self):
        self.analyze_xlinks = False
        self.xlink_1d_fname     = 'xlink_{}_1d.dat'.format(self.analysis_name)
        self.xlink_2d_fname     = 'xlink_{}_2d.dat'.format(self.analysis_name)
        self.xlink_psi1_fname   = 'xlink_{}_psi1.dat'.format(self.analysis_name)
        self.xlink_nexp_fname   = 'xlink_{}_nexp.dat'.format(self.analysis_name)

        if os.path.isfile(os.path.join(self.run_dir_path, self.xlink_1d_fname)):
            self.analyze_xlinks = True
            self.xlink_1d_data  = CorrelationData(self.run_dir_path, self.xlink_1d_fname)
            self.xlink_2d_data  = CorrelationData(self.run_dir_path, self.xlink_2d_fname)
            self.xlink_psi1     = CorrelationData(self.run_dir_path, self.xlink_psi1_fname)
            self.xlink_nexp     = CorrelationData(self.run_dir_path, self.xlink_nexp_fname)

    def InitFP(self):
        self.analyze_fp = False
        self.fp_psi1_fname      = 'fp_{}_psi1.dat'.format(self.analysis_name)
        self.fp_psi_fname       = 'fp_{}_psi.dat'.format(self.analysis_name)

        if os.path.isfile(os.path.join(self.run_dir_path, self.fp_psi1_fname)):
            self.analyze_fp = True
            self.fp_psi1 = CorrelationData(self.run_dir_path, self.fp_psi1_fname)
            self.fp_psi  = CorrelationData(self.run_dir_path, self.fp_psi_fname)

    def Print(self):
        if self.analyze_xlinks:
            print "Xlink 1D Data:"
            self.xlink_1d_data.Print()
            print "Xlink 2D Data:"
            self.xlink_2d_data.Print()
            print "Xlink Psi1 Data:"
            self.xlink_psi1.Print()
            print "Xlink nexp Data:"
            self.xlink_nexp.Print()
        if self.analyze_fp:
            print "FP Psi1 Data:"
            self.fp_psi1.Print()
            print "FP 2D Data:"
            self.fp_psi.Print()

    def Plot(self):
        if self.analyze_xlinks:
            self.Plot1Ddist()
            self.Plot2Ddist('xlink', self.xlink_2d_data)
            self.PlotPsi1('xlink', self.xlink_psi1)
            self.PlotNexp()
        if self.analyze_fp:
            self.PlotPsi1('fp', self.fp_psi1)
            self.Plot2Ddist('fp', self.fp_psi)

    def Plot1Ddist(self):
        fname = 'xlink_{}_1d.png'.format(self.analysis_name)
        self.Plot1D(-5.0, 5.0, self.bin_size, self.xlink_1d_data.GetData(), fname)

    def PlotPsi1(self, name, psi1):
        fname = '{}_{}_psi1.png'.format(name, self.analysis_name)
        self.Plot1D(0.0, psi1.h[0][0], self.bin_size, psi1.GetData(), fname)

    def PlotNexp(self):
        fname = 'xlink_{}_nexp.png'.format(self.analysis_name)
        self.Plot1D(0.0, self.xlink_nexp.h[0][0], self.bin_size, self.xlink_nexp.GetData(), fname)

    def Plot2Ddist(self, name, dist):
        fname = '{}_{}_2d.png'.format(name, self.analysis_name)
        self.Plot2D(0.0, dist.h[0][0], self.bin_size, dist.GetData(), fname)

    ### Generic Plot Functions
    def Plot1D(self, xmin, xmax, bin_size, imagedata, name):
        xvals = np.arange(xmin, xmax, bin_size)
        plt.plot(xvals, imagedata)
        plt.savefig(name, bbox_inches = 'tight')
        plt.close()

    def Plot2D(self, xmin, xmax, bin_size, imagedata, name):
        mextent = [xmin, xmax, xmin, xmax]
        plt.imshow(imagedata, origin = 'lower', extent = mextent)
        plt.savefig(name, bbox_inches = 'tight')
        plt.close()


##########################################
if __name__ == "__main__":
    print "xlink analysis"
    run_dir = sys.argv[1]
    analysis_name = sys.argv[2]

    xla = XlinkAnalysis(run_dir, analysis_name)
    xla.Print()
    xla.Plot()
