#!/usr/bin/env python
# encoding:utf8
#
# This source file is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This source file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Date: 2020
#
# Authors: Moritz Layer (m.layer@fz-juelich.de) [1,2]
#
# Affiliations:
# [1] Institute of Neuroscience and Medicine (INM-6) and Institute for Advanced
#     Simulation (IAS-6) and JARA Institut Brain Structure-Function
#     Relationships (INM-10), Jülich Research Centre, Jülich, Germany
# [2] RWTH Aachen University, Aachen, Germany

'''
Creates plots for DisCo manuscript for theory section and appendix

Usage: plot_theory_figs.py [options] --yaml=<yaml_filename>

Options:
    -h, --help         show this information

'''

import docopt
import yaml
import os
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mc
import matplotlib.gridspec as gridspec
import colorsys
import numpy as np
import string
import h5py_wrapper as h5w
from svgutils.compose import (Figure, Panel, SVG)


from utils import create_parameter_string, weights_2p


def mm2inch(x):
    return x / 25.4


def adjust_lightness(color, amount=0.5):
    try:
        c = mc.cnames[color]
    except:
        c = color
    c = colorsys.rgb_to_hls(*mc.to_rgb(c))
    return colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2])


def get_data(system, settings, parameter_strings):
    """
    Extracts data for settings and combines them in lists

    Parameters:
    -----------
    parameter_strings: str or list of str
        If list of str it needs to have same length as settings.
    """

    means = []
    vars = []
    distances = []

    if isinstance(parameter_strings, str):
        parameter_strings = [parameter_strings for setting in settings]

    for i, setting in enumerate(settings):

        parameter_string = parameter_strings[i]

        input_filename = 'data/{}/{}/{}.h5'.format(system, setting,
                                                      parameter_string)

        data = h5w.load(input_filename)

        distance = data['distances']
        mean = data['mean']
        var = data['var']

        means.append(mean)
        vars.append(var)
        distances.append(distance)

    return distances, means, vars


def lam_eff(w_E, w_I, lam_E, lam_I, K_E, K_I, r):
    """ effective decay constant of 2d2p theory """
    xi2 = w_E**2 * K_E * lam_E**2 + w_I**2 * K_I * lam_I**2
    return np.sqrt(np.absolute(xi2 / (1 - r**2) + lam_E**2 / 2))


def var_of_x_for_different_r(ax, parameters, settings, parameter_strings, r_list, fontsize=11):
    """ Plots variance of x for different spectral bounds. """

    system = parameters['system'][0]
    distances, means, vars = get_data(system, settings, parameter_strings)
    vars = [vars[i][0] for i in range(len(vars))]
    lam = parameters['lam_E']

    n_colors = 2*len(r_list) + 2
    cm = plt.get_cmap('Blues')
    colors =  [cm(1.*i/n_colors) for i in range(n_colors)]
    ax.set_prop_cycle(color=colors[3::2])
    for i, var in enumerate(vars):
        distance = distances[i] / lam
        ax.semilogy(distance, vars[i], label='R={}'.format(r_list[i]))

    ax.set_ylim([1e-13, 1e2])
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, +1.31), ncol=2, fontsize=fontsize-2, handletextpad=0.2, handlelength=1.5, labelspacing=0.1, columnspacing=1)
    ax.tick_params(axis='both', which='major', labelsize=fontsize-1)
    ax.set_xlabel('$x/d$')
    ax.set_ylabel('var(cov)', labelpad=-3)


def lam_eff_of_r(ax, parameters, rs, fontsize=11):
    """ Plots effective decay constant as function of spectral bound. """

    K_E = int(parameters['K_E'][0])
    K_I = int(parameters['K_I'][0])
    lam_E = int(parameters['lam_E'][0])
    lam_I = int(parameters['lam_I'][0])
    g = int(parameters['g'][0])

    w_E, w_I = weights_2p(rs, K_E, K_I, g)
    lam_eff_r = lam_eff(w_E, w_I, lam_E, lam_I, K_E, K_I, rs) / lam_I

    ax.plot(rs, lam_eff_r, label='$d_\mathrm{eff}$')
    ax.axvline(x=1, linestyle='--', color='orange', label='$R=1$')
    ax.legend(loc='lower left', fontsize=fontsize-2, bbox_to_anchor=(0.94,-0.12), handletextpad=0.2, handlelength=1.5)
    ax.tick_params(axis='both', which='major', labelsize=fontsize-1)
    ax.set_xticks([])
    ax.set_yticks([])


