#!/usr/bin/env python
# In case of poor (Sh***y) commenting contact adam.lamson@colorado.edu
# YOLO edelmaie@colorado.edu (too)
# Basic
import sys, os, pdb
import gc
## Analysis
import numpy as np
import yaml

from operator import attrgetter

from scipy import special
import scipy.misc

from read_posit_base import ReadPositBase

from quaternion import quaternion_between_vectors
from quaternion import axisangle_from_quaternion 
from quaternion import rodrigues_axisangle
from quaternion import quaternion_multiply
from quaternion import rotation_matrix_from_quaternion

# import pandas as pd
# import matplotlib.pyplot as plt
# import matplotlib as mpl
# from math import *

'''
Name: read_posit_thermo.py
Description:
Input:
Output:
'''

# Class definition
class Potential():
    def __init__(self, idx, subidx, name):
        self.idx = idx
        self.subidx = subidx # For differentiating the crosslink potentials from others
        self.u = 0.0
        self.virial = 0.0
        self.name = name
        #if self.idx == 0:
        #    self.name = "brownian_sphero_neighbor_lists_mp"
        #elif self.idx == 1:
        #    self.name = "crosslink_sphero_bd_mp"
        #elif self.idx == 2:
        #    self.name = "anchor_potential_bd"
        #elif self.idx == 3:
        #    self.name = "anchor_wca_potential_bd"
        #elif self.idx == 4:
        #    self.name = "chromosome_chromatin_potential"
        #elif self.idx == 5:
        #    self.name = "kc_wall_wca_potential"
        #elif self.idx == 6:
        #    self.name = "chromosome_mt_soft_gaussian_potential"
        #elif self.idx == 7:
        #    self.name = "af_mt_harmonic_potential"
        #elif self.idx == 8:
        #    self.name = "kinetochoremesh_mt_wca_potential_allpairs"
        #elif self.idx == 9:
        #    self.name = "single_site: wca_sphero_wall_potential"
        #else:
        #    self.name = "None"

    def SetValues(self, virial):
        self.virial = virial

    def PrintFrame(self):
        print "      idx: {}".format(self.idx)
        print "      subidx: {}".format(self.subidx)
        print "      name: {}".format(self.name)
        print "      virial:\n{}".format(self.virial)

    def __repr__(self):
        retrstring = ""
        retrstring += "{} {} {}".format(self.idx, self.subidx, self.name)

        return retrstring

