import arivis
import browser
import imagecore

import csv
import math
import copy

# parameters
PATH = 'D:/ExamplePath/Sample01/Spindle01/'
CSV_FILENAME1 = PATH + 'chromosomes/P/p_pole1.csv'
CSV_FILENAME2 = PATH + 'chromosomes/P/p_pole2.csv'
CSV_DELIMITER = ';'

CHANNEL      = 0			#channel to receive values from
LINE_RADIUS  = 9			#radius of the line
USE_GAUSSIAN = True			#otherwise, mean is used
START_FRAME	 = 0			#the index of the first frame to consider
LINE_OUTPUT  = False		#true to export lines, false otherwise 
LINE_PIXEL_SIZE  = 0.1		#size of one pixel in line space in microns
LINE_EXTRAPOLATE = 1.0		#how much the line should be extent at the start and end point in microns
  
LINE_OUTPUT_FILENAME = PATH + 'chromosomes/P/P_kymo_w0.sis' 
OUTPUT_FILENAME = PATH + 'chromosomes/P/P_kymo_w0.csv'

######
LINE_FILENAME = PATH + 'chromosomes/P/P_kymo_w0.csv'
KYMO_FILENAME = PATH + 'chromosomes/P/P_kymograph_w0.sis'
KYMO_PIXELTYPE = imagecore.ScopeBase.Pixeltype_USHORT
BORDER 		   = 0	#border (2)
COLS_PER_LINE  = 1	#cols per one intensity line (the width of intensity line) (20)
ROWS_PER_VALUE = 1	#rows per intensity value (the height if the intensity line) (5)
INTERPOLATE	   = False	#interpolate values (intensity values)
STRETCH 	   = False	#streches intensity profile lines over whole image
CENTER_VALUE   = 65535  #value indicates the center value of a line in second channel
######

class Gaussian:
  """
  helper class representing a 1D normalized gaussian distribution
  """
  
  def __init__(self, w):
	sigma = 0.3 * ((w-1)*0.5 -1.0) + 0.8
	self.values = []
	self.radius = (w-1) / 2
	
	#compute gauss
	omega = 4.0 * sigma*sigma;
	for i in range(0, w):
	  x = float(i) - (w-1) * 0.5
	  g = math.exp(-x*x / omega)
	  self.values.append(g)

  def eval(self, x, y):
	value = self.values[abs(x)-self.radius] * self.values[abs(y)-self.radius]
	return value

class AreaBuffer :
  """
  helper class to define an arivis.Buffer object and additional 
  bounds information
  """
  
  def __init__(self, buffer, bounds) :
	self.buffer = buffer
	self.bounds = bounds

class VolumeBuffer :
  """
  contains a number of arivis.Buffer objects representing the data space
  defined by a PointPair object. Objects are initiated by the PointPair class
  """
  
  def __init__(self) :
	self.buffers = []
	
  # add a buffer to buffers list
  def addBuffer(self, buffer, bounds) :
	self.buffers.append(AreaBuffer(buffer, bounds))
	
  # retrieves the value of the 3d index passed
  def getValue(self, x, y, z) :
	
	if len(self.buffers) <= z :
	  raise IndexError('Z index out of range')
	
	buf = self.buffers[z]
	
	if ((buf.bounds.width <= x) or (buf.bounds.height <= y)) :
	  raise IndexError()
	
	return buf.buffer[y * buf.bounds.width + x]

  # fetches a texel 
  def fetch(self, x, y, z) :
	
	iz = int(max(0.0, min(z, len(self.buffers)-1)))
	buf = self.buffers[iz]
	
	ix = int(max(0, min(x, buf.bounds.width-1)))
	iy = int(max(0, min(y, buf.bounds.height-1)))
	return buf.buffer[iy * buf.bounds.width + ix]

  # sample a value for the volume buffer using linear interpolation
  def sample(self, x, y, z):
  
	# lineary mix a and b
	def mix(a, b, t):
	  return a*(1.0-t) + b*t
  
	u = x - math.floor(x)
	v = y - math.floor(y)
	w = z - math.floor(z)
	
	ix = int(x)
	iy = int(y)
	iz = int(z)

	c00 = mix(self.fetch(ix  , iy  , iz  ), self.fetch(ix+1, iy  , iz  ), u)
	c01 = mix(self.fetch(ix  , iy+1, iz  ), self.fetch(ix+1, iy+1, iz  ), u)
	c10 = mix(self.fetch(ix  , iy  , iz+1), self.fetch(ix+1, iy  , iz+1), u)
	c11 = mix(self.fetch(ix  , iy+1, iz+1), self.fetch(ix+1, iy+1, iz+1), u)
	
	c0 = mix(c00, c01, v)
	c1 = mix(c10, c11, v)
	
	return int(mix(c0, c1, w))

	
