#!/usr/bin/env python
# Basic
import sys, os, pdb
import gc
import argparse
# Analysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from math import *
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'Lib'))

from read_posit_spindle import ReadPositSpindle
from read_posit_chromosomes import ReadPositChromosomes
from read_posit_xlinks import ReadPositXlinks

import scipy.io as sio

def parse_args():
    parser = argparse.ArgumentParser(prog='blur_3d.py')

    # General options
    parser.add_argument('-mt', '--microtubules', action='store_true',
            help='Run image generation for microtubules')

    parser.add_argument('-kc', '--kinetochores', action='store_true',
            help='Run image generation for kinetochores')

    parser.add_argument('-cut7', '--cut7', action='store_true',
            help='Run image generation for Cut7')

    parser.add_argument('-d', '--workdir', type=str,
            help='Directory name')

    parser.add_argument('-g', '--graph', action='store_true',
            help='Generate Python Images')

    opts = parser.parse_args()
    return opts

def make_image(imagedata, filename):
    fig = plt.figure()
    fig.set_size_inches(1, 1)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(imagedata, cmap='gray')
    plt.savefig(filename)
    plt.close()

def make_image_3d(imagedata, filename, vmin, vmax):
    f, axarr = plt.subplots(2, 3)
    axarr[0,0].imshow(imagedata[:,:,0], cmap = 'gray', vmin = vmin, vmax = vmax)
    axarr[0,1].imshow(imagedata[:,:,1], cmap = 'gray', vmin = vmin, vmax = vmax)
    axarr[0,2].imshow(imagedata[:,:,2], cmap = 'gray', vmin = vmin, vmax = vmax)
    axarr[1,0].imshow(imagedata[:,:,3], cmap = 'gray', vmin = vmin, vmax = vmax)
    axarr[1,1].imshow(imagedata[:,:,4], cmap = 'gray', vmin = vmin, vmax = vmax)
    axarr[1,2].imshow(np.amax(imagedata, axis=2), cmap='gray', vmin = vmin, vmax = vmax)
    axarr[0,0].set_axis_off()
    axarr[0,1].set_axis_off()
    axarr[0,2].set_axis_off()
    axarr[1,0].set_axis_off()
    axarr[1,1].set_axis_off()
    axarr[1,2].set_axis_off()
    plt.savefig(filename)
    plt.close()

##########################################
if __name__ == "__main__":
    opts = parse_args()
    run_dir_path = os.path.realpath(opts.workdir)

    spindlereader = None
    chromosomereader = None
    xlinkreader = None

    if opts.microtubules:
        spindlereader = ReadPositSpindle(run_dir_path, "spindle_bd_mp.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        spindlereader.LoadPosit()
    if opts.kinetochores:
        chromosomereader = ReadPositChromosomes(run_dir_path, "chromosomes.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        chromosomereader.LoadPosit()
    if opts.cut7:
        xlinkreader = ReadPositXlinks(run_dir_path, "crosslinks.posit", "spindle_bd_mp.default.yaml", "spindle_bd_mp.equil.yaml")
        xlinkreader.LoadPosit()

    nposit = spindlereader.CheckDefaultThenEquil('n_posit')
    ngraph = spindlereader.CheckDefaultThenEquil('n_graph')

    # image parameters
    imageParams = {'sigmaxy'    : np.float_(1.33),
                   'sigmaz'     : np.float_(4*0.125), # check this at some point
                   'A'          : np.float_(50.0),
                   'bkglevel'   : np.float_(0.0),
                   'noisestd'   : np.float_(0.0),
                   'pixelsize'  : np.float_(106.7/25.0),
                   'nzstacks'   : np.int_(5)}
    #imageParamsKC = {'sigmaxy'    : np.float_(1.33),
    #               'sigmaz'     : np.float_(4.0),
    #               'A'          : np.float_(150.0),
    #               'bkglevel'   : np.float_(0.0),
    #               'noisestd'   : np.float_(0.0),
    #               'pixelsize'  : np.float_(106.7/25.0)}

    print "Imaging {} frames with nposit: {} and ngraph: {}".format(spindlereader.nframes, nposit, ngraph)

    # Information to save off
    framedata_mt = []
    framedata_kc = []
    framedata_cut7 = []

    for frame in xrange(spindlereader.nframes):
    #for frame in xrange(100):
        if opts.microtubules:
            spindlereader.ReadFramePosit()
        if opts.kinetochores:
            chromosomereader.ReadFramePosit()
        if opts.cut7:
            xlinkreader.ReadFramePosit()
        trueframe = frame * nposit / ngraph
        if frame == 0:
            continue
        if (frame % (ngraph/nposit) == 0):
            print "Imaging frame: {}, trueframe: {}".format(frame, trueframe)
           
            if opts.microtubules:
                imagedata = spindlereader.ImageMicrotubules3D(imageParams)
                framedata_mt.append(imagedata)

            #if opts.kinetochores:
            #    imagedata = chromosomereader.ImageKinetochores2D(imageParams, spindlereader.h)

            #    if opts.graph:
            #        filename = 'kc_gaussian_2dblur_{0:0>5d}.png'.format(trueframe)
            #        make_image(imagedata, filename)

            #    framedata_kc.append(imagedata)

            #if opts.cut7:
            #    imagedata = xlinkreader.ImageXlink2D(imageParams, spindlereader.h, 0, spindlereader.microtubules)

            #    if opts.graph:
            #        filename = 'cut7_gaussian_2dblur_{0:0>5d}.png'.format(trueframe)
            #        make_image(imagedata, filename)

            #    framedata_cut7.append(imagedata)

    # Transpose stuff
    if opts.microtubules:
        framedata_mt = np.transpose(framedata_mt, (1, 2, 3, 0))

    mt_imin = np.amin(framedata_mt)
    mt_imax = np.amax(framedata_mt)
    print "MT limits {}, {}".format(mt_imin, mt_imax)

    # All imaging must be done post to get the colormaps correct
    if opts.graph:
        for frame in xrange(spindlereader.nframes-1):
        #for frame in xrange(100):
            trueframe = frame * nposit / ngraph
            if frame == 0:
                continue
            
            if opts.microtubules:
                filename = 'mt_gaussian_3dblur_{0:0>5d}.png'.format(trueframe)
                make_image_3d(framedata_mt[:,:,:,trueframe], filename, mt_imin, mt_imax)
                


    #if opts.kinetochores:
    #    framedata_kc = np.transpose(framedata_kc, (1, 2, 0))
    #if opts.cut7:
    #    framedata_cut7 = np.transpose(framedata_cut7, (1, 2, 0))
    imagedict = {'framedata_mt' : framedata_mt, 
                 'framedata_kc' : framedata_kc,
                 'framedata_cut7' : framedata_cut7}
    sio.savemat('blur_framedata_3d.mat', {'imagedict':imagedict})

    print "Done with image generation!"




