# -*- coding: utf-8 -*-
#
# This source file is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This source file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import matplotlib.pyplot as plt
import math


def sap_plt(ax,
            z_val,
            ExId,
            InId,
            positions,
            reference_neuron=None,
            plt_cb=False,
            cbarlabel=None,
            vmin=None,
            vmax=None,
            # Define Plot-Appearance
            plt_SensID=True,
            plt_snswdth=True,
            ticksensor=False,
            plot_size=True,
            linewidth=0.5,
            markersize=100,
            min_msize=20,
            fontsize=10,
            cmap='PiYG',
            c_edge=(0.8, 0.8, 0.8),
            sens_p_row=10,
            sens_p_col=10,
            x_up_lim=4000,
            x_lo_lim=0,
            y_up_lim=4000,
            y_lo_lim=0,
            verbose=False,
            cb_shrink=1
            ):
    """
    Scatter plot showing covariance or other z-values of neurons on 10 by 10
    sensor array.
    The neurons are ordered by neuron ID in each sensor field in the following
    order:
    1. center
    2. upper right
    3. lower right
    4. lower left
    5. upper left.
    Maximal number of neurons per sensor field is five.
    The z-values for each neuron are coded by color of the plotted markers.
    Additionally, all absolute z-values can be depicted by marker size.
    The z-value data are taken from a symmetrical 2d-numpy-array with the size
    of neuron ids.

    Parameters
    ----------
    ax: matplotlib.axes._subplots.AxesSubplot
        Axes in which the plot will be plotted.
    z_val : np.array (1d or 2d)
        Covariance or other z-data (e.g. rate). Input has to be ordered by
        neuron id.
        In case of 2d array, reference neuron defines the data row taken for
        z-values. The z-value for the reference neuron will be set to zero.
    ExId, InId : np.array or None
        Ids of excitatory/inhibitory neurons. Used to set correct marker type.
    positions : np.array (2d)
        [x, y] positions of neurons ordered by neuron id on sensor array,
        with [0, 0] being the bottom left; assignment of neurons to sensor IDs
    reference_neuron : integer or None
        ID of reference neuron if `z_val` is 2d.
    plt_cb : boolean
        If True, plot color bar inside ax, default=False
    cbarlabel : string
        Label of a color bar, default=None
    vmin, vmax : float or None
        Limits for the color scale. If None, the +/-maximum absolute plotted
        value is used for both
    plt_SensID : boolean
        If True, write connectorID for sensors on each sensor field,
        default=True
    plt_snswdth: boolean
        If True (default), plot arrows depicting width of a sensor field
    ticksensor : boolean
        If True, name axis ticks according to sensor IDs instead of physical
        length, useful for testing.
        If False (default), write axis tick labels in millimeters.
    plot_size : boolean
        If True (default), plot marker size depending on absolute z-value.
        Marker size is normalized by maximal absolute z-value.
    linewidth : float
        Line width of marker border, default=0.5.
    markersize : integer
        Basic marker size (as defined in plt.scatter()), default=100
    min_msize : integer
        Minimum marker size used for z-value-dependent marker size, default=20
    fontsize : float  ### TODO
        Font size used for axis labels, for ticks, sensor names and the scale
        bar (fontsize - 1) is used; default=10.
    cmap : str
        Name of a matplotlib color map used to indicate z-values,
        default='PiYG'.
    c_edge : tuple or None or 'none'
        RGB marker edge color forwarded to ax.scatter(),
        default=(0.8,0.8,0.8) (grey).
        If None, takes the value from rcParams,
        if 'none', no edge is drawn.
    sens_p_row, sens_p_col : integer
        number of sensors per row/column, default=10
    x_up_lim, x_lo_lim : Int or float
        Upper/lower limit for x axis, defaults: 4000, 0
    y_up_lim, y_lo_lim : Int or float
        Upper/lower limit for y axis, defaults: 4000, 0
    verbose: boolean
        If True, print additional information while creating plot.

    Returns
    -------
    im: list
        List of matplotlib.collections.PathCollection (output of plt.scatter)

    """
    # -------------------------------------------------------------------------
    # reorganize data in numpy arrays and create coordinates for plotting
    if reference_neuron is not None and len(z_val.shape) == 2:
        helper = np.copy(z_val[reference_neuron, :])
        helper[reference_neuron] = 0.0
    elif len(z_val.shape) == 1:
        helper = np.copy(z_val)
    else:
        raise ValueError('The z-values must be either 1d or 2d; '
                         'in the latter case reference_neuron has to be'
                         ' a valid array index.')

    total_neurons = np.size(positions[:, 0])
    position_array = np.zeros((total_neurons, 2))
    position_array[:, 1] = np.arange(total_neurons)
    for i in range(np.size(positions[:, 0])):
        position_array[i, 0] = positions[i, 1] * 10 + positions[i, 0]

    # check what will be plotted
    plt_exc, plt_inh = False, False
    if ExId is not None:
        plt_exc = True
    if InId is not None:
        plt_inh = True

    # set limits for z-value
    max_cov = np.nanmax(np.absolute(helper))
    if vmin is None:
        vmin = -max_cov
    if vmax is None:
        vmax = max_cov

    # Define plot appearance
    # number of ticks and labels to plot
    tick_no = 5

    # sensor number (1-100) excluding the edges 1,10,91,100
    nb_sens = sens_p_row * sens_p_col

    # -------------------------------------------------------------------------
    # lists of ordinates for 2 groups of neurons to plot
    list1x = []
    list1y = []
    list1z = []
    list2x = []
    list2y = []
    list2z = []

    # position of the right upper rectangle covering the first electrode
    yposn = 400
    xposn = 400

    # helper variable for plotting style of reference neuron
    ref_n_type_exi = False
    ref_marker = 'D'  # diamond, in case ref is neither exc nor inh

    # loop over all sensor IDs for actual ConnectorID
    for sens_id in range(nb_sens):
        on_sensor = np.where(position_array[:, 0] == sens_id)[0]
        if verbose:
            print('Found:', np.size(on_sensor),
                  ' neuron(s) on sensor', sens_id + 1)

        for j in range(np.size(on_sensor)):
            if j == 0:  # center center
                xposn_temp = xposn - 200
                yposn_temp = yposn - 200
            elif j == 1:  # right upper
                xposn_temp = xposn - 100
                yposn_temp = yposn - 100
            elif j == 2:  # right lower
                xposn_temp = xposn - 100
                yposn_temp = yposn - 300
            elif j == 3:  # left lower
                xposn_temp = xposn - 300
                yposn_temp = yposn - 300
            elif j == 4:  # left upper
                xposn_temp = xposn - 300
                yposn_temp = yposn - 100

            # check for the reference neuron
            if reference_neuron and on_sensor[j] == reference_neuron:
                x_ref = xposn_temp
                y_ref = yposn_temp
                if verbose:
                    print('\ncoordinates of reference_neuron',
                          on_sensor[j], 'is set to: X', xposn_temp, 'Y',
                          yposn_temp, 'on Sensor', sens_id + 1, '\n')

            # Check neurons: excitatory or inhibitory?
            if on_sensor[j] in ExId:
                if verbose:
                    print('Excitatory', 'NeuronID', on_sensor[j],
                          'covariance is:', helper[on_sensor][j])

                list1x.append(xposn_temp)
                list1y.append(yposn_temp)
                list1z.append(helper[on_sensor[j]])

                if reference_neuron and on_sensor[j] == reference_neuron:
                    ref_n_type_exi = True
                    ref_marker = '^'

            elif on_sensor[j] in InId:
                if verbose:
                    print('Inhibitory NeuronID', on_sensor[j],
                          'covariance is:', helper[on_sensor][j])

                list2x.append(xposn_temp)
                list2y.append(yposn_temp)
                list2z.append(helper[on_sensor[j]])

                if reference_neuron and on_sensor[j] == reference_neuron:
                    ref_marker = 'o'

        # move to the next electrode
        if xposn == 400 * sens_p_row:
            xposn = 400
            yposn += 400
        else:
            xposn += 400

    # create markersize for plot with z-value-dependent marker sizes
    if plot_size:
        # markersize in plt.scatter: size in points^2
        s_Ecov = [(np.sqrt(markersize) * np.abs(z) / max_cov
                   + np.sqrt(min_msize)) ** 2 for z in list1z]
        s_Icov = [(np.sqrt(markersize) * np.abs(z) / max_cov
                   + np.sqrt(min_msize)) ** 2 for z in list2z]
    else:
        s_Ecov = markersize
        s_Icov = markersize

    # -------------------------------------------------------------------------
    # plot the neurons
    im = []  # this will be returned
    if plt_exc:
        # plot all excitatory neurons
        im.append(ax.scatter(x=list1x,
                             y=list1y,
                             c=list1z,
                             marker='^',
                             s=s_Ecov,
                             cmap=cmap,
                             vmin=vmin,
                             vmax=vmax,
                             linewidths=linewidth,
                             edgecolors=c_edge))

    if plt_inh:
        # plot all inhibitory neurons
        im.append(ax.scatter(x=list2x,
                             y=list2y,
                             c=list2z,
                             marker='o',
                             s=s_Icov,
                             cmap=cmap,
                             vmin=vmin,
                             vmax=vmax,
                             linewidths=linewidth,
                             edgecolors=c_edge))

    # plot reference neuron on foreground
    if reference_neuron:
        ax.scatter(x=x_ref,
                   y=y_ref,
                   c='white',
                   marker=ref_marker,
                   s=2 * markersize,
                   linewidths=2.0,
                   edgecolor='black')

        if ref_n_type_exi:
            ax.scatter(x=x_ref,
                       y=y_ref - markersize / 5,
                       c='black',
                       marker=ref_marker,
                       s=markersize / 10)

        else:
            ax.scatter(x=x_ref,
                       y=y_ref,
                       c='black',
                       marker=ref_marker,
                       s=markersize / 5)

    # -------------------------------------------------------------------------
    # plot colorbar
    if plt_cb:
        cb = plt.colorbar(cmap=cmap, mappable=im[0], ax=ax, pad=0.02,
                          shrink=cb_shrink)
        cb.set_label(cbarlabel, fontsize=fontsize)

    # define tick labels of the plot
    if ticksensor:
        # use names of electrodes as tick labels
        tickpos = np.linspace(200, 3800, 10)
        minortickpos = np.linspace(400, 4000, 10)
        sensor_x = ['y1', 'y2', 'y3', 'y4', 'y5', 'y6', 'y7', 'y8', 'y9', 'y0']
        sensor_y = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

    else:
        # use mm as tick labels
        tickpos = np.linspace(0, 4000, tick_no).astype(int)
        minortickpos = np.linspace(400, 4000, 10)
        sensor_x = [0, 1, 2, 3, 4]
        sensor_y = [0, 1, 2, 3, 4]

    # -------------------------------------------------------------------------
    # Create sensor grid, manage axis limits and ticks

    # write sensor numbers
    if plt_SensID:
        sens_i = 1  # starting sensor number
        ypos = 305  # starting y coordinate

        for iy_pos in range(int(math.sqrt(nb_sens))):
            xpos = 10  # starting x coordinate for every loop

            for i in range(int(math.sqrt(nb_sens))):
                ax.text(xpos,
                        ypos,
                        sens_i,
                        fontsize=fontsize - 1,
                        color='lightgrey')

                xpos = xpos + 400
                sens_i = sens_i + 1
            ypos = ypos + 400

    # plot sensor grid
    ax.vlines(minortickpos, 0, 4000, color='lightgrey', linestyle=':',
              linewidth=1, zorder=0.5)
    ax.hlines(minortickpos, 0, 4000, color='lightgrey', linestyle=':',
              linewidth=1, zorder=0.5)
    # ax.grid(which='minor', color='lightgrey', linestyle='dotted')
    # ax.grid() does not work here, because minor ticks overlapping with major
    # are removed from minor ticks by matplotlib

    # add annotations for the scale:
    if plt_snswdth:
        ax.annotate('',
                    xytext=(0, 70),
                    xy=(400, 70),
                    arrowprops=dict(color='black',
                                    arrowstyle='<->', shrinkA=0, shrinkB=0)
                    )

        ax.text(50,
                150,
                '400 µm',
                fontsize=fontsize - 1)

    # set limits and labels
    ax.set_xlim(0, 4000)
    ax.set_ylim(0, 4000)
    ax.set_xlabel('x [mm]', fontsize=fontsize)
    ax.set_ylabel('y [mm]', fontsize=fontsize)
    ax.set_xticks(tickpos)
    ax.set_xticklabels(sensor_x, fontsize=fontsize - 1)
    ax.set_yticks(tickpos)
    ax.set_yticklabels(sensor_y, fontsize=fontsize - 1)

    return im


