# SCRIPT TO SEGMENT IMAGES OF DAPI STAINED CELL NUCLEI IN A CONFLUENT MONOLAYER,
# AND TO ASSESS THE ALIGNMENT BETWEEN CELLS AS A FUNCTION OF THEIR DISTANCE
# written by: Silvanus Alt, June 2017

from pylab import np, plt
import seaborn as sns
import pandas as pd
from skimage.io import imread
from skimage.filters import threshold_otsu, rank
from skimage.morphology import remove_small_objects, remove_small_holes, label, disk, binary_erosion
from skimage.measure import regionprops
from scipy.special import ellipe
from __future__ import division
import os
import time
from joblib import Parallel, delayed
import multiprocessing
from scipy.stats import binned_statistic, ttest_ind

sns.set_style('ticks')


def angle_between_nematics(n1, n2):
    v1_u = n1 / np.linalg.norm(n1)
    v2_u = n2 / np.linalg.norm(n2)
    alpha = np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
    alpha = np.mod(alpha, np.pi)
    alpha = np.min([alpha, np.pi - alpha])
    return alpha


def segment_nucleus_orientation(im_filename, erosion_width=3, threshold_prefactor=0.6, save_as_csv=True,
                                save_image=False):
    # load test file and prepare data
    # im_filename = '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/images/Max projection files/Exp1_170130_siTaz_5.tif'
    im = imread(im_filename)

    if len(np.shape(im)) == 2:
        im_nucleus = im
    elif len(np.shape(im)) == 3:
        im_nucleus = im[:, :, 2]
    else:
        im_nucleus = im[1, :, :, 2]

    # 1. FIND ALL NUCLEI
    # method 1: simple binarization
    # binarize and remove small objects and holes
    im_nucleus_bin = im_nucleus > (threshold_prefactor * threshold_otsu(im_nucleus))
    im_nucleus_bin = remove_small_objects(im_nucleus_bin, 40)
    im_nucleus_bin = remove_small_holes(im_nucleus_bin, 40)
    im_nucleus_bin_eroded = binary_erosion(im_nucleus_bin, selem=disk(erosion_width))

    # find all nuclei as connected regions using regionprops
    nuclei_props = regionprops(label(im_nucleus_bin_eroded))
    mps_nuclei = np.array([[nuc_prop.centroid[0], nuc_prop.centroid[1]] for nuc_prop in nuclei_props])

    # make a pd.DataFrame containing all relevant information for the individual cells
    df_cells = pd.DataFrame(mps_nuclei, columns=['midpoint_x', 'midpoint_y'])
    df_cells['image_name'] = im_filename

    df_cells['number'] = range(1, len(mps_nuclei) + 1)
    df_cells.set_index('number', inplace=True)
    df_cells['area_nuclei'] = np.array([nuc_prop.area for nuc_prop in nuclei_props])
    df_cells['orientation_nuclei'] = np.array([nuc_prop.orientation for nuc_prop in nuclei_props])
    df_cells['major_axis_length_nuclei'] = np.array([nuc_prop.major_axis_length for nuc_prop in nuclei_props])
    df_cells['minor_axis_length_nuclei'] = np.array([nuc_prop.minor_axis_length for nuc_prop in nuclei_props])
    df_cells['perimeter_ellipse_nucleus'] = [2 * f.major_axis_length * ellipe(f.eccentricity) for f in nuclei_props]
    df_cells['perimeter_nucleus'] = [f.perimeter for f in nuclei_props]
    df_cells['roughness'] = df_cells['perimeter_nucleus'] / df_cells['perimeter_ellipse_nucleus'] - 1.

    # remove nuclei which are too small
    nucleus_min_area = 100
    df_cells = df_cells[df_cells['area_nuclei'] > nucleus_min_area]

    if save_as_csv:
        df_cells.to_csv(im_filename[:-3] + 'csv')

    if save_image:

        f, axarr = plt.subplots(1, 1)
        axarr.imshow(im_nucleus)
        axarr.scatter(df_cells['midpoint_y'], df_cells['midpoint_x'], s=20, c='g', marker='o')

        for index, row in df_cells.iterrows():
            # get nucleus direction
            d_nucleus = 0.5 * row['major_axis_length_nuclei'] * np.array(
                [np.cos(row['orientation_nuclei'] - np.pi / 2), np.sin(row['orientation_nuclei'] - np.pi / 2)])
            axarr.plot([row['midpoint_y'] - d_nucleus[1],
                        row['midpoint_y'] + d_nucleus[1]],
                       [row['midpoint_x'] - d_nucleus[0],
                        row['midpoint_x'] + d_nucleus[0]],
                       linewidth=2, color='red')

        plt.savefig(im_filename[:-3] + '_seg.tif')
        plt.close()

    # return {'image_name':im_filename, 'cell_number':len(df_cells)}
    return df_cells