def covariance_scatter_plot_with_distance(ax, parameters, C, X,
                                          sample_size=150, sample_seed=42,
                                          markersize=0.01, fontsize=11,
                                          threshold=None):
    """
    Plots distance resolved covariance scatter plot with mean and variance.
    """

    # used anatomical decay constant
    d = parameters['lam_E'][0]
    g = parameters['g'][0]

    # only take upper triangle because covariance matrix is symmetric
    X /= d
    X = np.triu(X, 1)
    C = np.triu(C, 1)
    # get values for excitatory neurons
    X = X[::g, ::g]
    C = C[::g, ::g]
    # take out zeros
    C = C[X != 0]
    X = X[X != 0]
    # if threshold is given only take values below given distance
    if threshold:
        C = C[X < d * threshold]
        X = X[X < d * threshold]
        
    distances = np.unique(X)
    mean = np.array([C[X == x].mean() for x in distances])
    std = np.array([C[X == x].std() for x in distances])
    np.random.seed(sample_seed)
    
    def find_nearest(array, value):
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return array[idx]
    
    xs = [find_nearest(distances, x)
          for x in np.linspace(distances.min(), distances.max(), 200)]
    
    sample_indices = np.empty(0)
    for x in xs:
        all_indices = np.where(X == x)[0]
        sample_indices = np.append(sample_indices,
                                   np.random.choice(all_indices,
                                                    size=sample_size,
                                                    replace=False))
    sample_indices = sample_indices.astype(int)

    X_sample = X[sample_indices]
    C_sample = C[sample_indices]
    
    ax.scatter(X_sample, C_sample, color='steelblue', s=markersize)
    ax.plot(distances, mean, color='cyan', linewidth=1.5)
    ax.plot(distances, mean + std, color='orange', linewidth=1.5)
    ax.plot(distances, mean - std, color='orange', linewidth=1.5)

    ax.set_xlabel('$x/d_\mathrm{E}$')
    ax.set_ylabel('cov', labelpad=-15)
    
    return ax


def var_of_x_for_different_connection_types(ax, parameters, settings, fontsize=11):
    """ Plot var of x for different kinds of connections (EE, EI, II). """

    # get data
    system = parameters['system'][0]
    parameters_temp = parameters.copy()
    parameters_temp.pop('setting')
    parameters_temp.pop('ext')[0]
    parameter_string = create_parameter_string(parameters_temp)
    distances, means, vars = get_data(system, settings, parameter_string)

    lam_E = parameters['lam_E'][0]
    distances = np.array(distances[0]) / lam_E

    ax.semilogy(distances, vars[0][2], label='II', color='#FF1E1E')
    ax.semilogy(distances, vars[0][1], label='EI', color='orange')
    ax.semilogy(distances, vars[0][0], label='EE', color='steelblue')
    ax.legend(loc='upper right', fontsize=fontsize, bbox_to_anchor=(1.02, 1.02))
    ax.tick_params(axis='both', which='major', labelsize=fontsize-1)
    ax.tick_params(axis='y', pad=-1)
    ax.set_xlabel('$x/d_\mathrm{E}$')
    ax.set_ylabel('var(cov)', labelpad=0)


def eigenvalue_cloud(ax, parameters, eigvals, fontsize=11, markersize=1.5):
    """
    Plot eigenvalue cloud of network.
    """

    r = parameters['r'][0]
    eigval_cloud = eigvals[eigvals.real > -1]
    outliers = eigvals[eigvals.real < -1]
    ax.scatter(eigval_cloud.real, eigval_cloud.imag, s=markersize,
               color='steelblue')
    ax.scatter(outliers.real, outliers.imag, s=1.5, color='steelblue')
    ax.axvline(x=r, linestyle='--', color='orange')
    outlier_index = np.argmin(eigvals.real)
    ax.annotate('population\neigenvalue',
                xytext=(eigvals.real[outlier_index] + 0.3,
                        eigvals.imag[outlier_index] - 1.85),
                xy=(eigvals.real[outlier_index] + 0.05,
                    eigvals.imag[outlier_index] - 0.08),
                fontsize=fontsize - 1,
                arrowprops={'arrowstyle': '-|>', 'facecolor': 'black'})
    ax.annotate('spectral bound', xytext=(-3, 1.5),
                xy=(0.95, 1), fontsize=fontsize - 1,
                arrowprops={'arrowstyle': '-|>', 'facecolor': 'black'})
    ax.set_aspect('equal')
    ax.set_xlabel('$\mathrm{Re}(\lambda)$')
    ax.set_ylabel('$\mathrm{Im}(\lambda)$')
    ax.set_xlim([eigvals.real[outlier_index] - 0.25, 1.25])
    ax.set_ylim([-2.25, 2.25])
    ax.tick_params(axis='both', which='major', labelsize=fontsize - 1)


