import ome
import tqdm
import pickle
import os.path
import hdf5storage
import numpy as np
import multiprocessing

from scipy import io
from heapq import merge
from scipy import stats
from numpy import linalg
from scipy import ndimage
from munkres import munkres

# ---

def load(fname):
    with open(fname,'rb') as f:
        return pickle.load(f)

# ---

def save(fname,x):
    with open(fname,'wb') as f:
        pickle.dump(x,f)
    return

# ---

def savemat(fname,mdict):
    # Previously ...
    #   io.savemat(fname,mdict,do_compression=True)
    # Now upgraded to use the hdf5storage library so it can handle larger files.
    hdf5storage.savemat(fname,mdict)

# ---

def makeShell(dx,inner,outer):
    s = np.ceil((outer/dx[2],outer/dx[1],outer/dx[0])).astype(int)
    zm,ym,xm = np.meshgrid(np.arange(-s[0],s[0]+1),np.arange(-s[1],s[1]+1),np.arange(-s[2],s[2]+1),indexing='ij')
    xm = xm * dx[0]
    ym = ym * dx[1]
    zm = zm * dx[2]
    d2 = np.power(xm,2) + np.power(ym,2) + np.power(zm,2)
    shell = np.where((d2>=np.power(inner,2)) & (d2<=np.power(outer,2)))
    return [tuple(shell[n]-s[n]) for n in range(3)]

# ---

def moveShell(shell,tgt,limits):
    zsh = [tgt[0]+z for z in shell[0]]
    ysh = [tgt[1]+y for y in shell[1]]
    xsh = [tgt[2]+x for x in shell[2]]
    mask = [(zsh[n]<0) or (zsh[n]>=limits[0]) for n in range(len(zsh))]
    mask = [(ysh[n]<0) or (ysh[n]>=limits[1]) or mask[n] for n in range(len(ysh))]
    mask = [(xsh[n]<0) or (xsh[n]>=limits[2]) or mask[n] for n in range(len(xsh))]
    zsh = [zsh[n] for n in range(len(zsh)) if not mask[n]]
    ysh = [ysh[n] for n in range(len(ysh)) if not mask[n]]
    xsh = [xsh[n] for n in range(len(xsh)) if not mask[n]]
    return zsh,ysh,xsh

# ---

def pca(V,dx):
    vsh,vsz,vsm = V.shape,V.size,np.sum(V)
    zm,ym,xm = np.meshgrid(np.arange(vsh[0]),np.arange(vsh[1]),np.arange(vsh[2]),indexing='ij')
    xm = xm * dx[0]; ym = ym * dx[1]; zm = zm * dx[2]
    xc = np.sum(V*xm)/vsm
    yc = np.sum(V*ym)/vsm
    zc = np.sum(V*zm)/vsm
    xm = xm - xc; ym = ym - yc; zm = zm - zc

    Sxx = np.sum(V*xm*xm)
    Sxy = np.sum(V*xm*ym)
    Sxz = np.sum(V*xm*zm)
    Syy = np.sum(V*ym*ym)
    Syz = np.sum(V*ym*zm)
    Szz = np.sum(V*zm*zm)
    CVM = np.array([[Sxx,Sxy,Sxz],[Sxy,Syy,Syz],[Sxz,Syz,Szz]])/vsz
    val,vec = linalg.eig(CVM)
    return (xc,yc,zc),vec,val,CVM

# ---