def delta_angles(phi1, phi2):
    phi_max = np.max([phi1, phi2])
    phi_min = np.min([phi1, phi2])

    return np.min([np.abs(phi_max - phi_min), np.abs(phi_min + np.pi - phi_max)]) / (np.pi / 2)


def dist_orientation_analysis(fn_csv):
    # fn_csv = '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/images/DAPI_channel_max_proj/Exp2_170227_siYap_5._seg.csv'
    df = pd.read_csv(fn_csv)[['image_name', 'midpoint_x', 'midpoint_y', 'orientation_nuclei']]

    if not len(df['image_name'].unique()) == 1:
        raise ('csv file contains data from more than one image!')

    dist_list = []

    for i in range(len(df)):
        for j in range(i):
            dist = np.sqrt((df.iloc[i]['midpoint_x'] - df.iloc[j]['midpoint_x']) ** 2 + (
                df.iloc[i]['midpoint_y'] - df.iloc[j]['midpoint_y']) ** 2)
            angle = 0.5 - delta_angles(df.iloc[i]['orientation_nuclei'], df.iloc[j]['orientation_nuclei'])

            dist_list.append([dist, angle])

    a = pd.DataFrame(dist_list, columns=['distance', 'angle'])
    a['image_name'] = df.iloc[0]['image_name']

    return a


def get_exp_type(fn):
    if '_siCTR_' in fn:
        return 'siCTR'
    elif '_siYap_' in fn:
        return 'siYap'
    elif '_siTaz_' in fn:
        return 'siTaz'
    elif '_siYapTaz_' in fn:
        return 'siYapTaz'
    else:
        return 'unknown'


def get_exp_number(fn):
    return fn[fn.rfind('/') + 1:fn.rfind('/') + 5]


def prepare_data():
    inputDir = '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/cell_alignment_analysis/images/DAPI_channel_max_proj'
    filename_contains = '.tif'
    pathname_contains = ''

    # get list of all relevant image files
    file_list_im = []
    for path, subdirs, files in os.walk(inputDir):
        if pathname_contains in path:
            for name in files:
                if filename_contains in name and not '._' in name:
                    file_list_im.append(os.path.join(path, name))

    # segment the images and save resulting nuclei positions and orientations as csv
    all_dist = Parallel(n_jobs=6)(
        delayed(segment_nucleus_orientation)(fn_, save_image=True, save_as_csv=False) for fn_ in file_list_im)

    time_vec = []
    for fn in file_list_im:
        start_time = time.time()

        fn_df = segment_nucleus_orientation(fn, save_image=True, save_as_csv=False)

        time_vec.append(time.time() - start_time)
        print str(np.sum(time_vec) / 60.) + ' min elapsed of approximately ' + str(
            np.mean(time_vec) * len(file_list_im) / 60.) + \
              ' min (= ' + str(np.sum(time_vec) / (np.mean(time_vec) * len(file_list_im)) * 100) + '%)'

    # all_df.to_csv('/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/images/DAPI_channel_max_proj/all_nuclei.csv')

    # get list of all csv files

    file_list_csv = []
    for path, subdirs, files in os.walk(inputDir):
        if pathname_contains in path:
            for name in files:
                if name.endswith('.csv'):
                    file_list_csv.append(os.path.join(path, name))

    # get angle-distance data and save them into one dataframe all_nuclei_dist_df
    num_cores = multiprocessing.cpu_count()
    all_dist = Parallel(n_jobs=6)(delayed(dist_orientation_analysis)(fn_) for fn_ in file_list_csv)
    all_nuclei_dist_df = pd.concat(all_dist)
    all_nuclei_dist_df['exp_type'] = [get_exp_type(fn_) for fn_ in all_nuclei_dist_df['image_name']]
    all_nuclei_dist_df['exp_number'] = [get_exp_number(fn_) for fn_ in all_nuclei_dist_df['image_name']]

    all_nuclei_dist_df.to_csv(
        '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/cell_alignment_analysis/images/all_dists.csv')


