import scipy.io as io
import numpy as np
import matplotlib.pyplot as plt
import tqdm
import ffmpeg
import shutil
import tempfile
import hdf5storage

# ---

def buildVideo(frameSet,traceSet,dx,r,spins,tag,headDir,pctwin,videoSet,videoSize,uniform=False):

    # Load the data for this sequence
    # Previously used io.loadmat, but for larger filesizes
    # the hdf5storage library version is necessary.
    imageData = hdf5storage.loadmat(frameSet)
    traceData = io.loadmat(traceSet)
    nframes = imageData['redFrames'].shape[2]
    dpi = 150

    # Correct trace geometry for 90 degree rotation.
    # Unpack the originalVolume, allowing for variances in formatting
    originalVolume = imageData['originalVolume']
    while len(originalVolume) == 1:
        originalVolume = originalVolume[0]
    traces = traceData['traces']
    trX = traces[:,:,2].copy()
    trY = traces[:,:,1].copy()
    if headDir%4 == 1:
        traces[:,:,1] = originalVolume[0] - trX
        traces[:,:,2] = trY
    elif headDir%4 == 2:
        traces[:,:,1] = originalVolume[0] - trY
        traces[:,:,2] = originalVolume[1] - trX            
    elif headDir%4 == 3:
        traces[:,:,1] = trX
        traces[:,:,2] = originalVolume[1] - trY

    # Shell geometry
    shx = (r/min(dx))*np.cos(np.linspace(0,2*np.pi,61))
    shy = (r/min(dx))*np.sin(np.linspace(0,2*np.pi,61))

    # Colorbar limits
    if uniform:
        redLimits = np.percentile(imageData['redFrames'],pctwin)
        greenLimits = np.percentile(imageData['greenFrames'],pctwin)

    # Render the frames to a temporary directory.
    with tempfile.TemporaryDirectory() as tmpdir:
        th = np.linspace(0,2*np.pi*spins,nframes,endpoint=False)
        for i in tqdm.trange(nframes):
            red = imageData['redFrames'][:,:,i]
            if not uniform:
                redLimits = np.percentile(red,pctwin) 
            red = (red-redLimits[0])/(redLimits[1]-redLimits[0])
            red = np.clip(red,0,1)
            green = imageData['greenFrames'][:,:,i]
            if not uniform:
                greenLimits = np.percentile(green,pctwin)
            green = (green-greenLimits[0])/(greenLimits[1]-greenLimits[0])
            green = np.clip(green,0,1)
            frame = np.dstack((red,green,np.zeros(red.shape)))

            fig = plt.figure()
            fig.set_size_inches((videoSize[0]/dpi,videoSize[1]/dpi))
            fig.set_facecolor((0.0,0.0,0.0))
            ax = plt.Axes(fig,[0.0,0.0,1.0,1.0])
            ax.set_axis_off()
            ax.set_facecolor((0.0,0.0,0.0))
            ax.imshow(frame, aspect='equal')
            ax.text(25,25,'Frame %i: %s' % (i+1,tag),fontsize='x-small',color=(1.0,1.0,1.0))
            fig.add_axes(ax)
            fig.savefig('{0}/{1}-untracked-frame{2:05d}.jpg'.format(tmpdir,tag,i+1),dpi=dpi,facecolor=(0.0,0.0,0.0))
            ax.autoscale(False)
            for n,t in enumerate(traces[:,i,:],1):
                x = t[2] - 0.5*originalVolume[1]
                z = (dx[2]/dx[0]) * (t[0] - 0.5*originalVolume[2])
                d = np.ceil(np.sqrt(originalVolume[1]**2 + ((dx[2]/dx[0])*originalVolume[2])**2))
                xp = 0.5*d+(x*np.cos(th[i]) - z*np.sin(th[i]))
                ax.text(xp,t[1],'%i'%n,fontsize='xx-small',color=(1.0,1.0,1.0),ha='center',va='center')
                ax.plot(shx+xp,shy+t[1],'w-',linewidth=0.25)
            fig.savefig('{0}/{1}-tracked-frame{2:05d}.jpg'.format(tmpdir,tag,i+1),dpi=dpi,facecolor=(0.0,0.0,0.0))
            plt.close(fig)

        # Convert frame files to an MP4.
        if videoSize[0]>videoSize[1]:
            mpSize = 'scale=%i:-1' % videoSize[0]
        else:
            mpSize = 'scale=-1:%i' % videoSize[1]
        
        try:
            ffmpeg \
                .input('{0}/{1}-untracked-frame%05d.jpg'.format(tmpdir,tag),r=40,f='image2')\
                .output('{0}-untracked.mp4'.format(videoSet),vf=mpSize,vcodec='libx264',crf=28)\
                .run(overwrite_output=True)
        except:
            print('Could not render {0}-untracked.mp4'.format(tag)) 

        try:
            ffmpeg \
                .input('{0}/{1}-tracked-frame%05d.jpg'.format(tmpdir,tag),r=40,f='image2')\
                .output('{0}-tracked.mp4'.format(videoSet),vf=mpSize,vcodec='libx264',crf=28)\
                .run(overwrite_output=True)
        except:
            print('Could not render {0}-tracked.mp4'.format(tag))

# ---
