# coding=utf-8
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pandas as pd
import numpy as np


if __name__ == '__main__':
    plt.rcParams.update({'axes.labelsize': 8,
                         'xtick.labelsize': 6,
                         'ytick.labelsize': 6,
                         'axes.titlesize': 10,
                         'font.size': 8,
                         'legend.fontsize': 6,
                         'legend.frameon': False,
                         'lines.linewidth': 1.4,
                         'axes.spines.right': False,
                         'axes.spines.top': False})
    fig, (exploration_ax, runtime_ax) = plt.subplots(1, 2, figsize=(5, 2.2))  # size in inches

    # Results of parameter exploration

    exploration = pd.read_csv('results/Brian2_parameter_exploration.csv.gz', compression='gzip')
    I_values = sorted(exploration.I.unique())
    g_Na_values = sorted(exploration.g_Na.unique())
    exploration['I_index'] = ((exploration['I'] - np.min(I_values)) / np.diff(I_values)[0]).round().astype('int')
    exploration['Na_index'] = ((exploration['g_Na'] - np.min(g_Na_values)) / np.diff(g_Na_values)[0]).round().astype('int')
    matrix = np.full((len(I_values), len(g_Na_values)), np.nan)
    matrix[exploration['I_index'], exploration['Na_index']] = exploration['firing_rate']
    norm = mpl.colors.BoundaryNorm(np.arange(0, 19), plt.cm.viridis.N)
    m = exploration_ax.imshow(matrix, norm=norm, origin='lower')
    # We do manual ticks, easier than using extent and getting the scaling right
    ticks = [0, 99, 199, 299]
    exploration_ax.set(xticks=ticks, xticklabels=['%.1f' % g_Na_values[i] for i in ticks],
           yticks=ticks, yticklabels=['%.1f' % I_values[i] for i in ticks],
           xlabel=u'$g_{Na}$ (mS/cm²)', ylabel=u'$I$ (pA)')
    divider = make_axes_locatable(exploration_ax)
    cax = divider.append_axes('top', size='5%', pad=0.05)
    cbar = fig.colorbar(m, cax=cax, orientation='horizontal',
                        ticklocation='top')
    cbar.set_label('number of spikes')

    # Time it took to get results

    benchmark = pd.read_csv('results/parameter_exploration_benchmark_results.csv')
    benchmark['threads'] = benchmark['threads'].fillna(0.0)
    grouped = benchmark.groupby(['simulator', 'threads'])
    min_time = grouped.aggregate('min')
    labels = []
    runtime = []
    for label, simulator, threads in [('Brian2GeNN (GPU)', 'Brian2GeNN', 0.0),
                                      ('Brian 2 (12 threads)', 'Brian2', 12.0),
                                      ('Brian 2 (single thread)', 'Brian2', 0.0)
                                      ]:
        labels.append(label)
        runtime.append(min_time.loc[(simulator, threads)]['took'])
    rects = runtime_ax.barh(labels, runtime)
    runtime_ax.set_yticks([])
    runtime_ax.set_xticks([])
    runtime_ax.spines['bottom'].set_visible(False)
    runtime_ax.set_title('Simulation time')
    for rect, label, runtime in zip(rects, labels, runtime):
        # Rectangle widths are already integer-valued but are floating
        # type, so it helps to remove the trailing decimal point and 0 by
        # converting width to int type
        width = int(rect.get_width())

        # The bars aren't wide enough to print the ranking inside
        if width < 40:
            # Shift the text to the right side of the right edge
            xloc = 5
            # Black against white background
            clr = 'black'
            align = 'left'
        else:
            # Shift the text to the left side of the right edge
            xloc = -5
            # White on magenta
            clr = 'white'
            align = 'right'

        # Center the text vertically in the bar
        yloc = rect.get_y() + rect.get_height() / 2
        label = runtime_ax.annotate('{}\n{:.1f}s'.format(label, runtime),
                                    xy=(width, yloc),
                                    xytext=(xloc, 0),
                                    textcoords="offset points",
                                    ha=align, va='center',
                                    color=clr, clip_on=True)
    fig.tight_layout()
    fig.savefig('parameter_exploration.pdf', transparent=True)
