import tifffile
import logging
import numpy as np
from lxml import etree
from collections import defaultdict
import os
import glob
import tqdm
import tictoc

logging.getLogger('tifffile').setLevel(logging.ERROR)

# ---

def getMetadata(omeFile):

    # See if there's an OMEXMLMetadata.ome file in the directory, and if so then read that.    
    dirname = os.path.dirname(omeFile)
    try:
        f = open(os.path.join(dirname,'OMEXMLMetadata.ome'),'rb')
        metadata = str(f.read(),'ISO-8859-1')
        f.close()
        
    except FileNotFoundError:
        # Read the metadata out of the leading tiff file.
        with tifffile.TiffFile(omeFile) as f:
            metadata = f.ome_metadata

    return metadata.encode()                    

# ---

def findAll(astr,sub):
    if len(sub) == 0: return
    start = 0
    while True:
        start = astr.find(sub,start)
        if start == -1: return
        yield start
        start += len(sub)
        
# ---

def patch(astr,index,repl):
    return astr[:index]+repl+astr[index+len(repl):]

# ---

def index(omeFile):
    root = etree.fromstring(getMetadata(omeFile))
    for element in root:
        if element.tag.endswith('Image'):
            # Figure out the size of the dataset and allocate storage.
            for pixels in element:
                if not pixels.tag.endswith('Pixels'):
                    continue
                attr = pixels.attrib
                ctzyxSize = tuple(int(attr['Size'+ax]) for ax in 'CTZYX')
                index = [[[None]*ctzyxSize[2] for c in range(ctzyxSize[0])] for t in range(ctzyxSize[1])]
                break

            # Parse the Channel and TiffData fields.
            names = []
            for data in pixels:
                if data.tag.endswith('Channel'):
                    attr = data.attrib
                    names.append(attr.get('Name',None))
                elif data.tag.endswith('TiffData'):
                    attr = data.attrib
                    ctz = tuple(int(attr['First'+ax]) for ax in 'CTZ')
                    ifd = int(attr['IFD'])
                    for uuid in data:
                        if uuid.tag.endswith('UUID'):
                            filename = uuid.attrib['FileName']
                            break
                    index[ctz[1]][ctz[0]][ctz[2]] = (filename,ifd)

            # Only process the first Image
            break

    return {'size':ctzyxSize, 'names':names, 'index':index}

# ---

def manifest(omeFile,idx,c,t):
    dirname = os.path.dirname(omeFile)
    retrieval = defaultdict(list)
    for z,location in enumerate(idx['index'][t][c]):
        retrieval[os.path.join(dirname,location[0])].append((location[1],z))
    return dict(retrieval)    

# ---

def readVolumeWithManifest(mf,sz):
    data = np.zeros((sz[0],sz[1],sz[2]),dtype='float')
    for fullname,planes in mf.items():
        with tifffile.TiffFile(fullname,is_ome=False,is_imagej=False) as f:
            for p in planes:
                data[p[1],:,:] = f.pages[p[0]].asarray()
    return data

# ---

def readVolumeWithWavelength(omeFile,wavelength):
    idx = index(omeFile)
    s = idx['size']
    channels = idx['names']
    for c,channelName in enumerate(channels):
        # Channel names from Micromanager are expected to be something like:
        # ['Left_Camera-405-i', 'Left_Camera-488-i', 'Left_Camera-561-i', 'Left_Camera-642-i']
        # So it's a question of matching the number on the right.
        # And then we want the volume at timestep 0 (which may be the only timestep).
        wl = None
        try:
            wl = int(channelName.split('-')[1])
        except:
            pass
        if wavelength == wl:
            return readVolumeWithManifest(manifest(omeFile,idx,c,0),(s[2],s[3],s[4]))
            
    return None

# ---

def redGreen(omeFile):
    # Get the correct channel numbers for RFP and GCaMP, usually for a 2-channel activity sequence.
    # First element of the output contains the RFP channel number.
    # Second element of the output contains the GCaMP channel number.
    idx = index(omeFile)
    channels = idx['names']
    output = [None, None]
    for c,channelName in enumerate(channels):
        if '561' in channelName:
            output[0] = c
        if '488' in channelName:
            output[1] = c
    return output

# ---

def readVolumeWithPartialName(omeFile,name):
    idx = index(omeFile)
    s = idx['size']
    channels = idx['names']
    for c,channelName in enumerate(channels):
        # Channel names from Micromanager are expected to be something like:
        # ['Left_Camera-405-NP-BFP','Left_Camera-488-NP-CyOFP','Left_Camera-642-NP-mNeptune','Left_Camera-561-RFP','Left_Camera-488-GCaMP']
        # So it's a question of finding the partial name in the channel name.
        # And then we want the volume at timestep 0 (which may be the only timestep).
        if name in channelName:
            return readVolumeWithManifest(manifest(omeFile,idx,c,0),(s[2],s[3],s[4]))
            
    return None

# ---

def readVolume(omeFile,c,t):
    idx = index(omeFile)
    s = idx['size']
    return readVolumeWithManifest(manifest(omeFile,idx,c,t),(s[2],s[3],s[4]))

# ---

def dataMax(omeFile):
    # Get size in CTZYX dimensions
    x = index(omeFile)
    omeSize = x['size']
    return omeSize[1:]
    
# ---

def nChannels(omeFile):
    # Get size in CTZYX dimensions
    x = index(omeFile)
    omeSize = x['size']
    return omeSize[0]

# ---

def channelNames(omeFile):
    x = index(omeFile)
    return x['names']

# ---

def master(dirname):
    files = sorted(glob.glob(os.path.join(dirname,'*.ome.tif')))
    return files[0]

# ---

def swapChannels(dirname):
    omeFile = os.path.realpath(master(dirname))
    print('Parsing {0}'.format(dirname))
    dirname,fname = os.path.split(omeFile)
    root = etree.fromstring(getMetadata(omeFile))
    xmlns = etree.QName(root.tag).namespace
    filenames = sorted(set(uuid.attrib['FileName'] for uuid in root.findall('.//{%s}UUID'%xmlns)))
    cdata = root.findall('.//{%s}TiffData'%xmlns)
    for i,tag in enumerate(cdata):
        tag.attrib['FirstC'] = str(1-int(tag.attrib['FirstC']))

    f = open(os.path.join(dirname,'OMEXMLMetadata.ome'),'wb')
    f.write(etree.tostring(root,xml_declaration=True))
    print('Done')
