# SCRIPT TO READ ALL IMAGE FILES IN A DIRECTORY, TO CREATE A SUBFOLDER, PLACE RANDOM PATCHES OF ALL IMAGES WITHIN THIS SUBFOLDER,
# AND TO SAVE THE ORIGIN OF THE RANDOM PATCHES IN A CSV FILE FOR LATER ANALYSIS
# written by: Silvanus Alt, April 2017


from __future__ import division
from pylab import np
from skimage.io import imread, imshow
import scipy.misc
import os
import scipy.misc
from joblib import Parallel, delayed
import random
import shutil
from skimage.exposure import equalize_adapthist
import itertools
import csv


def clean_up(dir_name):
    # remove all folders in directory dir_name that end with '_patched'
    for path, subdirs, files in os.walk(dir_name):
        if path.endswith('_patches'):
            shutil.rmtree(path)


def patch_image(image_filename, N_patch=50, lx_patch=50, ly_patch=50):
    # box size should be odd so that they have an integer mid-point
    if np.mod(lx_patch, 2) == 0:
        lx_patch = lx_patch + 1
    if np.mod(ly_patch, 2) == 0:
        ly_patch = ly_patch + 1

    dx = int((lx_patch - 1) / 2)
    dy = int((ly_patch - 1) / 2)

    image_filename_short = image_filename[image_filename.rfind('/') + 1:]

    # load the image stored in file (junctions are stored in second channel)
    im_whole = imread(image_filename)

    if len(np.shape(im_whole)) == 4:
        im_whole = im_whole[1, :, :, 0]
    else:
        im_whole = im_whole[1, :, :]

    # increase the contrast using CLAHE
    im_whole = equalize_adapthist(im_whole)

    # get image dimensions
    lx_im = np.shape(im_whole)[0]
    ly_im = np.shape(im_whole)[1]

    # initialize pseudo random number generator
    random.seed()

    # find N_patch random midpoints around which to build the patches
    # these midpoints can be in the intervals defined in patch_limits:
    patch_limits = np.array([[dx, lx_im - dx - 1],
                             [dy, ly_im - dy - 1]])

    patch_mps = np.array([[random.randint(patch_limits[0][0], patch_limits[0][1]) for i in range(N_patch)],
                          [random.randint(patch_limits[0][0], patch_limits[0][1]) for i in range(N_patch)]])

    patch_list = []

    for i in range(N_patch):
        patch_i = equalize_adapthist(im_whole[patch_mps[0, i] - dx:patch_mps[0, i] + dx,
                                     patch_mps[1, i] - dy:patch_mps[1, i] + dy])

        patch_list.append({'image_name': image_filename_short, 'patch_number': i, 'patch': patch_i})

    return patch_list


def patch_image_rgb(image_filename, N_patch=50, lx_patch=50, ly_patch=50, CLAHE=True):
    # box size should be odd so that they have an integer mid-point
    if np.mod(lx_patch, 2) == 0:
        lx_patch = lx_patch + 1
    if np.mod(ly_patch, 2) == 0:
        ly_patch = ly_patch + 1

    dx = int((lx_patch - 1) / 2)
    dy = int((ly_patch - 1) / 2)

    image_filename_short = image_filename[image_filename.rfind('/') + 1:]

    # load the image stored in file (junctions are stored in second channel)
    im_whole = imread(image_filename)

    # increase the contrast using CLAHE
    im_whole = equalize_adapthist(im_whole)

    # get image dimensions
    lx_im = np.shape(im_whole)[0]
    ly_im = np.shape(im_whole)[1]

    # initialize pseudo random number generator
    random.seed()

    # find N_patch random midpoints around which to build the patches
    # these midpoints can be in the intervals defined in patch_limits:
    patch_limits = np.array([[dx, lx_im - dx - 1],
                             [dy, ly_im - dy - 1]])

    patch_mps = np.array([[random.randint(patch_limits[0][0], patch_limits[0][1]) for i in range(N_patch)],
                          [random.randint(patch_limits[0][0], patch_limits[0][1]) for i in range(N_patch)]])

    patch_list = []

    for i in range(N_patch):
        if CLAHE:  # enhance contrast
            patch_i = equalize_adapthist(im_whole[patch_mps[0, i] - dx:patch_mps[0, i] + dx,
                                         patch_mps[1, i] - dy:patch_mps[1, i] + dy, :])
        else:
            patch_i = im_whole[patch_mps[0, i] - dx:patch_mps[0, i] + dx,
                      patch_mps[1, i] - dy:patch_mps[1, i] + dy, :]

        patch_list.append({'image_name': image_filename_short, 'patch_number': i, 'patch': patch_i})

    return patch_list


def generate_random_patches(folder_name, N_patches_per_image=100, lx_patch=100, ly_patch=100,
                            patch_folder='unordered_patches', CLAHE=True):
    # start by cleaning up existing patch folder
    if os.path.exists(folder_name + '/' + patch_folder):
        shutil.rmtree(folder_name + '/' + patch_folder)
    os.makedirs(folder_name + '/' + patch_folder)

    # get list of image names in top dir
    file_list_im = []
    for path, subdirs, files in os.walk(folder_name):
        for name in files:
            if name.endswith('.tif'):
                file_list_im.append(os.path.join(path, name))

    # run the patching for the images in parallel to save N_patch
    all_patches = Parallel(n_jobs=4)(
        delayed(patch_image_rgb)(fn_, N_patch=N_patches_per_image, lx_patch=lx_patch, ly_patch=ly_patch, CLAHE=CLAHE)
        for fn_ in
        file_list_im)

    # put them all in one list
    all_patches = list(itertools.chain.from_iterable(all_patches))
    N_patch = len(all_patches)

    # shuffle the list
    random.shuffle(all_patches)

    # create the file patch_info.csv containing all the data
    finfo = open(folder_name + '/' + patch_folder + '/patch_info.csv', 'a')
    info_writer = csv.writer(finfo)
    info_writer.writerow(['patch_number', 'image_name', 'patch_number_in_image'])
    # save the patches stored in list to patch_folder
    for no, patch in enumerate(all_patches):
        patch_im = patch['patch']
        scipy.misc.imsave(folder_name + '/' + patch_folder + '/patch_' + str(no).zfill(len(str(N_patch))) + '.tiff',
                          patch_im)

        info_writer.writerow([no, patch['image_name'], patch['patch_number']])


# set top directory
# generate_random_patches('/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/junction_images', N_patches_per_image=100)

generate_random_patches(
    '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/junction_analysis/wetransfer-5281e6',
    N_patches_per_image=100, patch_folder='unordered_patches_no_clahe', CLAHE=False)

im = imread(
    '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/junction_analysis/wetransfer-5281e6')