def otsu(V):
    thresh = np.percentile(V,99.95)
    nbins = 256
    vmin = V.min(); vmax = V.max()
    h, binEdges = np.histogram(V,nbins)
    h = h.astype(float);
    p = h/np.sum(h)
    w = np.cumsum(p)
    mu = np.cumsum(p*(np.arange(nbins)+1))
    mu_t = mu[-1]

    for N in range(1,3):
        if N==1:
            with np.errstate(divide='ignore',invalid='ignore'):
                sbsq = (mu_t*w - mu)*(mu_t*w - mu)/(w*(1-w))
                sbsq[np.isnan(sbsq)] = 0
            th = np.array([np.argmax(sbsq)])
        if N==2:
            with np.errstate(divide='ignore',invalid='ignore'):
                w0 = np.tile(w[:,None],(1,nbins))
                mu_0_t = mu_t - mu/w; mu_0_t = np.tile(mu_0_t[:,None],(1,nbins))
                w1 = w - w[:,None]
                mu_1_t = mu_t - np.divide((mu-mu[:,None]),w1)
                [rm,cm]=np.meshgrid(np.arange(nbins),np.arange(nbins),indexing='ij')
                mask = rm>cm
                w0[mask] = np.nan
                w1[mask] = np.nan
                term1 = w0 * mu_0_t * mu_0_t
                term2 = w1 * mu_1_t * mu_1_t
                w2 = 1 - (w0 + w1)
                w2[np.where(w2<=0)] = np.nan
                term3 = np.divide(np.power((w0 * mu_0_t + w1 * mu_1_t),2),w2)    
                sbsq = term1 + term2 + term3
                sbsq[np.isnan(sbsq)] = 0
            th = np.array(np.unravel_index(sbsq.argmax(), sbsq.shape))
        th = binEdges[0] + th * (binEdges[-1]-binEdges[0]) / (nbins-1)
        if th[-1]>thresh:
            thresh = th[-1]
            break
    return thresh

# ---

def LoG3D(r,dx,V):
    # Convolve the image stack V with the Laplacian-of-Gaussian kernel for the
    # spot size r given voxel size dx = [dx,dy,dz]
    # Make the LoG kernel with Huertas and Medioni separability.
    # See Huertas1986, Sage2005.
    # Go about 4 SD in each direction.

    sigma = r/np.sqrt(3)
    N = np.floor(4*sigma/dx)
    X = dx[0]*np.arange(-N[0],1+N[0])
    Y = dx[1]*np.arange(-N[1],1+N[1])
    Z = dx[2]*np.arange(-N[2],1+N[2])

    sigma2 = sigma * sigma
    hBX = np.exp(-X*X/(2*sigma2))
    hBY = np.exp(-Y*Y/(2*sigma2))
    hBZ = np.exp(-Z*Z/(2*sigma2))

    sigma4 = sigma2 * sigma2
    hFX = ((X*X - sigma2)/sigma4)*hBX;
    hFY = ((Y*Y - sigma2)/sigma4)*hBY;
    hFZ = ((Z*Z - sigma2)/sigma4)*hBZ;

    C = 0;
    A = np.zeros((3,1))
    A[0] = np.sum(hFX[:,None]*hBY)
    A[1] = np.sum(hBX[:,None]*hFY)
    A[2] = np.sum(hBX[:,None]*hBY)
    for n in range(len(hFZ)):
       C = C + A[0] * hBZ[n];
       C = C + A[1] * hBZ[n];
       C = C + A[2] * hFZ[n];

    Q =     conv3sep(V,[hBZ,hBY,hFX])
    Q = Q + conv3sep(V,[hBZ,hFY,hBX])
    Q = Q + conv3sep(V,[hFZ,hBY,hBX])
    return Q/C

# ---

def conv3sep(V,H):
   for n in range(3):
       if n==0:
          W = ndimage.correlate1d(V,H[n],axis=n,mode='nearest')
       else:
          W = ndimage.correlate1d(W,H[n],axis=n,mode='nearest')
   return W

# ---

def findPoints(r,dx,pmax,pad,U):
    pindex = np.argsort(U.reshape([-1]))
    overlapped = np.zeros(U.shape,dtype=bool)
    shell = makeShell(dx,0,(1+0.01*pad)*r)

    points = list()
    for candidate in reversed(pindex):
        tgt = np.unravel_index(candidate,U.shape)
        if not overlapped[tgt[0],tgt[1],tgt[2]]:
            points.append(tgt)
            zsh,ysh,xsh = moveShell(shell,tgt,U.shape)
            overlapped[zsh,ysh,xsh] = True
        if len(points)==pmax:
            break
    return points

# ---

def getBrightness(points,V):
    brightness = list()
    for p in points:
        brightness.append(V[p])
    return brightness

# ---