class Vector3D :
  """
  helper class for 3D vectors/points
  """
  
  def __init__(self, x, y, z) :
	self.x = x
	self.y = y
	self.z = z
	
  # returns self - p
  def sub(self, p):
	return Vector3D(self.x - p.x, self.y - p.y, self.z - p.z)
	  
  # returns self + p
  def add(self, p):
	return Vector3D(self.x + p.x, self.y + p.y, self.z + p.z)

  # returns self * s
  def mul(self, s):
	return Vector3D(self.x *s, self.y * s, self.z * s)

  # returns magnitude squared
  def mag2(self):
	return self.x*self.x + self.y*self.y + self.z*self.z

  # returns magnitude
  def mag(self):
	normL1 = self.mag2()
	if normL1 > 0.0: 
	  return math.sqrt(normL1) 
	else: 
	  return 0.0

  # returns normalized vector
  def normalized(self):
	normL1 = self.mag2()
	if normL1 > 0.0:
	  invMag = 1.0 / math.sqrt(normL1)
	  return Vector3D(self.x*invMag, self.y*invMag, self.z*invMag)
	else:
	  return Vector3D(self.x, self.y, self.z)

  # returns distance to p
  def distanceTo(self, p):
	return p.sub(self).mag()

	
class BBox :
  """
  Bounding box in 3d
  """
  
  def __init__(self, p1, p2) :
	self.pmin = Vector3D(min(p1.x, p2.x), min(p1.y, p2.y), min(p1.z, p2.z))
	self.pmax = Vector3D(max(p1.x, p2.x), max(p1.y, p2.y), max(p1.z, p2.z))
	
  def left(self):
	return int(self.pmin.x)

  def right(self):
	return int(self.pmax.x)
	
  def top(self):
	return int(self.pmin.y)

  def bottom(self):
	return int(self.pmax.y)

  def front(self):
	return int(self.pmin.z)
	
  def back(self):
	return int(self.pmax.z)

  def width(self):
	return int(self.pmax.x - self.pmin.x) + 1

  def height(self):
	return int(self.pmax.y - self.pmin.y) + 1

  def depth(self):
	return int(self.pmax.z - self.pmin.z) + 1

  #builds union of bbox and the provided point
  def union(self, p):
	self.pmin = Vector3D(min(p.x, self.pmin.x), min(p.y, self.pmin.y), min(p.z, self.pmin.z))
	self.pmax = Vector3D(max(p.x, self.pmax.x), max(p.y, self.pmax.y), max(p.z, self.pmax.z))
	


