# -*- coding: utf-8 -*-
#
# This source file is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This source file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division, print_function

import numpy as np
import scipy.stats as sst
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb

# colors dictionary for all plots
coldict = {'all': 'darkseagreen',
           'exc': 'steelblue',
           'excexc': 'steelblue',
           'excinh': 'orange',
           'inhexc': 'orange',
           'inhinh': '#FF1E1E',
           'inh': '#FF1E1E',
           'S1': 'steelblue',
           'S2': 'steelblue',
           'P1': 'cyan',
           'P2': 'cyan'}


# functions for loading # # # # # # # # # # # # # # # # # # # # # # # # # # # #
def ei_lists(consFile, eiThres):
    '''
    Defines putative excitatory and inhibitory units based on given consistency
    threshold.

    Parameters
    ----------
    consFile: str
        file with consistency values for all units
    eiThres: float
        defines width of tail of consistency distribution
        float [0, 0.5]
        for 0 does no differentiation

    Returns 2 lists of indices, 1st for excitatory units and 2nd for inhibitory.
    '''
    consistency = np.loadtxt(consFile, dtype=np.float16)

    iIds = np.where(consistency < eiThres)[0].tolist()
    eIds = np.where(consistency > 1 - eiThres)[0].tolist()

    return eIds, iIds


# functions for data analysis # # # # # # # # # # # # # # # # # # # # # # # # #
def dist_dep_xcov(ccm, dist, ids1=None, ids2=None, NT=None, minPairs=10):
    '''
    Calculate distance-resolved average and standard deviation of pairwise
    cross-covariance between neurons.

    If NT is defined, additionally applies a correction to the resulting
    standard deviation biased by finite number of measurements [1].

    Parameters
    ----------
    ccm : numpy array
        Square array of pairwise cross-covariances between neurons.
    dist : numpy array
        Array of the same shape as ccm, each entry being a physical distance
        between a respective pair of neurons.
    ids1 : 1D numpy array
        Indices of rows in ccm to be considered. If None (default), the whole
        ccm is used.
    ids2 : 1D numpy array
        Indices of columns in ccm to be considered. If None (default), takes
        the same value as ids1.
    NT : int
        Number of bins used during calculation of pairwise covariances, used
        for bias correction. If None (default), no correction is performed.
    minPairs : int
        Minimal number of pairs per distance to be considered for bias
        correction. If number of pairs in given distance is <minPairs,
        bias-corrected std is set to NaN.
        ###Minimal value for this parameter is 3.

    Returns
    -------
    xmu_xvar : numpy array
        Of dim: number_unique_distaces x 2; average (0th column) and variance
        (1st column) of cross-covariance per distance.
    dVec : numpy array
        Vector of unique distances in increasing order. Matches the order of
        distances in xmu_xvar.
    dCov : list
        List of lists of cross-covariance values per distance, ordered the same
        as dVec and xmu_xvar.

    References
    ----------
    [1] Dahmen D., Gruen, S., Diesmann, M., Helias, M. (2018). Two types of
    criticality in the brain. arXiv:1711.10930v2 [cond-mat.dis-nn]
    '''
    assert(ccm.shape == dist.shape)

    # distance-independent mean auto- and cross-covariance
    auto1, auto2, cross = mean_cov(ccm, ids1, ids2)
    # vector of all available distances
    dVec = np.sort(np.unique(dist.flatten()))
    # lists of all cross-covariance values per distance
    dCov = []
    # distance-dependent mu and std of cross-covariance
    xmu_xvar = np.zeros((len(dVec), 2))

    if ids2 is None:
        ids2 = ids1

    for i, d in enumerate(dVec):
        dIds = (dist == d)  # boolean array
        dIds[np.diag_indices_from(dIds)] = False  # no auto-cov

        # collect cross-covs for distance d
        if ids1 is not None:
            ndIds = dIds[ids1, :][:, ids2]
            cov = ccm[ids1, :][:, ids2][ndIds].flatten()
        else:
            cov = ccm[dIds].flatten()
        dCov.append(cov)

        # calculate biased mu and var of cross-cov
        xmu_xvar[i, 0] = np.nanmean(cov)
        xmu_xvar[i, 1] = np.nanvar(cov, ddof=1)

        # apply bias-correction to variance of cross-covariance
        if NT is not None:
            Npairs = len(cov)  # = N(N-1)/2

            if minPairs > 0:
                if Npairs >= minPairs:
                    unbiased = xmu_xvar[i, 1] / (1. - 1. / Npairs) \
                        - (auto1 * auto2 - cross ** 2) / (NT - 1.)
                else:
                    unbiased = np.nan
            else:
                unbiased = xmu_xvar[i, 1] \
                    - (auto1 * auto2 - cross ** 2) / (NT - 1.)
            print('Corrected experimental variance by {:0.5f}%'.format(
                 (1. - unbiased / xmu_xvar[i, 1]) * 100.))
            xmu_xvar[i, 1] = unbiased

    return xmu_xvar, dVec, dCov


