# SCRIPT TO ANALYSE THE TYPE AND COLOUR OF JUNCTIONS, WHICH HAVE MANUALLY DETERMINED ON A RANDOMISED DATA SET AND SAVED IN AN EXCEL SHEET
# written by: Silvanus Alt, April 2017

from __future__ import division
from pylab import np, plt
import pandas as pd
import seaborn as sns
from seaborn.categorical import _CategoricalPlotter, remove_na
import matplotlib as mpl


def get_exp_type(fn):
    if 'siCTR' in fn:
        return 'siCTR'
    elif 'siYapTaz' in fn:
        return 'siYapTaz'
    else:
        return 'unknown'


def get_time_point(fn):
    if ' 2h ' in fn:
        return '2h'
    else:
        return '30min'


def type_numbers_to_types(tn):
    switcher = {0: 'nothing',
                1: 'straight',
                2: 'thick',
                3: 'thick JAIL',
                4: 'JAIL',
                5: 'fingers',
                6: 'dots'}

    if np.isnan(tn[1]):
        return [switcher.get(tn[0], "nothing")]
    else:
        return [switcher.get(tn[0], "nothing"), switcher.get(tn[1], "nothing")]


def colour_numbers_to_colours(tn):
    switcher = {1: 'green',
                2: 'yellow',
                3: 'red'}

    if np.isnan(tn[1]):
        return [switcher.get(tn[0], "nothing")]
    else:
        return [switcher.get(tn[0], "nothing"), switcher.get(tn[1], "nothing")]


class _CategoricalStatPlotter(_CategoricalPlotter):
    @property
    def nested_width(self):
        """A float with the width of plot elements when hue nesting is used."""
        return self.width / len(self.hue_names)

    def estimate_statistic(self, estimator, ci, n_boot):

        if self.hue_names is None:
            statistic = []
            confint = []
        else:
            statistic = [[] for _ in self.plot_data]
            confint = [[] for _ in self.plot_data]

        for i, group_data in enumerate(self.plot_data):
            # Option 1: we have a single layer of grouping
            # --------------------------------------------

            if self.plot_hues is None:

                if self.plot_units is None:
                    stat_data = remove_na(group_data)
                    unit_data = None
                else:
                    unit_data = self.plot_units[i]
                    have = pd.notnull(np.c_[group_data, unit_data]).all(axis=1)
                    stat_data = group_data[have]
                    unit_data = unit_data[have]

                # Estimate a statistic from the vector of data
                if not stat_data.size:
                    statistic.append(np.nan)
                else:
                    statistic.append(estimator(stat_data, len(np.concatenate(self.plot_data))))

                # Get a confidence interval for this estimate
                if ci is not None:

                    if stat_data.size < 2:
                        confint.append([np.nan, np.nan])
                        continue

                    boots = bootstrap(stat_data, func=estimator,
                                      n_boot=n_boot,
                                      units=unit_data)
                    confint.append(utils.ci(boots, ci))

            # Option 2: we are grouping by a hue layer
            # ----------------------------------------

            else:
                for j, hue_level in enumerate(self.hue_names):
                    if not self.plot_hues[i].size:
                        statistic[i].append(np.nan)
                        if ci is not None:
                            confint[i].append((np.nan, np.nan))
                        continue

                    hue_mask = self.plot_hues[i] == hue_level
                    group_total_n = (np.concatenate(self.plot_hues) == hue_level).sum()
                    if self.plot_units is None:
                        stat_data = remove_na(group_data[hue_mask])
                        unit_data = None
                    else:
                        group_units = self.plot_units[i]
                        have = pd.notnull(
                            np.c_[group_data, group_units]
                        ).all(axis=1)
                        stat_data = group_data[hue_mask & have]
                        unit_data = group_units[hue_mask & have]

                    # Estimate a statistic from the vector of data
                    if not stat_data.size:
                        statistic[i].append(np.nan)
                    else:
                        statistic[i].append(estimator(stat_data, group_total_n))

                    # Get a confidence interval for this estimate
                    if ci is not None:

                        if stat_data.size < 2:
                            confint[i].append([np.nan, np.nan])
                            continue

                        boots = bootstrap(stat_data, func=estimator,
                                          n_boot=n_boot,
                                          units=unit_data)
                        confint[i].append(utils.ci(boots, ci))

        # Save the resulting values for plotting
        self.statistic = np.array(statistic)
        self.confint = np.array(confint)

        # Rename the value label to reflect the estimation
        if self.value_label is not None:
            self.value_label = "{}({})".format(estimator.__name__,
                                               self.value_label)

    def draw_confints(self, ax, at_group, confint, colors,
                      errwidth=None, capsize=None, **kws):

        if errwidth is not None:
            kws.setdefault("lw", errwidth)
        else:
            kws.setdefault("lw", mpl.rcParams["lines.linewidth"] * 1.8)

        for at, (ci_low, ci_high), color in zip(at_group,
                                                confint,
                                                colors):
            if self.orient == "v":
                ax.plot([at, at], [ci_low, ci_high], color=color, **kws)
                if capsize is not None:
                    ax.plot([at - capsize / 2, at + capsize / 2],
                            [ci_low, ci_low], color=color, **kws)
                    ax.plot([at - capsize / 2, at + capsize / 2],
                            [ci_high, ci_high], color=color, **kws)
            else:
                ax.plot([ci_low, ci_high], [at, at], color=color, **kws)
                if capsize is not None:
                    ax.plot([ci_low, ci_low],
                            [at - capsize / 2, at + capsize / 2],
                            color=color, **kws)
                    ax.plot([ci_high, ci_high],
                            [at - capsize / 2, at + capsize / 2],
                            color=color, **kws)