def basic_analysis():
    inputDir = '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/images/DAPI_channel_max_proj'
    filename_contains = '.tif'
    pathname_contains = ''

    # get list of all relevant image files
    file_list_im = []
    for path, subdirs, files in os.walk(inputDir):
        if pathname_contains in path:
            for name in files:
                if filename_contains in name and not '._' in name:
                    file_list_im.append(os.path.join(path, name))

    # segment the images and save resulting nuclei positions and orientations as csv
    all_dist = Parallel(n_jobs=6)(
        delayed(segment_nucleus_orientation)(fn_, save_image=False, save_as_csv=False) for fn_ in file_list_im)

    all_dist = pd.DataFrame(all_dist)
    all_dist['exp_type'] = [get_exp_type(fn_) for fn_ in all_dist['image_name']]
    all_dist['average_area'] = (1024. * 0.6251700) ** 2 / all_dist['cell_number']
    all_dist.groupby(['exp_type']).std()['average_area']
    all_dist.groupby(['exp_type']).count()

    plt.figure()
    sns.boxplot(y='cell_number', x='exp_type', data=all_dist)
    plt.ylim([0, 450])
    plt.ylabel('# cells per image')
    # plt.savefig('results/cell_number.pdf')

    plt.figure()
    sns.boxplot(y='average_area', x='exp_type', data=all_dist)
    plt.ylim([0, 2400])
    plt.ylabel('average cell area [um^2]')
    # plt.savefig('results/cell_areas.pdf')


# prepare_data()

all_nuclei_dist_df = pd.read_csv(
    '/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/cell_alignment_analysis/images/all_dists.csv')