def mean_cov(ccm, ids1=None, ids2=None):
    '''
    Calculate mean auto- and cross-covariance from an array of pairwise
    covariances.

    Parameters
    ----------
    ccm : numpy array
        Square array of pairwise cross-covariances.
    ids1 : 1D numpy array
        Indices of rows in ccm to be considered. If None (default), the whole
        ccm is used.
    ids2 : 1D numpy array
        Indices of columns in ccm to be considered. If None (default), takes
        the same columns as rows.

    Returns
    -------
    auto1, auto2 : float
        Average auto-covariances of units defined by ids1 and ids2
        respectively.
    cross : float
        Average cross-covariance of considered units.
       '''
    if ids1 is ids2:
        ids2 = None

    if ids1 is not None and ids2 is not None:
        auto1 = np.nanmean(np.diag(ccm[ids1, :][:, ids1]))
        auto2 = np.nanmean(np.diag(ccm[ids2, :][:, ids2]))
        cross = np.nanmean(ccm[ids1, :][:, ids2])
    elif ids1 is not None:
        auto1 = auto2 = np.nanmean(np.diag(ccm[ids1, :][:, ids1]))
        # ccm with elements below k-th diagonal zeroed
        upper = np.triu(ccm[ids1, :][:, ids1], k=1)
        cross = np.sum(upper) / (len(upper) * (len(upper) - 1) / 2.)
    else:
        auto1 = auto2 = np.nanmean(np.diag(ccm))
        # ccm with elements below k-th diagonal zeroed
        upper = np.triu(ccm, k=1)
        cross = np.sum(upper) / (len(upper) * (len(upper) - 1) / 2.)
    return auto1, auto2, cross


def fexp(x, a, l):
    return a * np.exp(-x / l)


def residuals1exp1err(to_fit, x, y, w=None):
    '''
    to_fit: [a1, a2, a3, l]
    '''
    modelled = np.zeros((3, len(x)))
    for i, a in enumerate(to_fit[:3]):
        modelled[i] = fexp(x, a, to_fit[-1])
    residuals = y - modelled

    not_considered_1 = np.where(y <= 0)
    not_considered_2 = np.where(np.isnan(y))
    residuals[not_considered_1] = 0
    residuals[not_considered_2] = 0
    if w is None:
        return np.hstack(residuals)
    else:
        return np.hstack(residuals * w)


def residuals3exp1err(to_fit, x, y, w=None):
    '''
    to_fit: a1, a2, a3, l1, l2, l3
    '''
    modelled = np.zeros((3, len(x)))
    for i in range(3):
        modelled[i] = fexp(x, to_fit[i], to_fit[i + 3])
    residuals = y - modelled

    not_considered_1 = np.where(y <= 0)
    not_considered_2 = np.where(np.isnan(y))
    residuals[not_considered_1] = 0
    residuals[not_considered_2] = 0
    if w is None:
        return np.hstack(residuals)
    else:
        return np.hstack(residuals * w)


