from __future__ import division
import matplotlib
matplotlib.use('Agg')
import time
import resource

import numpy as np
import nest
import nest.raster_plot

from benchmark_utils import get_args, plot_results


def run_benchmark(n_neurons, n_rec=500, runtime=1000, threads=1,
                  measure_memory=False, get_spikes=False,
                  heterogeneous=False):

    nest.ResetKernel()

    startbuild = time.time()

    dt = 0.1  # the resolution in ms
    delay = 0.1  # synaptic delay in ms

    epsilon = 80. / n_neurons  # connection probability

    NE = int(n_neurons * 0.8)  # number of excitatory neurons
    NI = n_neurons - NE

    CE = int(epsilon * NE)  # number of excitatory synapses per neuron
    CI = int(epsilon * NI)  # number of inhibitory synapses per neuron

    neuron_params = {"C_m": 200.0,
                     "tau_m": 20.0,
                     "t_ref": 5.0,
                     "E_L": -49.0,
                     "V_reset": -60.0,
                     "V_m": -49.0,
                     "V_th": -50.0,
                     "tau_syn_ex": 5.0,
                     "tau_syn_in": 10.0}

    nest.SetKernelStatus({"resolution": dt, "overwrite_files": True,
                          "local_num_threads": threads})

    nest.SetDefaults("iaf_psc_exp", neuron_params)

    nodes_ex = nest.Create("iaf_psc_exp", NE)
    nodes_in = nest.Create("iaf_psc_exp", NI)

    Vms = -60.0 + 10.0*np.random.rand(NE)
    nest.SetStatus(nodes_ex, "V_m", Vms)
    Vms = -60.0 + 10.0*np.random.rand(NI)
    nest.SetStatus(nodes_in, "V_m", Vms)
    if heterogeneous:
        nest.SetStatus(nodes_ex, "tau_m", (0.9 + 0.2 * np.random.rand()) * 20.0)
        nest.SetStatus(nodes_ex, "tau_syn_ex", (0.9 + 0.2 * np.random.rand())*5.0)
        nest.SetStatus(nodes_ex, "tau_syn_in", (0.9 + 0.2 * np.random.rand())*10.0)
        nest.SetStatus(nodes_ex, "tau_m", (0.9 + 0.2 * np.random.rand())*20.0)
        nest.SetStatus(nodes_in, "tau_syn_ex", (0.9 + 0.2 * np.random.rand())*5.0)
        nest.SetStatus(nodes_in, "tau_syn_in", (0.9 + 0.2 * np.random.rand())*10.0)

    espikes = nest.Create("spike_detector")
    ispikes = nest.Create("spike_detector")

    nest.SetStatus(espikes, [{"label": "cuba_e",
                              "withtime": True,
                              "withgid": True,
                              "to_file": False}])

    nest.SetStatus(ispikes, [{"label": "cuba_i",
                              "withtime": True,
                              "withgid": True,
                              "to_file": False}])

    nest.Connect(nodes_ex[:min(n_rec, NE)], espikes, syn_spec="static_synapse")
    nest.Connect(nodes_in[:min(n_rec, NI)], ispikes, syn_spec="static_synapse")

    conn_params_ex = {'rule': 'pairwise_bernoulli', 'p': epsilon}
    nest.Connect(nodes_ex, nodes_ex + nodes_in, conn_params_ex,
                 syn_spec={"model": "static_synapse",
                           "weight": {"distribution": "uniform",
                                      "low": 0.9*16.2, "high": 1.1*16.2},
                           "delay": delay})

    conn_params_in = {'rule': 'pairwise_bernoulli', 'p': epsilon}
    nest.Connect(nodes_in, nodes_ex + nodes_in, conn_params_in,
                 syn_spec={"model": "static_synapse",
                           "weight": {"distribution": "uniform",
                                      "low": 1.1*-90, "high": 0.9*-90},
                           "delay": delay})
    endbuild = time.time()

    nest.Simulate(runtime)

    endsimulate = time.time()

    build_time = endbuild - startbuild
    sim_time = endsimulate - endbuild

    results = {}
    if measure_memory:
        max_mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss +
                   resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss)
        results['memory'] = max_mem
    n_spikes_e = nest.GetStatus(espikes, keys="n_events")[0]
    n_spikes_i = nest.GetStatus(ispikes, keys="n_events")[0]

    if n_rec > 0:
        if runtime > 0:
            results['rate'] = (n_spikes_e / runtime * 1000.0 / min(n_rec, NE),
                               n_spikes_i / runtime * 1000.0 / min(n_rec, NI))
        else:
            results['rate'] = (np.nan, np.nan)

    num_synapses = (len(nest.GetConnections(nodes_ex)) +
                    len(nest.GetConnections(nodes_in)))
    results['n_synapses'] = num_synapses

    if get_spikes:
        e_events = nest.GetStatus(espikes, keys="events")[0]
        i_events = nest.GetStatus(ispikes, keys="events")[0]
        results['spikes'] = {'indices_e': e_events["senders"],
                             'times_e': e_events["times"],
                             'indices_i': i_events["senders"],
                             'times_i': i_events["times"]}
    results['time'] = sim_time
    return results


if __name__ == '__main__':
    args = get_args(threads=('--threads',
                             {'help': 'number of threads',
                              'default': 1,
                              'type': int}))
    results = run_benchmark(get_spikes=True, measure_memory=True, **args)
    print('Simulating {} neurons for {}s took {:.2f}s.'.format(args['n_neurons'],
                                                               args['runtime']/1000,
                                                               results['time']))
    print('Firing rates: {:.2f}Hz/{:.2f}Hz'.format(*results['rate']))
    print('Total number of synapses: {}'.format(results['n_synapses']))
    print('Maximum memory usage: {}MB'.format(results['memory'] / 1000))
    import matplotlib.pyplot as plt
    plot_results(args, results['spikes'],
                 offset=-int(args['n_neurons'] * 0.8) + args['n_rec'])
    plt.show()