def simulation_vs_theory(ax, parameters, N_x=None, N_y=None, r=None, seed=None, fontsize=11):

    system = parameters['system'][0]
    parameters_temp = parameters.copy()
    parameters_temp.pop('setting')
    parameters_temp.pop('ext')[0]
    lam_E = parameters['lam_E'][0]

    if not N_x:
        N_x = parameters['N_x'][0]
    else:
        parameters_temp.update(N_x=[N_x])
    if not N_y:
        N_y = parameters['N_y'][0]
    else:
        parameters_temp.update(N_y=[N_y])
    if not r:
        r = parameters['r'][0]
    else:
        parameters_temp.update(r=[r])
    if not seed:
        seed = parameters['seed'][0]
    else:
        parameters_temp.update(seed=[seed])

    parameter_string = create_parameter_string(parameters_temp)

    settings=['simulation']
    distances_sim, means_sim, vars_sim = get_data(system, settings, parameter_string)
    distances_sim = np.array(distances_sim[0]) / lam_E

    settings=['disc_theory']
    distances_thy, means_thy, vars_thy = get_data(system, settings, parameter_string)
    distances_thy = np.array(distances_thy[0]) / lam_E
    
    color_II = 'red'
    color_EI = 'orange'
    color_EE = 'steelblue'
    
    ax.semilogy(distances_sim, vars_sim[0][2], 'o' , color=adjust_lightness(color_II, 0.5), label='II sim', markersize=1)
    ax.semilogy(distances_thy, vars_thy[0][2], 'o' , color=color_II , label='II thy', markersize=1)

    ax.semilogy(distances_sim, vars_sim[0][1], 'o' , color=adjust_lightness(color_EI, 0.5), label='EI sim', markersize=1)
    ax.semilogy(distances_thy, vars_thy[0][1], 'o' , color=color_EI, label='EI thy', markersize=1)

    ax.semilogy(distances_sim, vars_sim[0][0], 'o' , color=adjust_lightness(color_EE, 0.5), label='EE sim', markersize=1)
    ax.semilogy(distances_thy, vars_thy[0][0], 'o' , color=color_EE, label='EE thy', markersize=1)

    label_order = [1, 0, 3, 2, 5, 4]
    handles, labels = ax.get_legend_handles_labels()
    handles = [handles[idx] for idx in label_order]
    labels = [labels[idx] for idx in label_order]
    
    ax.legend(handles, labels, fontsize=fontsize, markerscale=6, handlelength=1.5, handletextpad=0., borderpad=0.2)
    ax.tick_params(axis='both', which='major', labelsize=fontsize-1)
    
    ax.set_xlabel('$x/d$')
    ax.set_ylabel('var(cov)', labelpad=-10)


def simulation_vs_pade(ax, parameters, N_x=None, N_y=None, r=None, fontsize=11):
    """
    TODO:
    - remove magic numbers at the end of this function
    """

    system = parameters['system'][0]
    parameters_temp = parameters.copy()
    parameters_temp.pop('setting')
    parameters_temp.pop('ext')[0]
    lam_E = parameters['lam_E'][0]

    if not N_x:
        N_x = parameters['N_x'][0]
    else:
        parameters_temp.update(N_x=[N_x])
    if not N_y:
        N_y = parameters['N_y'][0]
    else:
        parameters_temp.update(N_y=[N_y])
    if not r:
        r = parameters['r'][0]
    else:
        parameters_temp.update(r=[r])

    parameter_string = create_parameter_string(parameters_temp)

    settings=['disc_theory']
    distances, means, vars = get_data(system, settings, parameter_string)
    distances = np.array(distances[0]) / lam_E

    color_EE = 'steelblue'

    ax.semilogy(distances, vars[0][0], 'o' , label='thy', color=adjust_lightness(color_EE, 1.5), markersize=1)
    ax.set_xlabel('$x/d$')
    ax.set_ylabel('var(cov)')

    settings=['pade_theory']
    distances, means, vars = get_data(system, settings, parameter_string)
    distances = np.array(distances[0]) / lam_E

    ax.semilogy(distances, vars[0][0], label='Padé', color=adjust_lightness(color_EE, 1))
    
    settings=['pade_theory_higher_order']
    distances, means, vars = get_data(system, settings, parameter_string)
    distances = np.array(distances[0]) / lam_E

    ax.semilogy(distances, vars[0][0], label='Padé h.o.', color=adjust_lightness(color_EE, 0.5))
    ax.legend(fontsize=fontsize, markerscale=6,)
    
    ax.tick_params(axis='both', which='major', labelsize=fontsize-1)
    
    for label in ax.yaxis.get_ticklabels()[::2]:
        label.set_visible(False)