# functions for fast statistical comparisons # # # # # # # # # # # # # # # # #
def compare2samples(x, y, alpha=0.05, paired=None, alternative='two-sided',
                    compare_var=True, nan_policy='propagate'):
    """
    Test equality of means (or medians) of two independent distributions.
    Equality of variance is tested as well by default.
    """
    message = 'Original size of samples: len(x)={}, len(y)={}.\n'.format(
        len(x), len(y))
    x = x[~np.isnan(x)]
    y = y[~np.isnan(y)]
    message += 'Non-nan values in samples: len(x)={}, len(y)={}.\n'.format(
        len(x), len(y))

    normal, mess = test_normality(x, y, alpha)
    message += mess

    if normal:
        message += test_with_normality(x, y, alpha, nan_policy)

    else:
        message += test_without_normality(x, y, alpha, paired, compare_var)

    return message


def test_normality(x, y, alpha=0.05):
    """
    Performs Shapiro-Wilk test to test normality of each sample.
    """
    message = ''
    message += '\n Test normality (Shapiro-Wilk test) \n'
    sx, px = sst.shapiro(x)
    sy, py = sst.shapiro(y)
    message += 'p-values: px={:.2e}, py={:.2e}.'.format(px, py)
    if any(np.array([px, py]) < alpha):
        message += '\nAt least one sample is not normally distributed.\n'
        return False, message
    else:
        message += '\nBoth distributions are normal!\n'
        return True, message


def test_with_normality(x, y, alpha=0.05, nan_policy='propagate'):
    """
    Tests equality of variances and
    performs appropriate version of t-test to compare means.

    nan_policy: str
        parameter of t-test, default 'propagate' returns nan when input
        contains nan, 'omit' ignores nan values and calculates statistics
    """
    message = '\nH0: two distributions have equal variance\n'

    message += 'F-test\n'
    F, p = F_test(x, y)
    message += 'Statistics and its p-value: F={:.2f}, p={:.2e}'.format(F, p)
    if p > alpha:
        equal = True
    else:
        equal = False

    message += '\nBartlett test\n'
    s, p = sst.bartlett(x, y)
    message += 'Statistics and its p-value: s={:.2f}, p={:.2e}'.format(s, p)

    message += '\nH0: Two independent distributions have the same expected' \
               'value\n'

    if equal:
        message += "Student's t-test\n"
    else:
        message += "Welch's t-test\n"
    # t, p = sst.ttest_ind(x, y, equal_var = equal, nan_policy = nan_policy)
    t, p = sst.ttest_ind(x, y, equal_var=equal)
    message += 'Statistics and its p-value: t={:.2f}, p={:.2e}'.format(t, p)
    return message


def test_without_normality(x, y, alpha=0.05, paired=None,
                           alternative='two-sided', compare_var=True):
    """
    Performs a number of non-parametric tests to compare the distributions.

    paired: boolean or None
        if True, perform Wilcoxon test,
        if False, perform Mann-Whitney test
        if None, do not perform any additional tests
    alternative: 'less', 'two-sided' or 'greater'
        parameter of the Mann-Whitney U-test
    compare_var: boolean
        if True, perform Levene and Flinger-Killeen tests on equality
        of variance
    """
    message = '\nH0: two samples are drawn from the same distribution\n'

    message += 'Kolmogorov-Smirnov test\n'
    D, p = sst.ks_2samp(x, y)
    message += 'Statistics and p-value: D={:.2f}, p={:.2e}'.format(D, p)

    if paired:
        message += '\nWilcoxon test\n'
        z, p = sst.wilcoxon(x, y)
        message += 'Statistics and its p-value: z={:.2f}, p={:.2e}'.format(z,
                                                                           p)
    elif not paired:
        message += '\nMann-Whitney U-test\n'
        # U, p = sst.mannwhitneyu(x, y, alternative=alternative)
        U, p = sst.mannwhitneyu(x, y)
        message += 'Statistics and its p-value: U={:.2f}, p={:.2e}'.format(U,
                                                                           p)

    message += '\nH0: two distributions have equal median\n'

    message += 'Kruskal-Wallis test\n'
    H, p = sst.kruskal(x, y)
    message += 'Statistics and p-value: H={:.2f}, p={:.2e}'.format(H, p)

    if compare_var:
        message += '\nH0: two distributions have equal variance'

        message += '\nLevene test\n'
        s, p = sst.levene(x, y)
        message += 'Statistics and its p-value: s={:.2f}, p={:.2e}'.format(s,
                                                                           p)

        message += '\nFligner-Killeen test\n'
        s, p = sst.fligner(x, y)
        message += 'Statistics and its p-value: s={:.2f}, p={:.2e}'.format(s,
                                                                           p)

    return message