class _BarPlotter(_CategoricalStatPlotter):
    """Show point estimates and confidence intervals with bars."""

    def __init__(self, x, y, hue, data, order, hue_order,
                 estimator, ci, n_boot, units,
                 orient, color, palette, saturation, errcolor, errwidth=None,
                 capsize=None):
        """Initialize the plotter."""
        self.establish_variables(x, y, hue, data, orient,
                                 order, hue_order, units)
        self.establish_colors(color, palette, saturation)
        self.estimate_statistic(estimator, ci, n_boot)

        self.errcolor = errcolor
        self.errwidth = errwidth
        self.capsize = capsize

    def draw_bars(self, ax, kws):
        """Draw the bars onto `ax`."""
        # Get the right matplotlib function depending on the orientation
        barfunc = ax.bar if self.orient == "v" else ax.barh
        barpos = np.arange(len(self.statistic))

        if self.plot_hues is None:

            # Draw the bars
            barfunc(barpos, self.statistic, self.width,
                    color=self.colors, align="center", **kws)

            # Draw the confidence intervals
            errcolors = [self.errcolor] * len(barpos)
            self.draw_confints(ax,
                               barpos,
                               self.confint,
                               errcolors,
                               self.errwidth,
                               self.capsize)

        else:

            for j, hue_level in enumerate(self.hue_names):

                # Draw the bars
                offpos = barpos + self.hue_offsets[j]
                barfunc(offpos, self.statistic[:, j], self.nested_width,
                        color=self.colors[j], align="center",
                        label=hue_level, **kws)

                # Draw the confidence intervals
                if self.confint.size:
                    confint = self.confint[:, j]
                    errcolors = [self.errcolor] * len(offpos)
                    self.draw_confints(ax,
                                       offpos,
                                       confint,
                                       errcolors,
                                       self.errwidth,
                                       self.capsize)

    def plot(self, ax, bar_kws):
        """Make the plot."""
        self.draw_bars(ax, bar_kws)
        self.annotate_axes(ax)
        if self.orient == "h":
            ax.invert_yaxis()


def percentageplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
                   orient=None, color=None, palette=None, saturation=.75,
                   ax=None, **kwargs):
    estimator = lambda x, y: float(len(x) / y) * 100
    ci = None
    n_boot = 0
    units = None
    errcolor = None

    if x is None and y is not None:
        orient = "h"
        x = y
    elif y is None and x is not None:
        orient = "v"
        y = x
    elif x is not None and y is not None:
        raise TypeError("Cannot pass values for both `x` and `y`")
    else:
        raise TypeError("Must pass values for either `x` or `y`")

    plotter = _BarPlotter(x, y, hue, data, order, hue_order,
                          estimator, ci, n_boot, units,
                          orient, color, palette, saturation,
                          errcolor)

    plotter.value_label = "Percentage"

    if ax is None:
        ax = plt.gca()

    plotter.plot(ax, kwargs)
    return ax