all_nuclei_dist_df['exp_type'] = [get_exp_type(fn_) for fn_ in all_nuclei_dist_df['image_name']]
all_nuclei_dist_df['exp_number'] = [get_exp_number(fn_) for fn_ in all_nuclei_dist_df['image_name']]
# convert distances from pixel to micron
all_nuclei_dist_df['distance [um]'] = all_nuclei_dist_df['distance'] * 0.6251700
all_nuclei_dist_df['alignment'] = all_nuclei_dist_df['angle'] * 2
#
# wt_nuclei = all_nuclei_dist_df[all_nuclei_dist_df['exp_type'] == 'siCTR']
# yaptaz_nuclei = all_nuclei_dist_df[all_nuclei_dist_df['exp_type'] == 'siYapTaz']
# yap_nuclei = all_nuclei_dist_df[all_nuclei_dist_df['exp_type'] == 'siYap']
# taz_nuclei = all_nuclei_dist_df[all_nuclei_dist_df['exp_type'] == 'siTaz']
#
# bin_edges = range(20, 500, 20)
# bin_mps = (np.array(bin_edges[1:]) + np.array(bin_edges[:-1])) / 2
#
# mean_orientation_wt, _, _ = binned_statistic(x=wt_nuclei['distance [um]'], values=wt_nuclei['alignment'],
#                                              bins=bin_edges)
# mean_orientation_yt, _, _ = binned_statistic(x=yaptaz_nuclei['distance [um]'], values=yaptaz_nuclei['alignment'],
#                                              bins=bin_edges)
# mean_orientation_y, _, _ = binned_statistic(x=yap_nuclei['distance [um]'], values=yap_nuclei['alignment'],
#                                             bins=bin_edges)
# mean_orientation_t, _, _ = binned_statistic(x=taz_nuclei['distance [um]'], values=taz_nuclei['alignment'],
#                                             bins=bin_edges)
#
# std_orientation_wt, _, _ = binned_statistic(x=wt_nuclei['distance [um]'], values=wt_nuclei['alignment'],
#                                             bins=bin_edges, statistic=np.std)
# std_orientation_yt, _, _ = binned_statistic(x=yaptaz_nuclei['distance [um]'], values=yaptaz_nuclei['alignment'],
#                                             bins=bin_edges, statistic=np.std)
# std_orientation_y, _, _ = binned_statistic(x=yap_nuclei['distance [um]'], values=yap_nuclei['alignment'],
#                                            bins=bin_edges, statistic=np.std)
# std_orientation_t, _, _ = binned_statistic(x=taz_nuclei['distance [um]'], values=taz_nuclei['alignment'],
#                                            bins=bin_edges, statistic=np.std)
#
# count_orientation_wt, _, _ = binned_statistic(x=wt_nuclei['distance [um]'], values=wt_nuclei['alignment'],
#                                               bins=bin_edges, statistic=len)
# count_orientation_yt, _, _ = binned_statistic(x=yaptaz_nuclei['distance [um]'], values=yaptaz_nuclei['alignment'],
#                                               bins=bin_edges, statistic=len)
# count_orientation_y, _, _ = binned_statistic(x=yap_nuclei['distance [um]'], values=yap_nuclei['alignment'],
#                                              bins=bin_edges, statistic=len)
# count_orientation_t, _, _ = binned_statistic(x=taz_nuclei['distance [um]'], values=taz_nuclei['alignment'],
#                                              bins=bin_edges, statistic=len)
#
# sem_orientation_wt = std_orientation_wt / np.sqrt(count_orientation_wt)
# sem_orientation_yt = std_orientation_yt / np.sqrt(count_orientation_yt)
# sem_orientation_y = std_orientation_y / np.sqrt(count_orientation_y)
# sem_orientation_t = std_orientation_t / np.sqrt(count_orientation_t)
#
# # plt.figure()
# # plt.plot(bin_mps, mean_orientation_wt, c='g', label='wt')
# # plt.plot(bin_mps, mean_orientation_y, c='r', label='Yap')
# # plt.plot(bin_mps, mean_orientation_t, c='b', label='Taz')
# # plt.plot(bin_mps, mean_orientation_yt, c='purple', label='YapTaz')
# # plt.ylim([-0.01, 0.3])
# # plt.legend()
#
# sns.set_context('poster')
# plt.figure()
# plt.errorbar(x=bin_mps, y=mean_orientation_wt, yerr=std_orientation_wt, c='g', label='wt')
# plt.errorbar(x=bin_mps, y=mean_orientation_y, yerr=std_orientation_y, c='r', label='Yap')
# plt.errorbar(x=bin_mps, y=mean_orientation_t, yerr=std_orientation_t, c='b', label='Taz')
# plt.errorbar(x=bin_mps, y=mean_orientation_yt, yerr=std_orientation_yt, c='purple', label='YapTaz')
# # plt.errorbar(x=bin_mps, y=mean_orientation_wt, yerr=sem_orientation_wt, c='g', label='wt')
# # plt.errorbar(x=bin_mps, y=mean_orientation_y, yerr=sem_orientation_y, c='r', label='Yap')
# # plt.errorbar(x=bin_mps, y=mean_orientation_t, yerr=sem_orientation_t, c='b', label='Taz')
# # plt.errorbar(x=bin_mps, y=mean_orientation_yt, yerr=sem_orientation_yt, c='purple', label='YapTaz')
# plt.ylim([-0.03, 0.6])
# plt.xlabel('distance [um]')
# plt.ylabel('alignment')
# plt.legend()
# plt.savefig('results/alignment_plot_SD.pdf')
# plt.savefig('results/alignment_plot.pdf')
#
# plt.figure()
# plt.plot(bin_mps, count_orientation_wt, 'o', c='g', label='wt')
# plt.plot(bin_mps, count_orientation_y, 'o', c='r', label='Yap')
# plt.plot(bin_mps, count_orientation_t, 'o', c='b', label='Taz')
# plt.plot(bin_mps, count_orientation_yt, 'o', c='purple', label='YapTaz')
# # plt.xlim([0, 500])
# # plt.ylim([0, 50000])
# plt.xlabel('distance [um]')
# plt.ylabel('# cell pairs')
# plt.legend()
# plt.savefig('results/cell_pair_numbers.pdf')
#
# # # fit exponential curve to decay -> probably not a good assumption
# # wt_decay = np.mean(np.log(mean_orientation_wt[:15])/bin_mps[:15])
# # yt_decay = np.mean(np.log(mean_orientation_yt[:15])/bin_mps[:15])
# # y_decay = np.mean(np.log(mean_orientation_y[:15])/bin_mps[:15])
# # t_decay = np.mean(np.log(mean_orientation_t[:15])/bin_mps[:15])
# #
# # wt_fit = np.polyfit(np.log(mean_orientation_wt[:30]), bin_mps[:30], 1)
# # yt_fit = np.polyfit(np.log(mean_orientation_yt), bin_mps, 1)
# # t_fit = np.polyfit(np.log(mean_orientation_t), bin_mps, 1)
# # y_fit = np.polyfit(np.log(mean_orientation_y[:30]), bin_mps[:30], 1)
# #
# # plt.figure()
# # plt.plot(bin_mps, np.exp((bin_mps-wt_fit[0])/wt_fit[1]), c='g', label='wt')
# # plt.errorbar(x=bin_mps, y=mean_orientation_wt, yerr=sem_orientation_wt, c='g', label='wt')
# #
# # plt.plot(bin_mps, np.exp((bin_mps-yt_fit[0])/yt_fit[1]), c='purple', label='YapTaz')
# # plt.errorbar(x=bin_mps, y=mean_orientation_yt, yerr=sem_orientation_yt, c='purple')
# #
# # plt.plot(bin_mps, np.exp((bin_mps-t_fit[0])/t_fit[1]), c='b', label='Taz')
# # plt.errorbar(x=bin_mps, y=mean_orientation_t, yerr=sem_orientation_t, c='b')
# #
# # plt.plot(bin_mps, np.exp((bin_mps-y_fit[0])/y_fit[1]), c='r', label='Yap')
# # plt.errorbar(x=bin_mps, y=mean_orientation_y, yerr=sem_orientation_y, c='r')