class LineModel :
  """
  describes a line by 2 points and a radius.  
  """
  
  def __init__(self, frame, channel, x1, y1, z1, x2, y2, z2, spacing, radius) :
	self.frame = int(frame)
	self.channel   = int(channel)
	self.spacing   = spacing
	self.radius    = int(radius)
	
	p1 = Vector3D(spacing.x * x1, spacing.y * y1, spacing.z * z1)  
	p2 = Vector3D(spacing.x * x2, spacing.y * y2, spacing.z * z2)
	
	#compute extent
	d = p2.sub(p1)
	extent = d.normalized().mul(LINE_EXTRAPOLATE)
	
	#extrapolate line
	self.wp1 = p1.sub(extent)  
	self.wp2 = p2.add(extent)  
	
	#compute transform
	d              = self.wp2.sub(self.wp1)
	self.length	   = d.mag()
	self.planes	   = int(math.ceil(self.length / LINE_PIXEL_SIZE))
	self.azimutal  = math.atan2(d.y, d.x)
	self.polar 	   = -math.acos(d.z / self.length)
	
 
  # transform point from line (local) space to line (world) space
  def getScopePoint(self, x, y, z):
	#rotate frame -> z = x
	p = Vector3D(x, y, z)
	
	#transform to world space
	p = p.mul(LINE_PIXEL_SIZE)
	
	#rotate zx plane around y
	sin_t = math.sin(self.polar)
	cos_t = math.cos(self.polar)
	p = Vector3D(p.x*cos_t - p.z*sin_t, p.y, p.z*cos_t + p.x*sin_t)
	
	#rotate in xy plane around z
	sin_t = math.sin(self.azimutal)
	cos_t = math.cos(self.azimutal)
	p = Vector3D(p.x*cos_t - p.y*sin_t, p.y*cos_t + p.x*sin_t, p.z)
	
	#move origin
	p = p.add(self.wp1)
	
	#transform to pixel
	return Vector3D(p.x / self.spacing.x, p.y / self.spacing.y, p.z / self.spacing.z) 


  #get the bounds in line model space (line aligned along z-direction)
  def getModelBounds(self):
	return BBox(Vector3D(-self.radius, -self.radius, 0), Vector3D(self.radius, self.radius, self.planes))


  #get the bounds of the line in scope space
  def getBounds(self):
	localBounds = self.getModelBounds()
	p = self.getScopePoint(localBounds.pmin.x, localBounds.pmin.y, localBounds.pmin.z)
	bounds = BBox(p, p)
	bounds.union(self.getScopePoint(localBounds.pmax.x, localBounds.pmin.y, localBounds.pmin.z))
	bounds.union(self.getScopePoint(localBounds.pmin.x, localBounds.pmax.y, localBounds.pmin.z))
	bounds.union(self.getScopePoint(localBounds.pmax.x, localBounds.pmax.y, localBounds.pmin.z))
	bounds.union(self.getScopePoint(localBounds.pmin.x, localBounds.pmin.y, localBounds.pmax.z))
	bounds.union(self.getScopePoint(localBounds.pmax.x, localBounds.pmin.y, localBounds.pmax.z))
	bounds.union(self.getScopePoint(localBounds.pmin.x, localBounds.pmax.y, localBounds.pmax.z))
	bounds.union(self.getScopePoint(localBounds.pmax.x, localBounds.pmax.y, localBounds.pmax.z))
	
	bounds.pmin = bounds.pmin.sub(Vector3D(1, 1, 1))
	bounds.pmax = bounds.pmax.add(Vector3D(2, 2, 2))	#due to rounding
	return bounds
	
  def getVolumeBuffer(self, scope):
	"""
	retrieves a VolumeBuffer object of the volume defined for the scope passed
	"""
	
	ret = VolumeBuffer()
	volBounds = self.getBounds()
	bufBounds = arivis.Rect(volBounds.left(), volBounds.top(), volBounds.width(), volBounds.height())
	planeBounds = scope.get_bounding_rect(self.frame)
	
	npixels = bufBounds.width * bufBounds.height
	
	for plane in range(0, volBounds.depth()):
	  buffer = arivis.Buffer(scope.get_pixel_type(), npixels)
	  scope.get_channeldata(planeBounds, bufBounds, buffer, self.channel, volBounds.front()+plane, self.frame)
	  ret.addBuffer(buffer, bufBounds)
	 
	return ret
	

  #draw line model to scope (debug purposes)
  def draw(self, scope):
	
	G = Gaussian(self.radius*2+1)
	for z in range(0, self.planes):
	  for y in range(-self.radius, self.radius+1):
		for x in range(-self.radius, self.radius+1):
		  p = self.getScopePoint(x, y, z)
		  scope.set_pixel(arivis.Point(int(p.x), int(p.y)), G.eval(x, y) * 65535.0, 0, int(p.z), self.frame)

  
   #draw line model (local space) to scope
  def drawLocal(self, lineScope, frame):
	
	bufferBounds = self.getBounds()
	buffer = self.getVolumeBuffer(scope)
	
	size = self.radius * 2 + 1
	bufrect = arivis.Rect(0, 0, size, size)
	buf = arivis.Buffer(lineScope.get_pixel_type(), size*size)
	
	for z in range(0, self.planes):
	  for y in range(-self.radius, self.radius+1):
		for x in range(-self.radius, self.radius+1):
		  
		  ix = x + self.radius
		  iy = y + self.radius
		  
		  p  = self.getScopePoint(x, y, z)
		  pp = p.sub(bufferBounds.pmin)
		  buf[ix +iy * size] = buffer.sample(pp.x, pp.y, pp.z)
		
	  lineScope.set_channeldata(bufrect, buf, 0, z, frame)
  
    
  #get the intensity profile for this line
  def getIntensityProfile(self, scope):
	
	
	line = []
	
	#get volume buffer
	bufferBounds = self.getBounds()
	buffer = self.getVolumeBuffer(scope)
	
  	G = Gaussian(self.radius*2+1)
	#iterate through line
	for z in range(0, self.planes):
	
	  #convolve 2d-line slice
	  sum = 0.0
	  weight = 0.0
	  for y in range(-self.radius, self.radius+1):
		for x in range(-self.radius, self.radius+1):
		  p = self.getScopePoint(x, y, z)
		  
		  w = 1.0
		  if USE_GAUSSIAN == True:
			w = G.eval(x, y);
		
		  pp = p.sub(bufferBounds.pmin)
		  #value = scope.get_pixel(arivis.Point(int(p.x), int(p.y)), self.channel, int(p.z), self.frame)
		  value = buffer.sample(pp.x, pp.y, pp.z)
		  #scope.set_pixel(arivis.Point(int(p.x), int(p.y)), value, 1, int(p.z), self.frame)
		  sum  += value * w
		  
		  weight += w

	  #append value
	  line.append(int(sum / weight))
	
	#return value array
	return line
  


