from __future__ import division
import matplotlib
matplotlib.use('Agg')
import resource
import time

from brian import *
from brian.globalprefs import set_global_preferences
from brian.utils.progressreporting import ProgressReporter

from benchmark_utils import get_args, plot_results


def run_benchmark(n_neurons, n_rec=500, runtime=1000, target='weave',
                  measure_memory=False, get_spikes=False, heterogeneous=False):
    reinit_default_clock()
    clear(True, True)
    if target == 'weave':
        # We don't benchmark this for now, it is much slower than not using
        # code generation
        set_global_preferences(useweave=True,
                               usecodegen=True,
                               usecodegenstateupdate=True,
                               usecodegenweave=True,
                               usenewpropagate=True,
                               usecodegenthreshold=True)
    else:
        set_global_preferences(useweave=False,
                               usecodegen=False,
                               usecodegenstateupdate=False,
                               usecodegenweave=False,
                               usenewpropagate=False,
                               usecodegenthreshold=False)
    Ne = int(n_neurons * 0.8)
    runtime = runtime * ms

    taum_0 = 20 * ms
    taue_0 = 5 * ms
    taui_0 = 10 * ms
    Vt = -50 * mV
    Vr = -60 * mV
    El = -49 * mV

    if heterogeneous:
        eqs = '''
        dv/dt  = (ge+gi-(v-El))/taum : volt
        dge/dt = -ge/taue : volt
        dgi/dt = -gi/taui : volt
        taum : second
        taue : second
        taui : second
        '''
    else:
        eqs = '''
        dv/dt  = (ge+gi-(v-El))/taum_0 : volt
        dge/dt = -ge/taue_0 : volt
        dgi/dt = -gi/taui_0 : volt
        '''

    P = NeuronGroup(n_neurons, eqs, threshold='v>Vt', reset='v = Vr',
                    refractory=5 * ms, method='linear')
    Pe = P[:Ne]
    Pi = P[Ne:]
    P.v = Vr + np.random.rand(len(P)) * (Vt - Vr)
    if heterogeneous:
        P.taum = (0.9 + np.random.rand(len(P)) * 0.2) * taum_0
        P.taue = (0.9 + np.random.rand(len(P)) * 0.2) * taue_0
        P.taui = (0.9 + np.random.rand(len(P)) * 0.2) * taui_0
    P.ge = 0 * mV
    P.gi = 0 * mV
    Pe = P[:Ne]
    Pi = P[Ne:]
    we = (60 * 0.27 / 10) * mV  # excitatory synaptic weight (voltage)
    wi = (-20 * 4.5 / 10) * mV  # inhibitory synaptic weight
    Ce = Synapses(Pe, P, 'w : volt', pre='ge += w')
    Ci = Synapses(Pi, P, 'w : volt', pre='gi += w')
    Ce[:, :] = 80. / n_neurons
    Ci[:, :] = 80. / n_neurons
    Ce.w[:, :] = '(0.9 + rand() * 0.2) * we'
    Ci.w[:, :] = '(0.9 + rand() * 0.2) * wi'
    Ce.delay[:, :] = 0.1*ms
    Ci.delay[:, :] = 0.1*ms

    s_mon_e = SpikeMonitor(Pe[:min(n_rec, len(Pe))])
    s_mon_i = SpikeMonitor(Pi[:min(n_rec, len(Pi))])
    start_time = time.time()
    run(runtime)
    end_time = time.time()
    results = {}
    if runtime > 0:
        results['rate'] = (s_mon_e.nspikes / min(n_rec, len(Pe)) / float(runtime),
                           s_mon_i.nspikes / min(n_rec, len(Pi)) / float(runtime))
    else:
        results['rate'] = (np.nan, np.nan)
    results['n_synapses'] = len(Ce) + len(Ci)
    if measure_memory:
        max_mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss +
                   resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss)
        results['memory'] = max_mem
    if get_spikes:
        indices_e, times_e = s_mon_e.it
        indices_i, times_i = s_mon_i.it
        spikes = {'indices_e': indices_e, 'times_e': times_e * 1000,
                  'indices_i': indices_i, 'times_i': times_i * 1000}
        results['spikes'] = spikes
    results['time'] = end_time - start_time
    return results


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    args = get_args(target=('--target',
                            {'help': 'code generation target',
                            'default': 'numpy'}),
                    )
    results = run_benchmark(get_spikes=True, measure_memory=True, **args)
    print('Simulating {} neurons for {}s took {:.2f}s.'.format(args['n_neurons'],
                                                               args['runtime']/1000,
                                                               results['time']))
    print('Firing rates: {:.2f}Hz/{:.2f}Hz'.format(*results['rate']))
    print('Total number of synapses: {}'.format(results['n_synapses']))
    print('Maximum memory usage: {}MB'.format(results['memory']/1000))
    plot_results(args, results['spikes'])
    plt.show()
