from __future__ import division
import matplotlib
matplotlib.use('Agg')
import resource

from brian2 import *

from benchmark_utils import get_args, plot_results

def run_benchmark(n_neurons, n_rec=500, runtime=1000, target='weave',
                  threads=0, measure_memory=False, get_spikes=False,
                  heterogeneous=False):
    if target == 'cpp_standalone':
        set_device('cpp_standalone')
        prefs.devices.cpp_standalone.openmp_threads = threads
    elif target == 'brian2genn':
        import brian2genn
        set_device('genn')
    else:
        set_device('runtime')
        prefs.codegen.target = target
    Ne = int(n_neurons * 0.8)
    runtime = runtime * ms

    taum_0 = 20 * ms
    taue_0 = 5 * ms
    taui_0 = 10 * ms
    Vt = -50 * mV
    Vr = -60 * mV
    El = -49 * mV

    if heterogeneous:
        eqs = '''
        dv/dt  = (ge+gi-(v-El))/taum : volt (unless refractory)
        dge/dt = -ge/taue : volt (unless refractory)
        dgi/dt = -gi/taui : volt (unless refractory)
        taum  : second (constant)
        taue : second (constant)
        taui : second (constant)
        '''
    else:
        eqs = '''
        dv/dt  = (ge+gi-(v-El))/taum_0 : volt (unless refractory)
        dge/dt = -ge/taue_0 : volt (unless refractory)
        dgi/dt = -gi/taui_0 : volt (unless refractory)
        '''

    P = NeuronGroup(n_neurons, eqs, threshold='v > Vt', reset='v = Vr',
                    refractory=5*ms, method='exact')
    Pe = P[:Ne]
    Pi = P[Ne:]
    P.v = 'Vr + rand() * (Vt - Vr)'
    if heterogeneous:
        # Parameters for each neuron vary between 90% and 110% of standard values
        P.taum = '(0.9 + 0.2*rand()) * taum_0'
        P.taue = '(0.9 + 0.2*rand()) * taue_0'
        P.taui = '(0.9 + 0.2*rand()) * taui_0'

    we = (60 * 0.27 / 10) * mV  # excitatory synaptic weight (voltage)
    wi = (-20 * 4.5 / 10) * mV  # inhibitory synaptic weight
    Ce = Synapses(Pe, P, 'w : volt (constant)',
                  on_pre='ge += w', delay=0.1*ms)
    Ce.connect(p=80. / n_neurons)
    Ce.w = '(0.9 + 0.2*rand())*we'
    Ci = Synapses(Pi, P, 'w: volt (constant)',
                  on_pre='gi += w', delay=0.1*ms)
    Ci.connect(p=80. / n_neurons)
    Ci.w = '(0.9 + 0.2*rand())*wi'

    s_mon_e = SpikeMonitor(Pe[:min(n_rec, len(Pe))])
    s_mon_i = SpikeMonitor(Pi[:min(n_rec, len(Pi))])
    run(runtime)
    execution_time = device._last_run_time
    results = {}
    results['rate'] = (s_mon_e.num_spikes / min(n_rec, len(Pe)) / float(runtime),
                       s_mon_i.num_spikes / min(n_rec, len(Pi)) / float(runtime))
    results['n_synapses'] = len(Ce) + len(Ci)
    if measure_memory:
        max_mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss +
                   resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss)
        results['memory'] = max_mem
    if get_spikes:
        spikes = {'indices_e': s_mon_e.i[:], 'times_e': s_mon_e.t[:] / ms,
                  'indices_i': s_mon_i.i[:], 'times_i': s_mon_i.t[:] / ms}
        results['spikes'] = spikes
    results['time'] = execution_time
    return results


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    args = get_args(target=('--target',
                            {'help': 'code generation target',
                            'default': 'cython'}),
                    threads=('--threads',
                             {'help': 'number of threads (only for C++ '
                                     'standalone mode',
                             'default': 0,
                             'type': int}))
    results = run_benchmark(get_spikes=True, measure_memory=True, **args)
    print('Simulating {} neurons for {}s took {:.2f}s.'.format(args['n_neurons'],
                                                               args['runtime']/1000,
                                                               results['time']))
    print('Firing rates: {:.2f}Hz/{:.2f}Hz'.format(*results['rate']))
    print('Total number of synapses: {}'.format(results['n_synapses']))
    print('Maximum memory usage: {}MB'.format(results['memory']/1000))
    plot_results(args, results['spikes'])
    plt.show()