all_grouped = all_nuclei_dist_df.groupby(['exp_type', 'exp_number'])

bin_edges = range(20, 500, 10)
bin_mps = (np.array(bin_edges[1:]) + np.array(bin_edges[:-1])) / 2

list_plots = []

color_dict = {'siCTR': [208 / 256., 208 / 256., 208 / 256.],
              'siYapTaz': [110 / 256., 110 / 256., 110 / 256.],
              'siTaz': [58 / 256., 58 / 256., 58 / 256.],
              'siYap': [0 / 256., 0 / 256., 0 / 256., ]}

plt.figure()
for name, df_group in all_nuclei_dist_df.groupby(['exp_type', 'exp_number']):
    mean_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'], bins=bin_edges)
    std_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'], bins=bin_edges,
                                             statistic=np.std)
    count_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'],
                                               bins=bin_edges, statistic=len)
    sem_orientation = std_orientation / np.sqrt(count_orientation)

    plt.errorbar(x=bin_mps, y=mean_orientation, yerr=sem_orientation, label=name, c=color_dict[name[0]])

plt.figure()
for name, df_group in all_nuclei_dist_df.groupby(['exp_type']):
    mean_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'], bins=bin_edges)
    std_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'], bins=bin_edges,
                                             statistic=np.std)
    count_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'],
                                               bins=bin_edges, statistic=len)
    sem_orientation = std_orientation / np.sqrt(count_orientation)

    plt.errorbar(x=bin_mps, y=mean_orientation, yerr=sem_orientation, label=name, c=color_dict[name])

plt.legend()

color_dict_opaque = {'siCTR': [208 / 256., 208 / 256., 208 / 256., 0.3],
                     'siYap': [110 / 256., 110 / 256., 110 / 256., 0.3],
                     'siTaz': [58 / 256., 58 / 256., 58 / 256., 0.3],
                     'siYapTaz': [0 / 256., 0 / 256., 0 / 256., 0.3]}

plt.figure()
for name, df_group in all_nuclei_dist_df.groupby(['exp_type']):
    mean_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'], bins=bin_edges)
    std_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'], bins=bin_edges,
                                             statistic=np.std)
    count_orientation, _, _ = binned_statistic(x=df_group['distance [um]'], values=df_group['alignment'],
                                               bins=bin_edges, statistic=len)
    sem_orientation = std_orientation / np.sqrt(count_orientation)

    # plt.fill_between(bin_mps, mean_orientation+sem_orientation, mean_orientation-sem_orientation,
    #                  color=color_dict_opaque[name])
    plt.plot(bin_mps, mean_orientation, label=name, c=color_dict[name], lw=2)

plt.legend()

plt.savefig('shaded_errorbars.pdf')


# check significance for a given distance
