import matplotlib.pyplot as plt
import pandas as pd

from benchmark_utils import evaluate_subset

if __name__ == '__main__':
    plt.rcParams.update({'axes.labelsize': 8,
                         'xtick.labelsize': 6,
                         'ytick.labelsize': 6,
                         'legend.fontsize': 8,
                         'axes.titlesize': 8,
                         'legend.frameon': False,
                         'lines.linewidth': 1.4,
                         'axes.spines.right': False,
                         'axes.spines.top': False})
    fig, axes = plt.subplots(1, 2, figsize=(5.35, 3),
                             sharey=True)  # size in inches

    for ax, suffix in zip(axes, ['homogeneous', 'heterogeneous']):
        full_results = pd.read_csv('results/benchmark_results_{}.csv'.format(suffix))

        for label, simulator, target, threads, ls, color in [
            ('NEST (12 threads)', 'NEST', 'N/A', 12, '-', '#1f77b4'),
            ('NEURON', 'NEURON', 'N/A', None, '-', '#ff7f03'),
            ('Brian 1', 'Brian 1', 'numpy', None, '-', '#d62728'),
            ('Brian 2: runtime', 'Brian 2', 'weave', None, '-', '#1a601a'),
            ('Brian 2: standalone (single thread)', 'Brian 2', 'cpp_standalone', 0.0, '-', '#2ca02c'),
            ('Brian 2: standalone (12 threads)', 'Brian 2', 'cpp_standalone', 12.0, '-', '#6fd76f')
        ]:
            if suffix == 'heterogeneous' and label == 'Brian 1':
                continue  # uses Euler instead of exact integration
            evaluated = evaluate_subset(full_results, simulator=simulator,
                                        target=target, threads=threads)
            ax.plot(evaluated.index, evaluated['rel_simtime'], label=label,
                    marker='.', ls=ls, color=color, mec='white')
        ax.set_xlabel('Number of neurons', labelpad=2)
        ax.set_title('{} population'.format(suffix))
        ax.set(xscale='log', yscale='log')

        ax.axhline(1, color='gray', linestyle=':', lw=1)
    plt.subplots_adjust(top=0.75)
    bb = (fig.subplotpars.left, fig.subplotpars.top + 0.075,
          fig.subplotpars.right - fig.subplotpars.left, .1)
    axes[0].legend(bbox_to_anchor=bb,  mode="expand", loc="lower left",
                   ncol=2, borderaxespad=0., bbox_transform=fig.transFigure)
    axes[0].set_ylabel('execution time / biological time', labelpad=2)

    fig.savefig('../../figures/benchmarks.pdf')