def data_preparation(xls_sheet_name, csv_sheetname):
    # read the excel sheet with the count
    df_count = pd.read_excel(xls_sheet_name, header=9)
    df_count.set_index('patch_number')
    # read the csv sheet with the information about the class of each patch
    df_patches_groups = pd.read_csv(csv_sheetname)
    df_patches_groups.set_index('patch_number')

    # join the DataFrames
    df_count = pd.merge(df_count, df_patches_groups)

    # transform the numbers into list of type names
    df_count['types'] = [type_numbers_to_types([tn_['Junction Morphology'], tn_['Junction Morphology.1']]) for i, tn_ in
                         df_count.iterrows()]
    df_count['colours'] = [colour_numbers_to_colours([tn_['Colour'], tn_['Colour.1']]) for i, tn_ in
                           df_count.iterrows()]
    df_count['exp_type'] = [get_exp_type(tn_['image_name']) for i, tn_ in df_count.iterrows()]
    df_count['time_point'] = [get_time_point(tn_['image_name']) for i, tn_ in df_count.iterrows()]

    # remove patches without any cell junctions
    df_count = df_count[~ np.isnan(df_count['Junction Morphology'])]
    df_count = df_count[[jm[0] is not 'nothing' for jm in df_count['types']]]
    df_count = df_count[[jm[0] is not 'nothing' for jm in df_count['colours']]]

    # create dataset with all the junctions instead of all the patches
    df_junctions = pd.DataFrame(columns=['exp_type', 'time_point', 'colour', 'type'])
    for i, patch in df_count.iterrows():
        # print len(patch['types'])
        if len(patch['types']) == 1:
            df_junctions = df_junctions.append(pd.DataFrame(
                {'exp_type': [patch['exp_type']], 'time_point': [patch['time_point']], 'colour': [patch['colours'][0]],
                 'type': [patch['types'][0]]}), ignore_index=True)
        elif len(patch['types']) == 2:
            df_junctions = df_junctions.append(pd.DataFrame({'exp_type': [patch['exp_type'], patch['exp_type']],
                                                             'time_point': [patch['time_point'], patch['time_point']],
                                                             'colour': patch['colours'], 'type': patch['types'][0]}),
                                               ignore_index=True)

    return df_junctions


df_junctions_Exp1 = data_preparation(
    xls_sheet_name='/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/junction_analysis/Ve-Cad dynamics - 30 min VeCad647 pulse/unordered_patches_Exp1/VeCaddynamics_08052017.xls',
    csv_sheetname='/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/junction_analysis/Ve-Cad dynamics - 30 min VeCad647 pulse/unordered_patches_Exp1/patch_info.csv')
df_junctions_Exp23 = data_preparation(
    xls_sheet_name='/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/junction_analysis/Ve-Cad dynamics - 30 min VeCad647 pulse/unordered_patches_Exp23/YapTaz knockdown - VeCad morphology patches - Exp 2 and 3.xlsx',
    csv_sheetname='/Users/silvanus/Dropbox/angiogenesis/Yap_Taz/junction_analysis/Ve-Cad dynamics - 30 min VeCad647 pulse/unordered_patches_Exp23/patch_info.csv')

# define and order interesting subsets of junction colours and types
important_junction_types = ['straight', 'thick', 'thick JAIL', 'JAIL', 'fingers']
all_colours = ['green', 'yellow', 'red']
important_colours = all_colours

f, axarr = plt.subplots(2, 1)
percentageplot(data=df_junctions[df_junctions['time_point'] == '30min'], x='colour', hue='exp_type',
               order=['red', 'yellow', 'green'], hue_order=['siCTR', 'siYapTaz'],
               palette={'siCTR': 'green', 'siYapTaz': 'red'}, ax=axarr[0])
