#!/usr/bin/env python
# Basic
import sys, os, pdb
import gc
import argparse
import re
# Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Spindle'))

from read_posit_spindle import ReadPositSpindle
from read_posit_chromosomes import ReadPositChromosomes
from read_posit_xlinks import ReadPositXlinks
from quaternion import *

import scipy.io as sio

# Image stuff
import cv2
from spindle_unit_dict import SpindleUnitDict
from matplotlib.colors import LinearSegmentedColormap
from scipy.ndimage.filters import gaussian_filter

def parse_args():
    parser = argparse.ArgumentParser(prog='blur_2d.py')

    # General options
    parser.add_argument('-mt', '--microtubules', action='store_true',
            help='Run image generation for microtubules')

    parser.add_argument('-kc', '--kinetochores', action='store_true',
            help='Run image generation for kinetochores')

    parser.add_argument('-cut7', '--cut7', action='store_true',
            help='Run image generation for Cut7')

    parser.add_argument('-d', '--workdir', type=str,
            help='Directory name')

    parser.add_argument('-g', '--graph', action='store_true',
            help='Generate Python Images')

    opts = parser.parse_args()
    return opts

# Normalizes the image on the range [0,1] to make them pretty, and returns the
# normalized image, as well as the min and max (for use with display ranges)
def normalize_image(image):
    a = np.min(image)
    b = np.max(image)
    if b == 0:
        return a, b, image
    return a, b, (image - a) / (b - a)

def display_range(image, displayrange):
    amin = np.min(image)
    amax = np.max(image)
    min_indicies = image < amin
    adj_max = displayrange*(amax - amin) + amin
    max_indicies = image > adj_max

    newimage = image
    newimage[min_indicies] = amin
    newimage[max_indicies] = adj_max

    return newimage

def merge_image(imagedata_r, imagedata_g, imagedata_b, displayrange_r, displayrange_g, displayrange_b, filename):
    fig = plt.figure()
    fig.set_size_inches(1, 1)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)

    image_rgb = np.zeros([imagedata_r.shape[0], imagedata_r.shape[1], 3])

    max_r = np.amax(imagedata_r)
    max_g = np.amax(imagedata_g)
    max_b = np.amax(imagedata_b)

    if max_r <= 0:
        max_r = 1.0
    if max_g <= 0:
        max_g = 1.0
    if max_b <= 0:
        max_b = 1.0

    image_r_adj = display_range(imagedata_r, displayrange_r)
    image_g_adj = display_range(imagedata_g, displayrange_g)
    image_b_adj = display_range(imagedata_b, displayrange_b)
    a, b, image_r = normalize_image(image_r_adj)
    a, b, image_g = normalize_image(image_g_adj)
    a, b, image_b = normalize_image(image_b_adj)
    image_rgb[:,:,0] = image_r
    image_rgb[:,:,1] = image_g
    image_rgb[:,:,2] = image_b
    ax.imshow(image_rgb)

    #image_rgb[:,:,0] = imagedata_r/max_r/displayrange_r
    #image_rgb[:,:,1] = imagedata_g/max_g/displayrange_g
    #image_rgb[:,:,2] = imagedata_b/max_b/displayrange_b

    #minval, maxval, nimage = normalize_image(image_rgb)
    #ax.imshow(nimage)

    plt.savefig(filename)
    plt.close()

def merge_image_stack(imagedata, displayranges, colors, filename):
    fig = plt.figure()
    fig.set_size_inches(8, 8)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)

    nstack = len(imagedata)
    if nstack != len(displayranges) or nstack != len(colors):
        print "Don't have simliar lengths of displayranges or colors, exiting!\n"
        sys.exit(1)

    # Split up into red/green/blue contributions
    imagedata_r = np.zeros([imagedata[0].shape[0], imagedata[0].shape[1]])
    imagedata_g = np.zeros([imagedata[0].shape[0], imagedata[0].shape[1]])
    imagedata_b = np.zeros([imagedata[0].shape[0], imagedata[0].shape[1]])
    image_rgb = np.zeros([imagedata[0].shape[0], imagedata[0].shape[1], 3])

    for i in xrange(nstack):
        if colors[i] == 'r':
            imagedata_tmp = display_range(imagedata[i], displayranges[i])
            imagedata_r = imagedata_r + imagedata_tmp
        elif colors[i] == 'g':
            imagedata_tmp = display_range(imagedata[i], displayranges[i])
            imagedata_g = imagedata_g + imagedata_tmp
        elif colors[i] == 'b':
            imagedata_tmp = display_range(imagedata[i], displayranges[i])
            imagedata_b = imagedata_b + imagedata_tmp

    # Combine the images together
    a, b, image_r = normalize_image(imagedata_r)
    a, b, image_g = normalize_image(imagedata_g)
    a, b, image_b = normalize_image(imagedata_b)
    image_rgb[:,:,0] = image_r
    image_rgb[:,:,1] = image_g
    image_rgb[:,:,2] = image_b
    ax.imshow(image_rgb, interpolation='bicubic')

    ## Apply some gaussian blurring to the image
    #blurred_image_rgb = gaussian_filter(image_rgb, 3.0)
    #ax.imshow(blurred_image_rgb)

    # Generate a scale bar of 1um
    lpix = 0.1067 # 106.7 nm pixels
    lMicron = 1/lpix
    offset_scalebar = 2 # FIXME: Change back to 35
    ax.plot([offset_scalebar, offset_scalebar+lMicron], [2, 2], color=[1, 1, 1], linewidth=16)

    plt.savefig(filename)
    plt.close()