def phaseOneJob(processArgs):
    # Unpack everything
    t,manifest,shape,r,dx,pmax,pad,napier = processArgs

    V = ome.readVolumeWithManifest(manifest,shape)
    if V.shape != shape:
        raise ValueError('Inconsistent stack sizes encountered.')
    if napier:
        # Uses the logarithm of the raw intensity values.
        # Correct any values less than 1 to 1 so that the
        # log is everywhere postive. 
        V[V<1] = 1
        V = np.log(V)

    try:
        thresh = otsu(V)
        centroid,vecs,vals,CVM = pca(V*(V>thresh),dx)
        U = LoG3D(r,dx,V)
        points = findPoints(r,dx,pmax,pad,U)
        brightness = getBrightness(points,V)
        LoGB = getBrightness(points,U)
    except:
        # It's possible that task may fail if there's nothing to track.
        # i.e. in the light stimulation, everything's washed out.
        return (t, None)
    
    return (t, {'thresh': thresh, 'centroid': centroid, 'vecs': vecs, 'vals':vals, 'CVM':CVM, 'points': points, 'brightness': brightness, 'LoGB': LoGB, 'ok':True})

# ---

def phaseOne(omeFile,r,dx,pmax,pad,napier=True,tlimit=None,seedAxes=np.eye(3)):
    pmax = int(pmax)
    omeStats = ome.index(omeFile)
    channels = ome.redGreen(omeFile)

    tracks = list()
    if tlimit is None:
        trange = range(omeStats['size'][1])
    else:
        trange = tlimit

    # Create job parameters.
    # channels[0] gives the channel that holds the RFP data in the ome.tif file.    
    params = [(t,ome.manifest(omeFile,omeStats,channels[0],t),omeStats['size'][2:],r,dx,pmax,pad,napier) for t in trange]

    print('Phase 1 - Scanning')
    print('RFP is on channel %i, GCaMP is on channel %i.' % (channels[0],channels[1]))

    with multiprocessing.Pool() as pool:
        for result in tqdm.tqdm(pool.imap_unordered(phaseOneJob,params),total=len(params)):
            tracks.append(result)
    tracks = [x[1] for x in sorted(tracks)]
    
    # Repair missed steps, replacing None with the previous step's data.
    for i,item in enumerate(tracks):
        if i==0:
            continue
        if item is None:
            tracks[i] = tracks[i-1].copy()
            tracks[i]['ok'] = False
    
    resolveToPCA = True
    if seedAxes is None:
        seedAxes = np.eye(3)
        resolveToPCA = False
    align(tracks,seedAxes,dx,resolveToPCA)
    return tracks

# ---

def straighten(tracks,seedAxes):
    prior = seedAxes
    for t in range(len(tracks)):
        # Maintain best alignment with a pre-existing system.
        # (1) Keep the axes in the same order.
        tracks[t]['vecsOrig'] = tracks[t]['vecs'].copy()
        vecs = tracks[t]['vecs'].copy()
        dot2 = np.power(prior.transpose() @ vecs,2)
        axmatch = list()
        for i in range(3):
            for j in range(3):
                axmatch.append((dot2[i][j],i,j))
        axmatch = sorted(axmatch,key=lambda x:-x[0])
        i = [None,None,None]
        while len(axmatch):
            best = axmatch[0]
            i[best[1]] = best[2]
            axmatch = [x for x in axmatch if x[1]!=best[1] and x[2]!=best[2]]
        vecs = vecs[:,i]
        
        # (2) Correct any reversals in direction sense.
        for j in range(3):
            vecs[:,j] *= np.sign(np.dot(prior[:,j],vecs[:,j]))
        tracks[t]['vecs'] = vecs.copy()
        prior = tracks[t]['vecs']

    # Do a sanity check to make sure there are no axis inversions.
    # i.e. check using determinants that the co-ordinate system never
    # mirrors from right-handed to left-handed.

    if any([linalg.det(tracks[t]['vecs'])<0 for t in range(len(tracks))]):
        print('Warning: Axes became left-handed.')
    
    return

# ---

