# -*- coding: utf-8 -*-
#
# This source file is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This source file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Date: 2020
#
# Authors: Paulina Dąbrowska (p.dabrowska@fz-juelich.de) [1,2]
#
# Affiliations:
# [1] Institute of Neuroscience and Medicine (INM-6) and Institute for Advanced
#     Simulation (IAS-6) and JARA Institute Brain Structure-Function
#     Relationships (INM-10), Jülich Research Centre, Jülich, Germany
# [2] RWTH Aachen University, Aachen, Germany

from __future__ import division, print_function

import numpy as np
import h5py_wrapper as h5w
import matplotlib.pyplot as plt
from matplotlib import gridspec
from string import ascii_uppercase
from itertools import combinations_with_replacement as it_cwr

import helpers as hel
from helpers import coldict

# parameters
monkey = 'E2'  # E1, E2, N1 or N2, in the manuscript E2 is presented
formats = ['pdf', 'svg']  # ['pdf', 'svg', 'eps', 'png'] for saving figures
ifSave = True
fonts = 11  # font sizes
dpi = 300
vertical = False  # apply vertical or horizontal layout

###############################################################################
# load the preprocessed data
fits_path = 'preprocessed/RSfits_' + monkey + '.h5'

measures = h5w.load(fits_path)
# outcome of scipy.optimize.least_squares for two fits
efit1 = measures['efit1']
efit3 = measures['efit3']
# all discrete distances between electrodes available on the Utah array, sorted
distances = measures['distances']
# variances of distance-resolved covariances for II, IE, EE pairs
igrek = measures['igrek']  # 3 x N_distances
# indices of distances considered during fitting
fitrange = measures['fitrange']
# weights used during fitting for II, IE, EE pairs 3 x N_considered_distances
ws = measures['ws']

###############################################################################
# composite figure 4
# 2 versions of fit and a table with fitted decay constants

# dims of figs given in mm -> convert to inches
if vertical:
    # single column figure
    width = hel.mm2inch(88)
    height = hel.mm2inch(180)

    fig = plt.figure(figsize=[width, height], dpi=dpi)
    gs = gridspec.GridSpec(3, 1)
    gs.update(top=0.93, bottom=0.1, right=0.95, left=0.2, hspace=0.3)

    ax_3x1 = plt.subplot(gs[0])
    ax_1x1 = plt.subplot(gs[1])
    ax_tab = plt.subplot(gs[2])
else:
    # horizontal layout
    width = hel.mm2inch(180)
    height = hel.mm2inch(70)

    fig = plt.figure(figsize=[width, height], dpi=dpi)
    gs = gridspec.GridSpec(1, 7, top=0.9, bottom=0.18, left=0.1, right=0.8,
                           wspace=0.4)

    ax_3x1 = plt.subplot(gs[0, :3])
    ax_1x1 = plt.subplot(gs[0, 3:-1])
    ax_tab = plt.subplot(gs[0, -1])


# add letters to name panels
if vertical:
    xs = [-0.22, -0.22, -0.22]
    ys = [1.05, -0.22, -1.7]
else:
    xs = [-0.22, 1.05, 2.2]
    ys = [1.05, 1.05, 1.05]
for n, (x, y) in enumerate(zip(xs, ys)):
    ax_3x1.text(x, y, ascii_uppercase[n], transform=ax_3x1.transAxes,
                fontdict={'size': fonts + 1, 'weight': 'bold'})


# the following has to be called after adding letters to plots, otherwise
#    setting font weight does not take effect
params = {
    'text.usetex': True,
    'text.latex.preamble': r'\usepackage{sfmath}',
    'font.family': 'sans-serif',
    'font.sans-serif': 'DejaVu Sans'  # the default here
}
plt.rcParams.update(params)

# remove axes from table axes
ax_tab.set_frame_on(False)
ax_tab.set_xticks([])
ax_tab.set_yticks([])

# create and plot LaTeX table with fitting results
fitted = []
for monk in ['E1', 'E2', 'N1', 'N2']:
    fitfile = fits_path.replace(monkey, monk)
    fitmeasures = h5w.load(fitfile)
    fitted.append(fitmeasures['efit1']['x'][-1])
    relErr = np.mean(fitmeasures['efit1']['fun']**2) \
        / np.mean(fitmeasures['efit3']['fun']**2)
    # print(relErr)
    fitted.append(relErr)