class CSVReader :
  """
  reads a CSV file and parses it for point pairs
  """
  
  # parses the file and adds PointPair objects to list
  def parse(self, fileName1, fileName2, spacing) :
	
	self.lineModels = []
	
	reader1 = self.open(fileName1)
	reader2 = self.open(fileName2)

	list1 = list(reader1)
	list2 = list(reader2)

	assert len(list1) == len(list2), "Point lists must have same length"

	for index in range(len(list1)) :

	  frame = START_FRAME + index

	  i1 = list1[index]
	  i2 = list2[index]
	  self.lineModels.append(LineModel(frame, CHANNEL, float(i1[0]), float(i1[1]), float(i1[2]), float(i2[0]), float(i2[1]), float(i2[2]), spacing, LINE_RADIUS))
		
  # opens a CSV file
  def open(self, fileName) :
	file = open(fileName, mode='r')
	reader = csv.reader(file, delimiter = CSV_DELIMITER)
	return reader
	
	
# main

# retrieve viewer and scope
viewer = browser.get_active_viewer()
scope = viewer.get_active_scope()

# create a CSVReader object and parse the file
reader = CSVReader()
reader.parse(CSV_FILENAME1, CSV_FILENAME2, Vector3D(scope.get_pixel_width(), scope.get_pixel_height(), scope.get_pixel_depth()))

# create a file to write values
outputFile = open(OUTPUT_FILENAME, 'w')


#create output scope for lines if necessary
lineScope = None 
if LINE_OUTPUT == True:
  
  width = LINE_RADIUS * 2 + 1
  nplanes = 1;
  #get max line size
  for line in reader.lineModels:
	nplanes = max(nplanes, line.planes)

  #create the scope
  lineViewer = browser.create_imagestack(LINE_OUTPUT_FILENAME, None, scope.get_pixel_type(), width, width, len(reader.lineModels), nplanes, 1)
  lineScope = lineViewer.get_active_scope()

  

