import ome
import numpy as np
import scipy.ndimage as ndi
import math
import tqdm
import multiprocessing
import warnings

def spin(omeFile,dx,spins,headPosition):

    # Head position is:
    #   0 if at the top
    #   1 if on the right
    #   2 if at the bottom
    #   3 if at the left
    # See rot90 in the image loader.

    # Get sequence information
    if not((dx[0]==min(dx)) and (dx[1]==min(dx))):
        print('Render only corrects for anisotropy in Z.')
        return

    omeStats = ome.index(omeFile)
    dataSize = omeStats['size'][1:]
    channels = ome.redGreen(omeFile)
    print('RFP is on channel %i, GCaMP is on channel %i.' % (channels[0],channels[1]))

    th = np.linspace(0,360*spins,dataSize[0],False)
    pbar = tqdm.tqdm(total=dataSize[0])
    results = list()
    with multiprocessing.Pool() as pool:
        params = [(t,[ome.manifest(omeFile,omeStats,channels[0],t),ome.manifest(omeFile,omeStats,channels[1],t)],dataSize[1:],dx,th[t],headPosition) for t in list(range(dataSize[0]))]
        for x in pool.imap_unordered(spinJob,params):
            results.append(x)
            pbar.update()
        pbar.close()
    
    redFrames = np.zeros((dataSize[0],)+results[0][1].shape,'single');
    greenFrames = np.zeros((dataSize[0],)+results[0][1].shape,'single');
    redStill = np.zeros((dataSize[0],)+results[0][1].shape,'single');
    greenStill = np.zeros((dataSize[0],)+results[0][1].shape,'single');
    
    for x in results:
        redFrames[x[0],:,:] = x[1]
        greenFrames[x[0],:,:] = x[2]
        redStill[x[0],:,:] = x[3]
        greenStill[x[0],:,:] = x[4]

    # Flip array orders into Matlab style
    redFrames = np.transpose(redFrames,(1,2,0))
    greenFrames = np.transpose(greenFrames,(1,2,0))
    redStill = np.transpose(redStill,(1,2,0))
    greenStill = np.transpose(greenStill,(1,2,0))
    originalVolume = results[0][5]
    return (redFrames,greenFrames,redStill,greenStill,originalVolume)

# ---

def padToSize(V,i,j):
    di = i-V.shape[0]
    dj = j-V.shape[1]
    return np.pad(V,((math.floor(di/2),math.ceil(di/2)),(math.floor(dj/2),math.ceil(dj/2))),'constant')
    
# ---

def spinJob(processArgs):
    # Unpack everything
    t,manifests,shape,dx,th,headPosition = processArgs

    # Load and rotate the input data
    # Suppress this version warning: 
    #   From scipy 0.13.0, the output shape of zoom() is calculated with round() instead 
    #   of int() - for these inputs the size of the returned array has changed.
    warnings.filterwarnings('ignore', '.*output shape of zoom.*')

    # Manifests for loading are passed in as a (red,green) pair
    RFP = np.single(np.rot90(ome.readVolumeWithManifest(manifests[0],shape),headPosition,(1,2)))
    GCaMP = np.single(np.rot90(ome.readVolumeWithManifest(manifests[1],shape),headPosition,(1,2)))

    RFPlimits = [np.percentile(np.log(RFP),30),np.max(np.log(RFP))];
    GCaMPlimits = [np.percentile(np.log(GCaMP),30),np.max(np.log(GCaMP))];

    # Make the data isometric.
    # For efficiency, this routine only corrects for anisotropy in Z.

    # Scan in raster fashion down the Y direction.        
    for i in range(RFP.shape[1]):
        
        # Extract the XZ slice at this Y level.
        Rslice = ndi.zoom(RFP[:,i,None],(dx[2]/min(dx),1,1),order=1)
        Gslice = ndi.zoom(GCaMP[:,i,None],(dx[2]/min(dx),1,1),order=1)

        # Rotate to the viewing angle and flatten.
        Rmip = np.amax(ndi.rotate(Rslice,-th,axes=(0,2),order=1),0)
        Gmip = np.amax(ndi.rotate(Gslice,-th,axes=(0,2),order=1),0)

        # Create storage variables on the first pass.
        if i == 0:        
            d = 1 + math.ceil(np.sqrt(Rslice.shape[0]**2+Rslice.shape[2]**2));
            Rt = np.zeros((RFP.shape[1],Rmip.shape[1]))
            Gt = np.zeros((RFP.shape[1],Rmip.shape[1]))

        Rt[i,:] = Rmip
        Gt[i,:] = Gmip

    # Pad to desired size
    Rt = padToSize(Rt,RFP.shape[1],d)
    Gt = padToSize(Gt,RFP.shape[1],d)        
    with np.errstate(divide='ignore'):
        Rt = np.single(np.clip(np.log(Rt),RFPlimits[0],RFPlimits[1]))
        Gt = np.single(np.clip(np.log(Gt),GCaMPlimits[0],GCaMPlimits[1]))
        
    # Make the still images
    Rs = padToSize(np.amax(RFP,0),RFP.shape[1],d)
    Gs = padToSize(np.amax(GCaMP,0),RFP.shape[1],d)
    with np.errstate(divide='ignore'):
        Rs = np.single(np.clip(np.log(Rs),RFPlimits[0],RFPlimits[1]))
        Gs = np.single(np.clip(np.log(Gs),GCaMPlimits[0],GCaMPlimits[1]))

    originalVolume = RFP.shape;
    originalVolume = (originalVolume[1],originalVolume[2],originalVolume[0])

    return (t,Rt,Gt,Rs,Gs,originalVolume)

# ---