tabtxt = ('${{'
          '\\setlength\\arrayrulewidth{{1pt}}'
          '\\begin{{tabular}}{{| c | c | c |}}'
          # ' \\toprule'  this does not compile
          ' \\hline'
          '           & fitted   & \\\\'
          ' recording & decay    & error$_b$/ \\\\'
          ' session   & constant & error$_a$ \\\\'
          '           & [mm]     & \\\\[1mm]'
          ' \\hline'
          ' E1\\rule{{0pt}}{{1.5em}} & {:0.3f} & {:0.4f} \\\\[1mm]'
          ' E2 & {:0.3f} & {:0.4f} \\\\[1mm]'
          ' N1 & {:0.3f} & {:0.4f} \\\\[1mm]'
          ' N2 & {:0.3f} & {:0.4f} \\\\'
          # ' \\bottomrule'
          ' \\hline'
          ' \\end{{tabular}}}}$').format(*fitted)
# provide ordinates of the bottom left corner:
if vertical:
    ax_tab.text(0.1, -0.3, tabtxt, size=fonts, transform=ax_tab.transAxes)
else:
    ax_tab.text(0.2, 0.08, tabtxt, size=fonts, transform=ax_tab.transAxes)

# plot a 3x1 plot (3 slopes - 1 error)
for i, (nt1, nt2) in enumerate(it_cwr(['inh', 'exc'], 2)):
    ax_3x1.scatter(distances[fitrange], igrek[i, fitrange], marker='.',
                   c=coldict[nt1 + nt2], label='', s=(ws[i] * 10)**2)
    # only with the following command saving yields a latex error,
    # but still saves correctly...
    ax_3x1.semilogy(distances[fitrange],
                    hel.fexp(distances[fitrange], efit3['x'][i],
                             efit3['x'][i + 3]),
                    color=coldict[nt1 + nt2],
                    label='{}-{}: {:0.2f} mm'.format(nt1[0].upper(),
                                                     nt2[0].upper(),
                                                     efit3['x'][i + 3]))

ax_3x1.set_xlim([distances[0], distances[-1] + 0.4])
if vertical:
    hel.labels(ax_3x1, ylabel=r'$\sigma_\mathrm{cov}$',
               fontsize=fonts)
    ax_3x1.set_xticklabels([])
else:
    hel.labels(ax_3x1, xlabel='distance [mm]', ylabel=r'$\sigma_\mathrm{cov}$',
               fontsize=fonts)
ax_3x1.legend(loc='lower left', fontsize=fonts)

# plot a 1x1 plot (1 slope - 1 error)
for i, (nt1, nt2) in enumerate(it_cwr(['inh', 'exc'], 2)):
    ax_1x1.scatter(distances[fitrange], igrek[i, fitrange], marker='.',
                   c=coldict[nt1 + nt2], label='', s=(ws[i] * 10)**2)
    ax_1x1.semilogy(distances[fitrange],
                    hel.fexp(distances[fitrange], efit1['x'][i],
                             efit1['x'][-1]),
                    color=coldict[nt1 + nt2],
                    label='{}-{}'.format(nt1[0].upper(), nt2[0].upper()))

ax_1x1.set_xlim([distances[0], distances[-1] + 0.4])
if vertical:
    hel.labels(ax_1x1, xlabel='distance [mm]', ylabel=r'$\sigma_\mathrm{cov}$',
               fontsize=fonts)
else:
    hel.labels(ax_1x1, xlabel='distance [mm]',
               fontsize=fonts)
    ax_1x1.tick_params(axis='y', which='both', left=False, labelleft=False)
legend = ax_1x1.legend(loc='lower left',
                       title='{:0.3f} mm'.format(efit1['x'][-1]),
                       fontsize=fonts)
plt.setp(legend.get_title(), fontsize=fonts)

if not ifSave:
    plt.show()
# saving # # # # # # # # # # # #
namef = 'plots/fig4_' + monkey
hel.savef(fig, namef, formats, ifSave, dpi=dpi)
plt.rc('text', usetex=False)