# iterate the PointPair objects from list 
frame = 0
for line in reader.lineModels :
 
  # create a PointPairLine object and retrieve the line values
  values = line.getIntensityProfile(scope)
  #line.draw(scope)
  if lineScope != None:
	line.drawLocal(lineScope, frame)
	frame = frame + 1
	
  #line.draw(scope)
  
  # write the line values to file
  for value in values :
	outputFile.write("%s;" % value)
  
  outputFile.write("\n");
 

outputFile.close()

######

class PLF:
  """
  piecewise linear function approximation
  """
  
  def __init__(self, values, positions):
	self.len = len(values)
	self.values = values
	self.positions = positions
	assert len(values) == len(positions)
  
  def __init__(self, values):
	self.len = len(values)
	self.values = values
	self.positions = []
	
	for i in range(0, self.len):
	  self.positions.append(float(i) / float(self.len-1))
	  
  
  def sample(self, x):
	#find segment
	i = self.len-1
	for k in range(0, self.len):
	  if (self.positions[k] > x):
		i = k
		break
	
	if (i == 0):
	  return self.values[0]
	
	a = self.positions[i-1];
	b = self.positions[i];

	t = max(0.0, min(1.0, (x - a) / (b - a)))
	return int(self.values[i-1]*(1.0-t) + self.values[i]*t)
	
	
  def fetch(self, x):
	for k in range(0, self.len):
	  if (self.positions[k] > x):
		return self.values[k]
	
	return self.values[self.len-1]
	
class CSVLineReader :
  """
  reads a CSV file with line intensity data and parses it
  """
  
  # parses the file and adds PointPair objects to list
  def parse(self, fileName) :
	
	
	reader = self.open(fileName)
	lines = list(reader)
	
	plfs = []
	for line in lines:
	  values = []
	  for s in line:
		if s != '':
		  values.append(int(s))
		
	  plfs.append(PLF(values))
		
	return plfs
	
	
  # opens a CSV file
  def open(self, fileName) :
	file = open(fileName, mode='r')
	reader = csv.reader(file, delimiter = CSV_DELIMITER)
	return reader




# main

# create line list
reader = CSVLineReader()
plfs = reader.parse(LINE_FILENAME)

# get number of lines and length of longest line
numProfiles = len(plfs)
maxProfileLength = 0
for plf in plfs :
  maxProfileLength = max(maxProfileLength, plf.len)


# create dataset
width  = numProfiles * COLS_PER_LINE + BORDER * 2
height = maxProfileLength*ROWS_PER_VALUE + BORDER * 2 
viewer = browser.create_imagestack(KYMO_FILENAME, None, KYMO_PIXELTYPE, width, height, 1, 1, 2)
scope = viewer.get_active_scope()

#create buffer for fast kymo generation
buf = arivis.Buffer(KYMO_PIXELTYPE, width*height)

#fill kymogram
colsize = maxProfileLength*ROWS_PER_VALUE

for j in range(0, numProfiles):
  plf = plfs[j]
  row = []
  
  x = BORDER + j * COLS_PER_LINE
  y = BORDER
  
  if STRETCH == False:
	colsize = plf.len * ROWS_PER_VALUE
	y		= BORDER + (int(maxProfileLength - plf.len) / 2) * ROWS_PER_VALUE
		
  for i in range(0, colsize):
	#compute value
	p = float(i) / float(colsize-1)
	v = 0.0
	if INTERPOLATE == True:
	  v = plf.sample(p)
	else:
	  v = plf.fetch(p)
	  
	#draw the line segment
	for k in range(0, COLS_PER_LINE):
	  bi = (x+k)+(y+i)*width
	  buf[bi] = v
	  
  #compute y-position
  if STRETCH == False:
	y = y + int(plf.len / 2) * ROWS_PER_VALUE 					#advance to center
  else:
	y = BORDER + int(maxProfileLength / 2) * ROWS_PER_VALUE 	#center of line

  #draw the line center point
  for col in range(0, COLS_PER_LINE):
	for row in range(0, ROWS_PER_VALUE):
	  scope.set_pixel(arivis.Point(x + col, y + row), CENTER_VALUE, 1)
  
#write buffer
scope.set_channeldata(scope.get_bounding_rect(), buf, 0)