def makeCoords(tracks,dx,resolveToPCA):
    # Convert the points into coordinates in the PCA system.
    for t in range(len(tracks)):
        tracks[t]['coords'] = list()
        for i,p in enumerate(tracks[t]['points']):
            x = np.array((p[2]*dx[0],p[1]*dx[1],p[0]*dx[2]))
            x = x - tracks[t]['centroid']
            if resolveToPCA:
                coords = x @ tracks[t]['vecs']
            else:
                coords = x
            tracks[t]['coords'].append(np.append(coords,tracks[t]['brightness'][i]))
    return

# ---

def align(tracks,seedAxes,dx,resolveToPCA=True):
    straighten(tracks,seedAxes)
    makeCoords(tracks,dx,resolveToPCA)

# ---

def phaseTwoAJob(processArgs):
    fromX,fromY,fromZ,fromB,toCoords,weight,ri,fromFrame,toFrame = processArgs
    d2  = np.square(weight[0]*(fromX - toCoords[0,:]))
    d2 += np.square(weight[1]*(fromY - toCoords[1,:]))
    d2 += np.square(weight[2]*(fromZ - toCoords[2,:]))
    d2 += np.square(weight[3]*(fromB - toCoords[3,:]))
    ci = [tuple(x).index(True) for x in munkres(d2)]
    cost = d2[ri,ci].sum()
    return (cost,(fromFrame,toFrame),ci) 

# ---