#Class definition
class ReadPositThermo(ReadPositBase):
    def __init__(self, seed_path, posit_name,
                 default_file='default.yaml', equil_file='equil.yaml'):
        ReadPositBase.__init__(self, seed_path, posit_name, default_file, equil_file)

        # Here is the list of the MT virial forces
        self.mt_virial_forces = ['brownian_sphero_neighbor_lists_mp',
                'crosslink_interaction_bd_mp_0',
                'crosslink_interaction_bd_mp_1',
                'crosslink_interaction_bd_mp_2',
                'anchor_potential_bd',
                'chromosome_mt_soft_gaussian_potential_allpairs',
                'af_mt_harmonic_potential',
                'kinetochoremesh_mt_wca_potential_allpairs',
                'wca_sphero_wall_potential_bd']

    def LoadPosit(self):
        self.f_posit = open(self.posit_path)
        self.ReadPosHeader()

    def UnloadPosit(self):
        self.f_posit.close()

    def ReadPosHeader(self):
        # Therm file is now in plaintext, with a header describing what is
        # going on
        line = self.f_posit.readline().rstrip()
        # Find the start of the header
        if line == "HEADERINFO":
            # The next lines are always specified
            line = self.f_posit.readline().rstrip()
            self.ndim = np.int(line.split()[1])
            line = self.f_posit.readline().rstrip()
            self.ncomp = np.int(line.split()[1])
            line = self.f_posit.readline().rstrip()
            self.ntypes = np.int(line.split()[1])
            line = self.f_posit.readline().rstrip()
            self.nanchors = np.int(line.split()[1])

            # Now come all the special ptentials that we have enabled
            line = self.f_posit.readline().rstrip()
            self.potentials = []
            while line != "ENDHEADERINFOEND":
                # Load this potential name
                if len(line.split()) == 3:
                    new_potential = Potential(np.int(line.split()[0]), np.int(line.split()[1]), line.split()[2])
                elif len(line.split()) == 1:
                    new_potential = Potential(np.int(line.split()[0]), 0, 'unknown')
                else:
                    new_potential = Potential(np.int(line.split()[0]), 0, line.split()[1])

                self.potentials += [new_potential]
                line = self.f_posit.readline().rstrip()

        line = self.f_posit.readline().rstrip() # This is really just a header to easily grab the information
        self.cur_frame = -1

    def ReadFramePosit(self):
        # The posit file is now plaintext, so readoff this frame
        line = self.f_posit.readline().rstrip()
        linesplit = line.split()
        #print linesplit

        self.istep = np.int(linesplit.pop(0))

        # The first 9 values are the unit cell
        self.h = np.zeros((self.ndim * self.ndim))
        for i in xrange(self.ndim * self.ndim):
            self.h[i] = linesplit.pop(0)
        self.h = self.h.reshape((self.ndim, self.ndim))

        # Now read in the 2 SPB force vectors
        # XXX: There is an overall minus sign for the force on the SPB, we want the force on the spindle
        self.fspb0 = np.zeros((self.ndim))
        self.fspb1 = np.zeros((self.ndim))
        for i in xrange(self.ndim):
            self.fspb0[i] = linesplit.pop(0)
        for i in xrange(self.ndim):
            self.fspb1[i] = linesplit.pop(0)
        self.fspb0 *= -1.0
        self.fspb1 *= -1.0

        # Now we have to grab all the virial contributions
        for pot in self.potentials:
            cur_virial = np.zeros((self.ndim * self.ndim))
            for i in xrange(self.ndim * self.ndim):
                cur_virial[i] = linesplit.pop(0)
            pot.SetValues(cur_virial.reshape((self.ndim, self.ndim)))

        self.cur_frame += 1

    ### Print functionality
    def __repr__(self):
        retrstring = "ReadPositThermo\n"
        retrstring += "   ndim: {}\n".format(self.ndim)
        retrstring += "   ncomp: {}\n".format(self.ncomp)
        retrstring += "   ntypes: {}\n".format(self.ntypes)
        retrstring += "   nanchors: {}\n".format(self.nanchors)
        for pot in self.potentials:
            retrstring += "   potential: {}\n".format(pot)

        return retrstring

    def PrintFrame(self):
        print "****************"
        print "Frame: {}".format(self.cur_frame)
        # Print out current information
        print "  istep: {}".format(self.istep)
        print "  h:\n{}".format(self.h)
        print "  fspb0: {}".format(self.fspb0)
        print "  fspb1: {}".format(self.fspb1)
        for ipot in self.potentials:
            print"  potential:"
            ipot.PrintFrame()

    # Generate the spindle forces at this timestep
    def GenerateSpindleForce(self, spindlevector):
        # Generate the parallel and perpendicular forces on each SPB
        # with regards to the main spindle vector
        shat = spindlevector / np.linalg.norm(spindlevector)
        # SPB0
        spb0_parallel = np.dot(self.fspb0, shat)
        spb0_perpvec  = self.fspb0 - spb0_parallel*shat
        spb0_perp     = np.linalg.norm(spb0_perpvec)

        spb1_parallel = 0.0
        spb1_perpvec  = 0.0
        spb1_perp     = 0.0
        if self.nanchors > 1:
            # SPB1
            spb1_parallel = np.dot(self.fspb1, shat)
            spb1_perpvec  = self.fspb1 - spb1_parallel*shat
            spb1_perp     = np.linalg.norm(spb1_perpvec)

        return np.array([spb0_parallel, spb0_perp, spb1_parallel, spb1_perp])

    # Generate the force information with quaternions etc
    def GenerateForceInformation(self, rspb0, rspb1):
        rspindle = rspb1 - rspb0
        ntomo = np.cross(rspb0, rspb1)
        rhat = rspindle / np.linalg.norm(rspindle)
        nhat = ntomo / np.linalg.norm(ntomo)
        #print "rhat = {}".format(rhat)
        #print "nhat = {}".format(nhat)

        zaxis = np.array([0.0, 0.0, 1.0])
        xaxis = np.array([1.0, 0.0, 0.0])

        q1 = quaternion_between_vectors(nhat, zaxis)
        theta, eaxis = axisangle_from_quaternion(q1)
        vrot = rodrigues_axisangle(rhat, theta, eaxis)
        q2 = quaternion_between_vectors(vrot, xaxis)
        q3 = quaternion_multiply(q2, q1)

        #print q1
        #print q2
        #print q3

        # q3 should be the quaternion that we are using
        theta_qr, axis_qr = axisangle_from_quaternion(q3)
        #print theta_qr
        #print axis_qr
        # Save these off for this time step for other uses!
        self.theta = theta_qr
        self.axis = axis_qr
        self.qr = q3

        # What are the rotated positions of the SPBs?
        rrot0 = rodrigues_axisangle(rspb0, theta_qr, axis_qr)
        rrot1 = rodrigues_axisangle(rspb1, theta_qr, axis_qr)
        frot0 = rodrigues_axisangle(self.fspb0, theta_qr, axis_qr)
        frot1 = rodrigues_axisangle(self.fspb1, theta_qr, axis_qr)

        #print '----'
        #print rrot0
        #print rrot1
        #print frot0
        #print frot1

        # We have eliminated zhat essentially, so now we just have the components
        # to look at along the force axis, and the tangent vector
        nhat0 = np.array([rrot0[1], -rrot0[0]])
        nhat0 = nhat0 / np.linalg.norm(nhat0)
        nhat1 = np.array([-rrot1[1], rrot1[0]])
        nhat1 = nhat1 / np.linalg.norm(nhat1)

        #print nhat0
        #print nhat1

        #Now look at the projection in the nhat direction for each!
        f0 = np.dot(nhat0, frot0[0:2])
        f1 = np.dot(nhat1, frot1[0:2])

        #print f0
        #print f1

        ## Do some simple plotting of the SPB locations in the 2d plane, and the forces, etc
        #import matplotlib
        #import matplotlib.pyplot as plt
        #fig, ax = plt.subplots()
        #ax.set_xlim((-60.0,60.0))
        #ax.set_ylim((-60.0,60.0))

        ## Draw the nuclear envelope
        #circle = plt.Circle((0,0), 55., fill=False)
        #ax.add_artist(circle)

        ## Draw the two SPBs
        #spb0_graphic = plt.Circle((rrot0[0], rrot0[1]), 1., color = 'r')
        #spb1_graphic = plt.Circle((rrot1[0], rrot1[1]), 1., color = 'g')
        #ax.add_artist(spb0_graphic)
        #ax.add_artist(spb1_graphic)

        ## Draw the tangent vectors at the two spbs
        #plt.arrow(rrot0[0], rrot0[1], 10.*nhat0[0], 10.*nhat0[1], color = 'r')
        #plt.arrow(rrot1[0], rrot1[1], 10.*nhat1[0], 10.*nhat1[1], color = 'g')

        ## Draw the rotated force vectors
        #plt.arrow(rrot0[0], rrot0[1], 1.*frot0[0], 1.*frot0[1], color = 'c')
        #plt.arrow(rrot1[0], rrot1[1], 1.*frot1[0], 1.*frot1[1], color = 'y')

        ## Draw everything in the patches
        #ax.set_aspect('equal', 'box')
        #plt.show()

        return np.array([f0, f1])


    # Get just the MT virial contributions
    def GenerateVirialMT(self, rspb0, rspb1):
        spindlevector = rspb1 - rspb0
        print "--------"
        print "spindle = {}".format(spindlevector)
        spindlehat = spindlevector / np.linalg.norm(spindlevector)
        print "spindlehat = {}".format(spindlehat)
        # Get the virial contributions
        self.total_virial = np.zeros((self.ndim, self.ndim))
        for pot in self.potentials:
            print pot
            if pot.name in self.mt_virial_forces:
                self.total_virial += pot.virial
                w1,v1 = np.linalg.eig(pot.virial)
                print "name: {}".format(pot.name)
                print "w1=\n{}".format(w1)
                print "v1=\n{}".format(v1)

        w,v = np.linalg.eig(self.total_virial)
        print "name: total MT"
        print "w=\n{}".format(w)
        print "v=\n{}".format(v)

        # We shoudl also think about the spindle axis as the nematic director.
        # This could be useful in considering the pressures into/out of the spindle
        # as it is elongating
        R = rotation_matrix_from_quaternion(self.qr)
        Rinv = np.linalg.inv(R)
        rotated_virial = Rinv * self.total_virial * R
        print "rotated_virial=\n{}".format(rotated_virial)
        w,v = np.linalg.eig(rotated_virial)
        print "name: rotated total MT"
        print "w=\n{}".format(w)
        print "v=\n{}".format(v)

        #sys.exit(1)


        ##mt_potentials = [0, 1, 2, 6, 7, 8, 9]
        ##mt_potentials = [0, 1, 2]
        #self.total_virial = np.zeros((self.ndim, self.ndim))
        #for icomp in xrange(self.ncomp):
        #    # Get which ones are actually happening to MTs
        #    # 0, 1, 2, 
        #    if self.potentials[icomp].idx in mt_potentials:
        #        self.total_virial += self.potentials[icomp].virial
        #        w1, v1 = np.linalg.eig(self.potentials[icomp].virial)
        #        self.potentials[icomp].PrintFrame()
        #        print "w1 =\n{}".format(w1)
        #        print "v1 =\n{}".format(v1)
        #        print "dot1 =\n{}".format(np.dot(spindlehat, v1))

        #print "total virial =\n{}".format(self.total_virial)
        #w, v = np.linalg.eig(self.total_virial)
        #print "w =\n{}".format(w)
        #print "v =\n{}".format(v)
        ## dot in spindlehat?
        #print "dot =\n{}".format(np.dot(spindlehat, v))
        ## Should apply rotation to figure out rotation matrix that
        ## transforms from cartesian coordinates to one defined by the
        ## spindle axis



##########################################
if __name__ == "__main__":
    t = ReadPositThermo(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
    t.LoadPosit()
    print "{}".format(t)
    t.ReadFramePosit()
    t.PrintFrame()
    print "TEST ONLY!!!!"