if __name__ == '__main__':
    import h5py_wrapper as h5w
    from helpers import ei_lists

    # load some preprocessed data
    RScovs_path = 'preprocessed/RScovs_E2.h5'
    cons_path = 'preprocessed/consistency_E2.txt'

    measures = h5w.load(RScovs_path)
    # pairwise covariance matrix N_units x N_units
    covm = measures['covm']
    # physical positions of electrodes on the Utah array per SU
    positions = measures['positions']  # N_units x 2
    # threshold to separate putative excitatory and inhibitory SUs
    eiThres = measures['eiThres']
    # lists of putative excitatory/inhibitory SUs IDs
    eIds, iIds = ei_lists(cons_path, eiThres)

    ref_neuron = 96  # use exact Neuron ID here
    # e.g. 96 for an excitatory neuron, 100 for inhibitory

    figsize = (6, 6)  # 10, 10
    plt_cb = True
    if plt_cb:
        figsize = (figsize[0], figsize[0] / 50 * 40)

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)

    im = sap_plt(ax,
                 covm,
                 eIds,
                 iIds,
                 positions,
                 reference_neuron=ref_neuron,
                 plt_cb=plt_cb,
                 cbarlabel='spike count covariance',
                 # Define Plot-Appearance
                 plt_SensID=False,
                 ticksensor=False,
                 plot_size=True,
                 linewidth=0.5,
                 markersize=100,
                 min_msize=20,
                 fontsize=12,
                 cmap='PiYG',
                 c_edge=(0.8, 0.8, 0.8),
                 sens_p_row=10,
                 sens_p_col=10,
                 x_up_lim=4000,
                 x_lo_lim=0,
                 y_up_lim=4000,
                 y_lo_lim=0,
                 # additional information while creating plot
                 verbose=False
                 )

    filename = 'test_salt_and_pepper_ref_neuron_' + str(ref_neuron) + '.png'

    plt.savefig(filename, dpi=600, facecolor='w', edgecolor='w',
                orientation='portrait', papertype=None, format=None,
                transparent=False, bbox_inches=None, pad_inches=0.1,
                metadata=None)
