from __future__ import division

import argparse

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def get_args(**additional_arguments):
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_neurons', default=4000,
                        help='Total number of neurons', type=int)
    parser.add_argument('--n_rec', default=500,
                        help='Number of excitatory and inhibitory neurons to '
                             'record from', type=int)
    parser.add_argument('--runtime', default=10000,
                        help='Biological runtime (in ms) of the model',
                        type=int)
    parser.add_argument('-g', '--heterogeneous',
                        help='Use heterogeneous parameters for the neuron '
                             'population',
                        action='store_true')
    for flag, additional_arg in additional_arguments.values():
        parser.add_argument(flag, **additional_arg)
    args = parser.parse_args()
    d = {'n_neurons': args.n_neurons,
         'n_rec': args.n_rec,
         'runtime': args.runtime,
         'heterogeneous': args.heterogeneous}
    d.update({name: getattr(args, name)
              for name in additional_arguments})
    return d


def plot_results(args, results, offset=None):
    if offset is None:
        offset = args['n_rec']
    rate_e = len(results['indices_e']) / args['n_rec'] / (
                args['runtime'] / 1000)
    rate_i = len(results['indices_i']) / args['n_rec'] / (
                args['runtime'] / 1000)

    plt.plot(results['times_e'], results['indices_e'], '|',
             label='excitatory')
    plt.plot(results['times_i'],
             np.array(results['indices_i']) + offset,
             '|', label='inhibitory')
    plt.title('Excitatory rate: {:.2f}Hz / '
              'Inhibitory rate: {:.2f}Hz'.format(rate_e, rate_i))
    plt.legend()
    plt.xlabel('time (ms)')
    plt.ylabel('neuron index')
    plt.show()


def evaluate_subset(results, simulator, target, threads=None):
    # Add a new column with the mean runtime for a 0s simulation
    results['target'] = results['target'].fillna('N/A')
    zero_sims = results[results.runtime == 0]
    zero_sims = zero_sims.drop(['rate_exc', 'rate_inh', 'runtime'], axis=1)
    # We cannot do groupby with NaN values
    zero_sims['threads'] = zero_sims['threads'].fillna(-1)
    grouped = zero_sims.groupby(['n_neurons', 'simulator', 'target', 'threads'])
    mean_times = grouped.mean()['simtime'].reset_index()
    mean_times['threads'] = mean_times['threads'].replace(-1, np.nan)
    mean_times = mean_times.set_index(['n_neurons', 'simulator', 'target', 'threads'])
    mean_times.rename(columns={'simtime': 'prep_time'}, inplace=True)
    nonzero_sims = results[results.runtime > 0]
    nonzero_sims = nonzero_sims.join(mean_times,
                                     on=['n_neurons', 'simulator', 'target', 'threads'])
    # We now have a data frame with the mean preparation time (from a 0s sim)
    # for each combination of network size and simulator in a separate column
    # ('prep_time')
    nonzero_sims['prep_time'] = nonzero_sims['prep_time'].fillna(0.0)
    nonzero_sims['rel_simtime'] = (nonzero_sims['simtime'] - nonzero_sims['prep_time']) / (nonzero_sims['runtime']/1000.)

    if threads is None:
        subset = nonzero_sims.loc[nonzero_sims.simulator.str.match(simulator) &
                                  nonzero_sims.target.str.match(target)]
    else:
        subset = nonzero_sims.loc[nonzero_sims.simulator.str.match(simulator) &
                                  nonzero_sims.target.str.match(target) &
                                  (nonzero_sims.threads == threads)]

    grouped = subset.groupby('n_neurons')

    # We are interested in the minimum runtime
    rel_simtime = grouped['rel_simtime'].aggregate('min')

    # And mean + stddev for rate and number of synapses (for validation)
    validation = grouped[['n_synapses',
                          'rate_exc', 'rate_inh']].aggregate(['mean', 'std'])
    validation.columns = ['n_synapses_mean', 'n_synapses_std',
                          'rate_exc_mean', 'rate_exc_std',
                          'rate_inh_mean', 'rate_inh_std']
    return pd.concat([rel_simtime, validation], axis=1)