def F_test(x, y):
    """
    Test for equality of variances.

    Assumes that x and y are normally distributed.
    Very sensitive to normality violation!
    """
    var_x = np.var(x, ddof=1)
    var_y = np.var(y, ddof=1)
    if var_x < var_y:
        F = var_x / var_y
        df1 = len(x) - 1
        df2 = len(y) - 1
    else:
        F = var_y / var_x
        df2 = len(x) - 1
        df1 = len(y) - 1
    p = sst.f.cdf(F, df1, df2)

    return F, p


# functions for beautiful and easy plotting # # # # # # # # # # # # # # # # # #
def mm2inch(n):
    return n / 25.4


def labels(ax, fontsize=16, title=None, xlabel=None, ylabel=None, zlabel=None,
           legend=True, color='k', xcolor='k', ycolor='k', zcolor='k',
           suptitle=None, fig=None, legendfontsize=None):
    """
    Manages contents and font sizes of most of textual elements of a matplotlib
    figure.

    To add a suptitle provide also a figure handle 'fig'.
    """
    if title:
        ax.set_title(title, fontsize=fontsize, color=color)  # +1

    if suptitle:
        if fig:
            fig.suptitle(suptitle, fontsize=fontsize + 1, color=color)
        else:
            print('Provide a figure handle to add a suptitle')

    if xlabel:
        ax.set_xlabel(xlabel, fontsize=fontsize, color=xcolor)

    if ylabel:
        ax.set_ylabel(ylabel, fontsize=fontsize, color=ycolor)

    if zlabel:
        ax.set_zlabel(zlabel, fontsize=fontsize, color=zcolor)

    if legend:
        ax.legend(loc='best', handlelength=1)

    if legendfontsize:
        ax.legend(loc='best', handlelength=1, fontsize=legendfontsize)

    if xcolor == ycolor:
        ax.tick_params('both', labelsize=fontsize - 1, color=xcolor)
    else:
        ax.tick_params('x', labelsize=fontsize - 1, colors=xcolor)
        ax.tick_params('y', labelsize=fontsize - 1, colors=ycolor)


def distribution(ax, data, bins=None, histtype='bar', color='darkseagreen',
                 lw=1., linestyle='-', normed=False, label=None):
    """
    Plots a histogram of data in given axes.

    If data is a multidimensional array, flattens it and plots the distribution
    of all (N) numbers.
    If bins are not given, divides the distribution into floor(sqrt(N)) bins.
    """

    if isinstance(data, np.ndarray):
        if len(data.shape) > 1:
            data = data.flatten()

    ids = np.where([not np.isnan(d) for d in data])[0]

    if bins is None:
        bins = int(np.floor(np.sqrt(np.size(data))))
        print(('size:', np.size(data), 'bins:', bins))

    if label:
        n, bins, _ = ax.hist(data[ids], bins=bins, histtype=histtype, lw=lw,
                             linestyle=linestyle, color=color,
                             density=normed, label=label)
    else:
        n, bins, _ = ax.hist(data[ids], bins=bins, histtype=histtype, lw=lw,
                             linestyle=linestyle, color=color,
                             density=normed)

    ax.set_xlim([min(bins), max(bins)])
    ax.set_ylim([0, max(n) * 1.1])

    return n, bins