percentageplot(data=df_junctions[df_junctions['time_point'] == '2h'], x='colour', hue='exp_type',
               order=['red', 'yellow', 'green'], hue_order=['siCTR', 'siYapTaz'],
               palette={'siCTR': 'green', 'siYapTaz': 'red'}, ax=axarr[1])
axarr[0].set_title('30min TP')
axarr[0].set_xlabel('')
axarr[1].set_title('2h TP')
axarr[1].set_xlabel('')
plt.savefig('colour_percentages.pdf')

f, axarr = plt.subplots(2, 1)
percentageplot(data=df_junctions[df_junctions['exp_type'] == 'siCTR'], x='colour', hue='time_point',
               order=['red', 'yellow', 'green'], hue_order=['30min', '2h'], ax=axarr[0],
               palette={'30min': 'lightgray', '2h': 'darkgray'})
percentageplot(data=df_junctions[df_junctions['exp_type'] == 'siYapTaz'], x='colour', hue='time_point',
               order=['red', 'yellow', 'green'], hue_order=['30min', '2h'], ax=axarr[1],
               palette={'30min': 'lightgray', '2h': 'darkgray'})
axarr[0].set_title('siCTR')
axarr[0].set_xlabel('')
axarr[1].set_title('siYapTaz')
axarr[1].set_xlabel('')
plt.savefig('colour_percentage_change.pdf')

f, axarr = plt.subplots(2, 1)
percentageplot(data=df_junctions[df_junctions['time_point'] == '30min'], x='type', hue='exp_type',
               order=important_junction_types, hue_order=['siCTR', 'siYapTaz'],
               palette={'siCTR': 'green', 'siYapTaz': 'red'}, ax=axarr[0])
percentageplot(data=df_junctions[df_junctions['time_point'] == '2h'], x='type', hue='exp_type',
               order=important_junction_types, hue_order=['siCTR', 'siYapTaz'],
               palette={'siCTR': 'green', 'siYapTaz': 'red'}, ax=axarr[1])
axarr[0].set_title('30min TP')
axarr[0].set_xlabel('')
axarr[1].set_title('2h TP')
axarr[1].set_xlabel('')
plt.savefig('shape_percentages.pdf')

f, axarr = plt.subplots(2, 1)
percentageplot(data=df_junctions[df_junctions['exp_type'] == 'siCTR'], x='type', hue='time_point',
               order=important_junction_types, hue_order=['30min', '2h'], ax=axarr[0],
               palette={'30min': 'lightgray', '2h': 'darkgray'})
percentageplot(data=df_junctions[df_junctions['exp_type'] == 'siYapTaz'], x='type', hue='time_point',
               order=important_junction_types, hue_order=['30min', '2h'], ax=axarr[1],
               palette={'30min': 'lightgray', '2h': 'darkgray'})
axarr[0].set_title('siCTR')
axarr[0].set_xlabel('')
axarr[1].set_title('siYapTaz')
axarr[1].set_xlabel('')
plt.savefig('shape_percentage_shape.pdf')

# do mixed analysis - shape & colour

f, axarr = plt.subplots(2, 2)
percentageplot(
    data=df_junctions[np.array(df_junctions['time_point'] == '30min') * np.array(df_junctions['exp_type'] == 'siCTR')],
    x='type', hue='colour', order=important_junction_types, hue_order=['red', 'yellow', 'green'],
    palette={'green': 'green', 'red': 'red', 'yellow': 'yellow'}, ax=axarr[0, 0])
percentageplot(
    data=df_junctions[np.array(df_junctions['time_point'] == '2h') * np.array(df_junctions['exp_type'] == 'siCTR')],
    x='type', hue='colour', order=important_junction_types, hue_order=['red', 'yellow', 'green'],
    palette={'green': 'green', 'red': 'red', 'yellow': 'yellow'}, ax=axarr[1, 0])
axarr[0, 0].set_title('30min TP - siCTR')
axarr[0, 0].set_xlabel('')
axarr[1, 0].set_title('2h TP - siCTR')
axarr[1, 0].set_xlabel('')
percentageplot(data=df_junctions[
    np.array(df_junctions['time_point'] == '30min') * np.array(df_junctions['exp_type'] == 'siYapTaz')], x='type',
               hue='colour', order=important_junction_types, hue_order=['red', 'yellow', 'green'],
               palette={'green': 'green', 'red': 'red', 'yellow': 'yellow'}, ax=axarr[0, 1])