def make_image(imagedata, filename, color = None, displayrange = 0.25):
    fig = plt.figure()
    fig.set_size_inches(2, 2)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    
    image_rgb = np.zeros([imagedata.shape[0], imagedata.shape[1], 3])
    if color == 'r':
        image_rgb[:,:,0] = imagedata
    elif color == 'g':
        image_rgb[:,:,1] = imagedata
    elif color == 'b':
        image_rgb[:,:,2] = imagedata
    else:
        image_rgb[:,:,0] = imagedata
        image_rgb[:,:,1] = imagedata
        image_rgb[:,:,2] = imagedata
   
    ## Old way of just normalizing image
    ## Normalize the image values in [0, 1]
    ##minval, maxval, nimage = normalize_image(image_rgb)
    ##ax.imshow(nimage, vmin = minval, vmax = (displayrange*(maxval - minval) + minval))

    # Manually set the display range, noramlize, and then display
    image_adj = display_range(image_rgb, displayrange)
    minval, maxval, nimage = normalize_image(image_adj)
    ax.imshow(nimage)

    plt.savefig(filename)
    plt.close()

    return minval, maxval, nimage

def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix):]
    return text  # or whatever

class GaussianBlurCreator2D():
    def __init__(self, workdir, mt, cut7, kc, graph, do_anaphase, anaphase_b_frame):
        self.run_dir_path = os.path.realpath(workdir)
        self.mt = mt
        self.cut7 = cut7
        self.kc = kc
        self.graph = graph
        
        # Anaphase information
        self.do_anaphase = do_anaphase
        self.anaphase_b_frame = anaphase_b_frame

        self.mt_color = 'r'
        self.cut7_color = 'g'
        self.kc_color = 'g'

        self.spindlereader = None
        self.chromosomereader = None
        self.xlinkreader = None

        self.first_call = True

        if self.mt:
            self.spindlereader = ReadPositSpindle(self.run_dir_path, "spindle_bd_mp.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
            self.spindlereader.LoadPosit()
        if self.kc:
            self.chromosomereader = ReadPositChromosomes(self.run_dir_path, "chromosomes.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
            self.chromosomereader.LoadPosit()
        if self.cut7:
            self.xlinkreader = ReadPositXlinks(self.run_dir_path, "crosslinks.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
            self.xlinkreader.LoadPosit()

        self.nposit = self.spindlereader.CheckDefaultThenEquil('n_posit')
        self.ngraph = self.spindlereader.CheckDefaultThenEquil('n_graph')

        # image parameters
        self.imageParams = {'sigmaxy'    : np.float_(1.33),
                            'sigmaz'     : np.float_(4.0),
                            'A'          : np.float_(50.0),
                            'bkglevel'   : np.float_(0.0),
                            'noisestd'   : np.float_(0.0),
                            'pixelsize'  : np.float_(106.7/25.0),
                            'modeloffset': np.float_(90.0)}

    def CreateData(self, frame_dir):
        self.frame_dir = frame_dir
        print "Imaging {} frames with nposit: {} and ngraph: {}".format(self.spindlereader.nframes, self.nposit, self.ngraph)

        # Information to save off
        framedata_mt = []
        framedata_kc = []
        framedata_cut7 = []

        self.mt_prefix = 'mt_gaussian_2dblur'
        self.cut7_prefix = 'cut7_gaussian_2dblur'
        self.kc_prefix = 'kc_gaussian_2dblur'
        self.merge_prefix = 'merge_gaussian_2dblur'

        displayrange_mt = 0.25
        displayrange_cut7 = 0.3
        displayrange_kc = 0.5

        self.max_frame = self.spindlereader.nframes
        if self.do_anaphase:
            self.max_frame = np.int(self.anaphase_b_frame)

        print "max_frame = {}".format(self.max_frame)
        for frame in xrange(self.max_frame):
        #for frame in xrange(10):
            if self.mt:
                self.spindlereader.ReadFramePosit()
            if self.kc:
                self.chromosomereader.ReadFramePosit()
            if self.cut7:
                self.xlinkreader.ReadFramePosit()
            trueframe = frame * self.nposit / self.ngraph
            if frame == 0:
                continue
            if (frame % (self.ngraph/self.nposit) == 0):
                print "Imaging frame: {}, trueframe: {}".format(frame, trueframe)
                # Get information to plot in the planar view
                [rhat, tomo] = self.spindlereader.GetPlanarViewInformation()

                # Generate the quaternion for this particular frame
                self.GenerateQuaternion(rhat, tomo)
                

                #imagedata_base = self.spindlereader.ImageMicrotubules2D(self.imageParams)
                imagedata_base = self.spindlereader.ImageMicrotubules2DPlanar(self.imageParams, self.rotation_matrix_current)
                imagedata_mt = np.zeros([imagedata_base.shape[0], imagedata_base.shape[1]])
                imagedata_cut7 = np.zeros([imagedata_base.shape[0], imagedata_base.shape[1]])
                imagedata_kc = np.zeros([imagedata_base.shape[0], imagedata_base.shape[1]])


                if self.mt:
                    #imagedata_mt = self.spindlereader.ImageMicrotubules2D(self.imageParams)
                    imagedata_mt = self.spindlereader.ImageMicrotubules2DPlanar(self.imageParams, self.rotation_matrix_current)

                    if self.graph:
                        filename = '{0}/{1}_frame{2:0>5d}.png'.format(frame_dir, self.mt_prefix, trueframe)
                        mt_min, mt_max, mt_channel = make_image(imagedata_mt, filename, self.mt_color, displayrange_mt)

                    framedata_mt.append(imagedata_mt)

                if self.kc:
                    #imagedata_kc = self.chromosomereader.ImageKinetochores2D(self.imageParams, self.spindlereader.h)
                    imagedata_kc = self.chromosomereader.ImageKinetochores2DPlanar(self.imageParams, self.spindlereader.h, self.rotation_matrix_current)

                    if self.graph:
                        filename = '{0}/{1}_frame{2:0>5d}.png'.format(frame_dir, self.kc_prefix, trueframe)
                        kc_min, kc_max, kc_channel =  make_image(imagedata_kc, filename, self.kc_color, displayrange_kc)

                    framedata_kc.append(imagedata_kc)

                if self.cut7:
                    imagedata_cut7 = self.xlinkreader.ImageXlink2D(self.imageParams, self.spindlereader.h, 0, self.spindlereader.microtubules)

                    if self.graph:
                        filename = '{0}/{1}_frame{2:0>5d}.png'.format(frame_dir, self.cut7_prefix, trueframe)
                        cut7_min, cut7_max, cut7_channel = make_image(imagedata_cut7, filename, self.cut7_color, displayrange_cut7)

                    framedata_cut7.append(imagedata_cut7)

                if self.graph:
                    # Merge the images together in some way
                    filename = '{0}/{1}_frame{2:0>5d}.png'.format(frame_dir, self.merge_prefix, trueframe)
                    #merge_image(imagedata_mt, imagedata_cut7, imagedata_kc, displayrange_mt, displayrange_cut7, displayrange_kc, filename)
                    imagedata_combined = [imagedata_mt, imagedata_kc]
                    displayranges_combined = [displayrange_mt, displayrange_kc]
                    colors_combined = [self.mt_color, self.kc_color]
                    merge_image_stack(imagedata_combined, displayranges_combined, colors_combined, filename)


        if self.mt:
            framedata_mt = np.transpose(framedata_mt, (1, 2, 0))
        if self.kc:
            framedata_kc = np.transpose(framedata_kc, (1, 2, 0))
        if self.cut7:
            framedata_cut7 = np.transpose(framedata_cut7, (1, 2, 0))
        imagedict = {'framedata_mt' : framedata_mt, 
                     'framedata_kc' : framedata_kc,
                     'framedata_cut7' : framedata_cut7}
        sio.savemat('blur_framedata.mat', {'imagedict':imagedict})

        print "Done with image generation!"

    def GenerateQuaternion(self, rhat, tomo):
        #print "Generating quaternion"
        xaxis = np.array([1.0, 0.0, 0.0])
        yaxis = np.array([0.0, 1.0, 0.0])
        zaxis = np.array([0.0, 0.0, 1.0])
        # Quaternion to align with z axis
        q1 = quaternion_between_vectors(tomo, zaxis)

        # Get the axis and rotate spb0 to point in this new frame
        [theta, eaxis] = axisangle_from_quaternion(q1)
        vrot = rodrigues_axisangle(rhat, theta, eaxis)
        q2 = quaternion_between_vectors(vrot, yaxis)

        q3 = quaternion_multiply(q2, q1)
        if self.first_call:
            self.first_call = False
            self.quaternion_old = q3

        qr = rotate_towards(self.quaternion_old, q3, 3.14*0.01)
        self.quaternion_old = qr

        #print qr

        self.qr = qr
        self.rotation_matrix_current = rotation_matrix_from_quaternion(self.qr)


    def CreateMovie(self, name, fps=60.0):
        print "Creating movies of gaussian 2d images!"

        fps = float(fps)
        # Init text writing
        font = cv2.FONT_HERSHEY_SIMPLEX

        #mov_path = os.path.join(self.run_dir_path, name)
        #if os.path.exists(mov_path):
        #    os.remove(mov_path)

        # Make list of all the frames in the frame directory
        frame_list = [f for f in os.listdir(self.frame_dir) if os.path.isfile(os.path.join(self.frame_dir, f))]
        frame_num_list = [ int(remove_prefix(re.findall(r'frame\d+', f)[0], 'frame')) for f in frame_list]
        from collections import OrderedDict
        #frame_num_list = OrderedDict((x, True) for x in frame_num_list).keys()
        frame_num_list = np.array(sorted(OrderedDict((x, True) for x in frame_num_list).keys()))

        #print "frame_list: {}".format(frame_list)
        #print "frame_num_list: {}".format(frame_num_list)

        vidmt = None
        vidcut7 = None
        vidkc = None
        vidmerge = None

        #mtname = '{}_microtubule.mov'.format(mov_path)
        #cut7name = '{}_cut7.mov'.format(mov_path)
        #kcname = '{}_kinetochore.mov'.format(mov_path)
        #mergename = '{}_merge.mov'.format(mov_path)
        mtname = os.path.join(self.run_dir_path, 'microtubule_microscopy{}.mov'.format(name))
        cut7name = os.path.join(self.run_dir_path, 'cut7_microscopy{}.mov'.format(name))
        kcname = os.path.join(self.run_dir_path, 'kinetochore_microscopy{}.mov'.format(name))
        mergename = os.path.join(self.run_dir_path, 'merge_microscopy{}.mov'.format(name))

        size = None
        fps = 24.0

        if self.do_anaphase:
            frame_num_list = frame_num_list[frame_num_list < self.max_frame]

        for n in frame_num_list:
            if self.mt:
                img = cv2.imread(os.path.join(self.frame_dir, '{0}_frame{1:0>5d}.png'.format(self.mt_prefix, n)))
                if size is None:
                    size = img.shape[1], img.shape[0]
                elif size[0] != img.shape[1] and size[1] != img.shape[0]:
                    img = cv2.resize(img, size)
                if vidmt is None:
                    vidmt = cv2.VideoWriter(mtname, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
                vidmt.write(img)
            if self.cut7:
                img = cv2.imread(os.path.join(self.frame_dir, '{0}_frame{1:0>5d}.png'.format(self.cut7_prefix, n)))
                if size is None:
                    size = img.shape[1], img.shape[0]
                elif size[0] != img.shape[1] and size[1] != img.shape[0]:
                    img = cv2.resize(img, size)
                if vidcut7 is None:
                    vidcut7 = cv2.VideoWriter(cut7name, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
                vidcut7.write(img)
            if self.kc:
                img = cv2.imread(os.path.join(self.frame_dir, '{0}_frame{1:0>5d}.png'.format(self.kc_prefix, n)))
                if size is None:
                    size = img.shape[1], img.shape[0]
                elif size[0] != img.shape[1] and size[1] != img.shape[0]:
                    img = cv2.resize(img, size)
                if vidkc is None:
                    vidkc = cv2.VideoWriter(kcname, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
                vidkc.write(img)
            if self.mt:
                img = cv2.imread(os.path.join(self.frame_dir, '{0}_frame{1:0>5d}.png'.format(self.merge_prefix, n)))
                if size is None:
                    size = img.shape[1], img.shape[0]
                elif size[0] != img.shape[1] and size[1] != img.shape[0]:
                    img = cv2.resize(img, size)
                if vidmerge is None:
                    vidmerge = cv2.VideoWriter(mergename, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
                vidmerge.write(img)


        return [vidmt, vidcut7, vidkc, vidmerge]



##########################################
if __name__ == "__main__":
    opts = parse_args()
    run_dir_path = os.path.realpath(opts.workdir)

    spindlereader = None
    chromosomereader = None
    xlinkreader = None

    if opts.microtubules:
        spindlereader = ReadPositSpindle(run_dir_path, "spindle_bd_mp.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        spindlereader.LoadPosit()
    if opts.kinetochores:
        chromosomereader = ReadPositChromosomes(run_dir_path, "chromosomes.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        chromosomereader.LoadPosit()
    if opts.cut7:
        xlinkreader = ReadPositXlinks(run_dir_path, "crosslinks.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        xlinkreader.LoadPosit()

    nposit = spindlereader.CheckDefaultThenEquil('n_posit')
    ngraph = spindlereader.CheckDefaultThenEquil('n_graph')

    # image parameters
    imageParams = {'sigmaxy'    : np.float_(1.33),
                   'sigmaz'     : np.float_(4.0),
                   'A'          : np.float_(50.0),
                   'bkglevel'   : np.float_(0.0),
                   'noisestd'   : np.float_(0.0),
                   'pixelsize'  : np.float_(106.7/25.0)}
    #imageParamsKC = {'sigmaxy'    : np.float_(1.33),
    #               'sigmaz'     : np.float_(4.0),
    #               'A'          : np.float_(150.0),
    #               'bkglevel'   : np.float_(0.0),
    #               'noisestd'   : np.float_(0.0),
    #               'pixelsize'  : np.float_(106.7/25.0)}

    print "Imaging {} frames with nposit: {} and ngraph: {}".format(spindlereader.nframes, nposit, ngraph)

    # Information to save off
    framedata_mt = []
    framedata_kc = []
    framedata_cut7 = []

    for frame in xrange(spindlereader.nframes):
    #for frame in xrange(1000):
        if opts.microtubules:
            spindlereader.ReadFramePosit()
        if opts.kinetochores:
            chromosomereader.ReadFramePosit()
        if opts.cut7:
            xlinkreader.ReadFramePosit()
        trueframe = frame * nposit / ngraph
        if frame == 0:
            continue
        if (frame % (ngraph/nposit) == 0):
            print "Imaging frame: {}, trueframe: {}".format(frame, trueframe)
           
            if opts.microtubules:
                imagedata = spindlereader.ImageMicrotubules2D(imageParams)
                
                if opts.graph:
                    filename = 'mt_gaussian_2dblur_{0:0>5d}.png'.format(trueframe)
                    make_image(imagedata, filename)

                framedata_mt.append(imagedata)

            if opts.kinetochores:
                imagedata = chromosomereader.ImageKinetochores2D(imageParams, spindlereader.h)

                if opts.graph:
                    filename = 'kc_gaussian_2dblur_{0:0>5d}.png'.format(trueframe)
                    make_image(imagedata, filename)

                framedata_kc.append(imagedata)

            if opts.cut7:
                imagedata = xlinkreader.ImageXlink2D(imageParams, spindlereader.h, 0, spindlereader.microtubules)

                if opts.graph:
                    filename = 'cut7_gaussian_2dblur_{0:0>5d}.png'.format(trueframe)
                    make_image(imagedata, filename)

                framedata_cut7.append(imagedata)


    if opts.microtubules:
        framedata_mt = np.transpose(framedata_mt, (1, 2, 0))
    if opts.kinetochores:
        framedata_kc = np.transpose(framedata_kc, (1, 2, 0))
    if opts.cut7:
        framedata_cut7 = np.transpose(framedata_cut7, (1, 2, 0))
    imagedict = {'framedata_mt' : framedata_mt, 
                 'framedata_kc' : framedata_kc,
                 'framedata_cut7' : framedata_cut7}
    sio.savemat('blur_framedata.mat', {'imagedict':imagedict})

    print "Done with image generation!"