def theory_plot(parameters, fontsize=11):
    """
    Plots theory composition for DisCo manuscript.

    This plotting routine is optimized for the following parameter set:
    D:1, K_E:100, K_I:50, N_x:201, N_y:201, distr:exp, g:4, lam_E:20, lam_I:10,
    r:0.95, seed:42, sytem:2d2p.
    """
    
    path = 'data/'
    
    figsize = (mm2inch(180), mm2inch(120))
    fig = plt.figure(figsize=figsize, constrained_layout=True)
    gs = gridspec.GridSpec(2, 3, figure=fig)
    ax0 = fig.add_subplot(gs[0, :1])
    ax1 = fig.add_subplot(gs[0, 2])
    ax2 = fig.add_subplot(gs[1, 0])
    ax3 = fig.add_subplot(gs[1, 1])
    ax4 = fig.add_subplot(gs[1, 2])

    # plot whitespace for network sketch
    ax0.set_axis_off()
    print('plot whitespace done')

    # eigenvalue cloud
    eigvals = np.load(path + 'eigenvalues.npy')
    ax1.set_rasterized(True)
    eigenvalue_cloud(ax1, parameters, eigvals, markersize=0.005,
                     fontsize=fontsize)
    print('plot eigenvalues done')

    # distance resolved covariance scatter plot
    C = np.load(path + 'covariances.npy')
    X = np.load(path + 'distances.npy')
    ax2.set_rasterized(True)
    covariance_scatter_plot_with_distance(ax2, parameters, C, X,
                                          fontsize=fontsize)
    print('plot covariance with distance done')

    # var of x for different r
    # list of spectral bounds plotted
    r_list = [0.99, 0.95, 0.8, 0.3]
    replace = [dict(r=[r]) for r in r_list]
    parameter_list = [dict(parameters, **new_params) for new_params in replace]
    # create list of parameter strings
    parameter_strings = [create_parameter_string(parameters) for parameters
                         in parameter_list]
    # settings to be plotted
    settings = ['cont_theory', 'cont_theory', 'cont_theory', 'cont_theory']
    var_of_x_for_different_r(ax3, parameters, settings, parameter_strings,
                             r_list, fontsize=fontsize)
    print('plot variance of r done')

    # inset lam_eff of r
    ax3_inset = ax3.inset_axes([0.015, 0.015, 0.35, 0.35])
    rs = np.arange(0.1, 1, 0.01)
    lam_eff_of_r(ax3_inset, parameters, rs, fontsize=fontsize)
    print('plot inset done')

    # var of x for different kinds of connections (EE, EI, II)
    settings = ['cont_theory']
    var_of_x_for_different_connection_types(ax4, parameters, settings,
                                            fontsize=fontsize)
    print('plot variance of different connection types done')

    # add letters for reference
    axs = [ax0, ax1, ax2, ax3, ax4]
    for n, ax in enumerate(axs):
        ax.text(-0.1, 1.1, string.ascii_uppercase[n], transform=ax.transAxes,
                size=fontsize + 1)

    # final formatting and save
    path = ''
    filename = 'fig3.svg'
    plt.savefig(path + filename, bbox_inches="tight", dpi=300)
    plt.close()
    
    # combine with svg sketch
    Figure("15cm", "10cm",
           Panel(SVG(path + filename).move(20, 20)),
           Panel(
               SVG("fig3a.svg"
                   ).move(45, 20).scale(1.2))
           ).save(path + filename)


def appendix_plot(parameters, fontsize=11):

    figsize = (mm2inch(180), mm2inch(180)/2)
    fig = plt.figure(figsize=figsize, constrained_layout=True)
    gs = gridspec.GridSpec(1,2, figure=fig)
    ax0 = fig.add_subplot(gs[0,0])
    ax1 = fig.add_subplot(gs[0,1])
    
    # plot data from simulation vs data from discrete theory
    simulation_vs_theory(ax0, parameters, N_x=61, N_y=61, r=0.8, seed=42, fontsize=fontsize)

    # plot data from simulation vs slope of Padé approximation
    simulation_vs_pade(ax1, parameters, N_x=1001, N_y=1001, r=0.95, fontsize=fontsize)

    # add letters for reference
    axs = [ax0, ax1]
    for n, ax in enumerate(axs):
        ax.text(-0.1, 1.1, string.ascii_uppercase[n], transform=ax.transAxes,
                size=fontsize+1)

    # final formatting and save
    path = 'figures/'
    filename = 'supp_fig4.png'
    if not os.path.exists(path):
        os.makedirs(path)
    plt.savefig(path + filename, bbox_inches="tight")
    plt.close()


if __name__ == '__main__':

    args = docopt.docopt(__doc__)
    yaml_file = args['--yaml']

    # load parameters from yaml file
    with open(yaml_file, 'r') as stream:
        try:
            parameters = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)

    matplotlib.rcParams.update({'font.size': 11})

    theory_plot(parameters)
    # appendix_plot(parameters)