percentageplot(
    data=df_junctions[np.array(df_junctions['time_point'] == '2h') * np.array(df_junctions['exp_type'] == 'siYapTaz')],
    x='type', hue='colour', order=important_junction_types, hue_order=['red', 'yellow', 'green'],
    palette={'green': 'green', 'red': 'red', 'yellow': 'yellow'}, ax=axarr[1, 1])
axarr[0, 1].set_title('30min TP - siYapTaz')
axarr[0, 1].set_xlabel('')
axarr[1, 1].set_title('2h TP - siYapTaz')
axarr[1, 1].set_xlabel('')
plt.tight_layout()
plt.savefig('colour_and_shape.pdf')

f, axarr = plt.subplots(2, 2)
percentageplot(
    data=df_junctions[np.array(df_junctions['time_point'] == '30min') * np.array(df_junctions['exp_type'] == 'siCTR')],
    x='colour', hue='type', hue_order=important_junction_types, order=['red', 'yellow', 'green'], ax=axarr[0, 0])
percentageplot(
    data=df_junctions[np.array(df_junctions['time_point'] == '2h') * np.array(df_junctions['exp_type'] == 'siCTR')],
    x='colour', hue='type', hue_order=important_junction_types, order=['red', 'yellow', 'green'], ax=axarr[1, 0])
axarr[0, 0].set_title('30min TP - siCTR')
axarr[0, 0].set_xlabel('')
axarr[1, 0].set_title('2h TP - siCTR')
axarr[1, 0].set_xlabel('')
percentageplot(data=df_junctions[
    np.array(df_junctions['time_point'] == '30min') * np.array(df_junctions['exp_type'] == 'siYapTaz')], x='colour',
               hue='type', hue_order=important_junction_types, order=['red', 'yellow', 'green'], ax=axarr[0, 1])
percentageplot(
    data=df_junctions[np.array(df_junctions['time_point'] == '2h') * np.array(df_junctions['exp_type'] == 'siYapTaz')],
    x='colour', hue='type', hue_order=important_junction_types, order=['red', 'yellow', 'green'], ax=axarr[1, 1])
axarr[0, 1].set_title('30min TP - siYapTaz')
axarr[0, 1].set_xlabel('')
axarr[1, 1].set_title('2h TP - siYapTaz')
axarr[1, 1].set_xlabel('')
plt.tight_layout()
plt.savefig('colour_and_shape_2.pdf')

junction_types_numbers = dict(zip(df_junctions['type'].unique(), range(len(df_junctions['type'].unique()))))
junction_colours_numbers = dict(zip(df_junctions['colour'].unique(), range(len(df_junctions['colour'].unique()))))

all_counts_30min = {'siCTR': np.zeros([len(df_junctions['type'].unique()), len(df_junctions['colour'].unique())]),
                    'siYapTaz': np.zeros([len(df_junctions['type'].unique()), len(df_junctions['colour'].unique())])}
all_counts_2h = {'siCTR': np.zeros([len(df_junctions['type'].unique()), len(df_junctions['colour'].unique())]),
                 'siYapTaz': np.zeros([len(df_junctions['type'].unique()), len(df_junctions['colour'].unique())])}

for i, patch in df_junctions[df_junctions['time_point'] == '30minh'].iterrows():
    all_counts_30min[patch['exp_type']][
        junction_types_numbers[patch['type']], junction_colours_numbers[patch['colour']]] += 1

for i, patch in df_junctions[df_junctions['time_point'] == '2h'].iterrows():
    all_counts_2h[patch['exp_type']][
        junction_types_numbers[patch['type']], junction_colours_numbers[patch['colour']]] += 1

plt.matshow(all_counts_2h['siCTR'])
plt.xticks(np.array(range(len(junction_colours_numbers.keys()))), junction_colours_numbers.keys())
plt.yticks(np.array(range(len(junction_types_numbers.keys()))), junction_types_numbers.keys())
plt.colorbar()