def phaseTwoA(tracks,nloci,nsources=None,bweight=0,headOrient='T',keyframe=0,seed=[],matchRange=None):
    # Chain together over the available points, creating nloci tracks.
    # bweight is scaling factor applied to difference in brightness.
    # Can initialize by default with:
    #     the coords for the n brightest points in the zeroth frame.
    # or, specify coordinates as prime with which to begin (overrides n).
    # Starts from the designated keyframe and blooms out to frames
    # that are the most similar.

    print('Phase 2A - Matching')
    if nsources is None:
        nsources = len(tracks[keyframe]['coords'])
    weight = np.array([1,1,1,bweight])
    ri = list(range(nloci))
    coordTable = np.zeros((4,len(tracks),nsources))
    for t in range(len(tracks)):
        for n in range(nsources):
            coordTable[:,t,n] = tracks[t]['coords'][n]

    if len(seed)!=0:
        # A preceeding set of tracks was supplied, which we should use to optimally
        # cross-register this track set.
        costBoard = []
        nseedsources = len(seed[0]['coords'])
        seedTable = np.zeros((4,len(seed),nseedsources))
        for t in range(len(seed)):
            for n in range(nseedsources):
                seedTable[:,t,n] = seed[t]['coords'][n]
        pbar = tqdm.tqdm(total = len(seed)*len(tracks))
        with multiprocessing.Pool() as pool:
            for fromFrame in range(len(seed)):
                fromSeed = seedTable[:,fromFrame,seed[fromFrame]['chain']]
                fromX = fromSeed[0,:][:,None]
                fromY = fromSeed[1,:][:,None]
                fromZ = fromSeed[2,:][:,None]
                fromB = fromSeed[3,:][:,None]
                params = [(fromX,fromY,fromZ,fromB,coordTable[:,toFrame,:],weight,ri,-1-len(seed)+fromFrame,toFrame) for toFrame in list(range(len(tracks)))]
                for result in pool.imap_unordered(phaseTwoAJob,params):
                    costBoard.append(result)
                pbar.update(len(tracks))
        costBoard.sort()
        pbar.close()
    
    else:
        # Initialize the costBoard with the brightest points in the keyframe
        # Sort the list dependent on the head orientation in the image so that
        # the most cephalic (sensory) neurons have the lowest numbers.
        try:
            headOrient = headOrient.upper()[0]
        except:
            pass
            
        if headOrient=='R' or headOrient=='E' or headOrient==1:
            # The head is at the right of the image, so number with decreasing X.
            print('Head orientation right/east')
            ci = sorted(ri,key=lambda n:-tracks[keyframe]['points'][n][2])
        elif headOrient=='L' or headOrient=='W' or headOrient==3:
            # The head is at the left of the image, so number with increasing X.
            print('Head orientation left/west')
            ci = sorted(ri,key=lambda n:tracks[keyframe]['points'][n][2])
        elif headOrient=='B' or headOrient=='S' or headOrient==2:
            # The head is at the bottom of the image, so number with decreasing Y.
            print('Head orientation bottom/south')
            ci = sorted(ri,key=lambda n:-tracks[keyframe]['points'][n][1])
        elif headOrient=='T' or headOrient=='N' or headOrient==0:
            # The head is at the top of the image, so number with increasing Y.
            print('Head orientation top/north')
            ci = sorted(ri,key=lambda n:tracks[keyframe]['points'][n][1])
        else:
            # No orientation was requested.
            print('No orientation, ordered by brightness.')
            ci = ri[:]

        costBoard = [(0,(keyframe,keyframe),ci)]

    if matchRange is None:
        print('Matching all frames against each other.')
    else:
        print('Matching frames within %i frames of each other.' % matchRange)

    pbar = tqdm.tqdm(total = (len(tracks)*(len(tracks)+1))//2)
    toDo = list(range(len(tracks)))
    with multiprocessing.Pool() as pool:
        while len(toDo):
            fromFrame = costBoard[0][1][1]
            tracks[fromFrame]['chain'] = costBoard[0][2]
            tracks[fromFrame]['chainRMSE'] = np.sqrt(costBoard[0][0]/nloci)
            tracks[fromFrame]['chainPair'] = costBoard[0][1]
            toDo = [x for x in toDo if x!=fromFrame]
            costBoard = [x for x in costBoard if x[1][1]!=fromFrame]
            
            fromCoords = coordTable[:,fromFrame,tracks[fromFrame]['chain']]
            fromX = fromCoords[0,:][:,None]
            fromY = fromCoords[1,:][:,None]
            fromZ = fromCoords[2,:][:,None]
            fromB = fromCoords[3,:][:,None]

            params = [(fromX,fromY,fromZ,fromB,coordTable[:,toFrame,:],weight,ri,fromFrame,toFrame) for toFrame in toDo if matchRange is None or abs(toFrame-fromFrame)<=matchRange]
            costBoardNewItems = list()
            for result in pool.imap_unordered(phaseTwoAJob,params):
                costBoardNewItems.append(result)
            costBoard = list(merge(costBoard,sorted(costBoardNewItems)))
            pbar.update(len(toDo)+1)
    pbar.close()
    return

# ---

def phaseTwoB(tracks):
    # Returns the point list for each chain.
    print('Phase 2B - Chaining')
    
    traceList = list()
    for c in tqdm.trange(len(tracks[0]['chain'])):
        pointList = list()
        for t in range(len(tracks)):
            pointList.append(tracks[t]['points'][tracks[t]['chain'][c]])
        traceList.append(pointList)

    return traceList

# ---

def phaseTwoC(traces,dx,matchRange=None):
    # Consensus phase.
    # After calculating the apparent transformation and rotation from each frame to the current frame,
    # each frame votes on where it thinks each neuron should be.
    # If the position of the neuron in the current frame is more than 3SD outside of the consensus
    # opinion, then it is outvoted and replaced with the centroid of the votes.
    # This helps repair tracking errors from misidentification, occlusion or signal loss.

    ttr = list(zip(*traces))
    w = np.array(dx[::-1])
    print('Phase 2C - Consensus')

    for src in tqdm.trange(len(ttr)): 
        b = np.array(ttr[src]).astype('double')
        bm = b.mean(axis=0)
        bp = list()
        for dst in [t for t in range(len(ttr)) if matchRange is None or abs(t-src)<=matchRange]:
            a = np.array(ttr[dst]).astype('double')
            am = a.mean(axis=0)
            R = linalg.lstsq(a-am,b-bm,rcond=None)[0]
            bp.append((a-am)@R + bm)
        bp = np.array(bp)
        bpm = bp.mean(axis=0)
        bpr = linalg.norm(w*(bp-bpm),axis=2)
        blimit = 3 * np.std(bpr,axis=0)
        btest = linalg.norm(w*(b-bpm),axis=1)
        fixMe = btest>blimit
        b[fixMe,:] = np.round(bpm[fixMe,:])
        ttr[src] = tuple(tuple(x) for x in b.astype('int'))

    traces = list(zip(*ttr))
    return traces

# ---

def phaseTwoD(traces,wsize,sigma):
    # Smoothes traces in time to control for occasional glitches.
    x = np.arange(-wsize,wsize+1)
    y = stats.norm.pdf(x/sigma)
    y = y/y.sum()
    
    print('Phase 2D - Smoothing')    
    smoothed = list()
    for n in tqdm.trange(len(traces)):
        t = np.array(traces[n])
        t = ndimage.correlate1d(t,y,axis=0,mode='nearest')
        smoothed.append(tuple(tuple(x) for x in t))

    return smoothed

# ---

def phaseTwo(tracks,dx,nloci,nsources=None,bweight=0,headOrient='T',keyframe=0,seed=[],matchRange=None):
    phaseTwoA(tracks,nloci,nsources,bweight,headOrient,keyframe,seed,matchRange)
    traces = phaseTwoB(tracks)
    traces = phaseTwoC(traces,dx,matchRange)
    traces = phaseTwoD(traces,5,2)
    return traces

# ---

def phaseThreeJob(processArgs):
    t,manifests,shape,loci,shell,thresh,napier = processArgs
    center = [[0],[0],[0]]
    activityAtTime = [list(),list()]

    # Manifests are given as a (red,green) pair.
    # We read the RFP channel first because we need those values
    # to calculate the masking over the GCaMP.
    for i in [0,1]:

        V = ome.readVolumeWithManifest(manifests[i],shape)
        if V.shape!=shape:
            raise ValueError('Inconsistent stack sizes encountered.')

        # Implement masking if an RFP threshold was supplied.    
        if thresh is not None:
            if i==0:
                # Calculate a mask from the RFP data.
                if napier:
                    mask = V>np.exp(thresh)
                else:
                    mask = V>thresh
            else:
                # Put NaNs over masked points in the GCaMP data.
                V[mask] = np.nan
    
        for locus in loci:
            # Get the center point for RFP, but get the shell average for GCaMP.
            zsh,ysh,xsh = moveShell(shell if i==1 else center,locus,shape)
            readings = V[zsh,ysh,xsh]
            if readings.size == 0:
                activityAtTime[i].append(np.nan)
            else:
                # nanmean will return nan only if all values are nan
                activityAtTime[i].append(np.nanmean(readings))
            
    return (t,activityAtTime)

# ---

def phaseThree(omeFile,traces,dx,inner,outer,threshs=None,napier=True):
    # Used to extract along traces from the ome.tif file.
    # A list of thresholds can be supplied to screen out nuclear interference
    # in the GCaMP signal. Thresholds can be obtained from the tracks structure.
    # If thresholds are supplied, it must be a list as long as the number of time steps.
    
    shell = makeShell(dx,inner,outer)
    activity = list()
    lociAtTime = tuple(zip(*traces))
    omeStats = ome.index(omeFile)
    channels = ome.redGreen(omeFile)

    print('Phase 3 - Extracting')
    print('RFP is on channel %i, GCaMP is on channel %i.' % (channels[0],channels[1]))

    if threshs is None:
        threshs = [None] * len(lociAtTime)
    
    # Manifests are transmitted as a (red,green) pair.
    params = [(t,[ome.manifest(omeFile,omeStats,channels[0],t),ome.manifest(omeFile,omeStats,channels[1],t)],omeStats['size'][2:],lociAtTime[t],shell,threshs[t],napier) for t in range(len(lociAtTime))]
    with multiprocessing.Pool() as pool:
        for result in tqdm.tqdm(pool.imap_unordered(phaseThreeJob,params),total=len(params)):
            activity.append(result)
    activity = [x[1] for x in sorted(activity)]
    RFP = [x[0] for x in activity]
    GCaMP = [x[1] for x in activity]
    return {'GCaMP':tuple(zip(*GCaMP)), 'RFP':tuple(zip(*RFP))}

# ---