def dynamics(ax, data, t=None, color='k', lw=1, ls='solid', label=False):
    """

    data: array-like
        1D vector or 2D array with 2 rows, in the latter case treats first
        row as an average and second row as SD dynamics
    t: array-like
        time vector; if None, numbers of samples will be given (0, 1, 2, 3...)
    """
    if not isinstance(data, np.ndarray):
        data = np.array(data)

    if t is None:
        if len(data.shape) == 1:
            t = np.arange(len(data))
        elif len(data.shape) == 2:
            t = np.arange(data.shape[1])
        else:
            print('Data has to be either 1D or 2D')
            return

    if len(data.shape) == 1:
        ax.plot(t, data, color, lw=lw, ls=ls, label=label)

    elif len(data.shape) == 2:
        ax.plot(t, data[0], color, lw=lw, ls='solid', label=label)
        a = 0.2
        old_rgba = np.asarray(to_rgb(color))  # tuple
        newcolor = tuple(1 - a + a * old_rgba)
        ax.fill_between(t, data[0] + data[1], data[0] - data[1],
                        color=newcolor)

    ax.set_xlim([t[0], t[-1]])


def plot_r2g(ax):
    '''
    Looks good in e.g. 8x2 figure.

    TS - trial start
    WS - warning signal
    CUE-ON, CUE-OFF, GO
    SR - switch release
    RW - reward
    '''
    xticks = [0, 4, 8, 11, 21, 22.5, 30]
    xlabels = ['TS', 'WS', 'CUE-ON', 'CUE-OFF', 'GO', 'SR', 'RW']
    xlim = [-1, xticks[-1] + 5]

    ax.set_xlim(xlim)
    ax.set_ylim([-1.5, 1.5])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)

    x = ax.spines['bottom']
    x.set_position(('data', 0))
    x.set_linewidth(0)

    ax.set_xticks(xticks)
    ax.set_xticklabels(xlabels)
    ax.xaxis.set_tick_params(width=5, direction='inout', length=20,
                             labelsize=12, labelrotation=90, pad=3)
    ax.arrow(xlim[0], 0, xlim[-1], 0, length_includes_head=True,
             width=0.05, head_width=0.15, head_length=1, fc='k')
    ax.plot([xticks[-3], xticks[-1] - 5], [0, 0], color='lightgrey',
            linewidth=2, linestyle=(0, (3, 2)))
    ax.plot([xticks[-1], xlim[-1] - 2], [0, 0], color='lightgrey',
            linewidth=2, linestyle=(0, (3, 2)))


def savef(fig, fname, formats, ifSave=True, latex=True, **kwargs):
    r"""
    Saves the figure with given name. If latex=True, adjusts the name to be
    compatible with latex \includegraphics{} command.

    fig:
        a figure handle (figure to save)
    fname: str
        name of a file (full directory) without extension
    formats:
        a list of strings indicating all formats in which the figure
        is to be saved (should work with png, pdf, ps, eps and svg,
        depending on an active backend)
        by default uses orientation = 'landscape' for ps
    ifSave: boolean
        if False, does not do anything
    latex: boolean
        if True, replaces all white spaces in fname with hyphens
        and periods with commas

    POTENTIAL PROBLEM: eps does not support transparencies (alpha < 1), so
    when saving as eps, all transparency info is lost and a picture is saved
    in full colors. Possible workarounds:
        - rasterize the plot before saving to prevent the transparency
        (fig.set_rasterized(True)); but this may worsen the quality of the image
        - save as pdf and than run in the terminal (under Linux):
            pdftops -eps your_image.pdf your_image.eps
        that should preserve the quality, but you have to do it by hand with
        each single figure
        - https://stackoverflow.com/questions/19638773/matplotlib-plots-lose-transparency-when-saving-as-ps-eps
    """
    if ifSave:
        if latex:
            # checks on latex \includegraphics compatibility
            fname = fname.replace(' ', '-')
            fname = fname.replace('.', ',')
        print(fname)

        # actual saving
        for f in formats:
            if f in ['png', 'pdf']:
                fig.savefig(fname + '.' + f, bbox_inches='tight',
                            frameon=True, **kwargs)
            elif f == 'eps':
                fig.savefig(fname + '.' + f, bbox_inches='tight',
                            frameon=False, **kwargs)
            else:
                fig.savefig(fname + '.' + f, **kwargs)
        plt.close(fig)
