from __future__ import division
import matplotlib
matplotlib.use('Agg')
import resource
import time

import numpy as np
from neuron import h

from benchmark_utils import get_args, plot_results


class LeakyIFModel:
    def __init__(self, record=False, heterogeneous=False):
        self.cell = h.IF3()
        if heterogeneous:
            self.cell.taum = (0.9 + 0.2*np.random.rand()) * 20  # ms
            self.cell.taue = (0.9 + 0.2*np.random.rand()) * 5  # ms
            self.cell.taui = (0.9 + 0.2*np.random.rand()) * 10  # ms
        else:
            self.cell.taum = 20  # ms
            self.cell.taue = 5  # ms
            self.cell.taui = 10  # ms
        self.cell.b = 1.1
        self.cell.m_init = np.random.rand()
        # Refractory period is 5 ms by default (see if3.mod)
        if record:
            self.recorder = h.NetCon(self.cell, None)
            self.spk = h.Vector()
            self.recorder.record(self.spk)
        else:
            self.recorder = None
            self.spk = None


def run_benchmark(n_neurons, n_rec=500, runtime=1000,
                  measure_memory=False, get_spikes=False, heterogeneous=False):
    start_build = time.time()

    Ne = int(n_neurons * 0.8)
    Ni = n_neurons - Ne
    epsilon = 80. / n_neurons

    all_cells = [LeakyIFModel(record=(x < min(n_rec, Ne) or
                                      Ne <= x < Ne + min(n_rec, Ni)),
                              heterogeneous=heterogeneous)
                 for x in range(n_neurons)]

    # Determine number of outgoing connections
    outgoing_E = np.random.binomial(n_neurons, epsilon, size=Ne)
    outgoing_I = np.random.binomial(n_neurons, epsilon, size=Ni)

    conn_E = []
    for e_cell, n_targets in zip(all_cells[:Ne], outgoing_E):
        # Get targets  (with replacement, for simplicity)
        e_targets = np.random.randint(0, n_neurons, size=n_targets)
        conn_E.extend([h.NetCon(e_cell.cell, all_cells[target].cell)
                       for target in e_targets])

    conn_I = []
    for i_cell, n_targets in zip(all_cells[Ne:], outgoing_I):
        # Get targets  (with replacement, for simplicity)
        i_targets = np.random.randint(0, n_neurons, size=n_targets)
        conn_I.extend([h.NetCon(i_cell.cell, all_cells[target].cell)
                       for target in i_targets])

    for conn in conn_E:
        conn.weight[0] = (0.9 + 0.2*np.random.rand()) * 0.02551
        conn.delay = 0.1

    for conn in conn_I:
        conn.weight[0] = (0.9 + 0.2*np.random.rand()) * -0.225
        conn.delay = 0.1

    h.dt = 0.1
    h.finitialize()
    h.fcurrent()
    h.frecord_init()

    end_build = time.time()

    while h.t < runtime:
        h.fadvance()

    end_simulate = time.time()
    sim_time = (end_simulate - end_build)

    results = {}
    results['n_synapses'] = len(conn_E) + len(conn_I)
    if measure_memory:
        max_mem = (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss +
                   resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss)
        results['memory'] = max_mem

    indices_e = []
    times_e = []
    for idx, cell in enumerate(all_cells[:min(n_rec, Ne)]):
        spikes = cell.spk.as_numpy()
        indices_e.extend([idx] * len(spikes))
        times_e.extend(spikes)
    indices_i = []
    times_i = []
    for idx, cell in enumerate(all_cells[Ne:Ne + min(n_rec, Ni)]):
        spikes = cell.spk.as_numpy()
        indices_i.extend([idx] * len(spikes))
        times_i.extend(spikes)
    results['rate'] = (len(indices_e) / min(n_rec, Ne) / float(runtime / 1000),
                       len(indices_i) / min(n_rec, Ni) / float(runtime / 1000))
    if get_spikes:
        spikes = {'indices_e': indices_e, 'times_e': times_e,
                  'indices_i': indices_i, 'times_i': times_i}
        results['spikes'] = spikes

    results['time'] = sim_time

    return results


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    args = get_args()
    results = run_benchmark(get_spikes=True, measure_memory=True, **args)
    print('Simulating {} neurons for {}s took {:.2f}s.'.format(args['n_neurons'],
                                                               args['runtime']/1000,
                                                               results['time']))
    print('Firing rates: {:.2f}Hz/{:.2f}Hz'.format(*results['rate']))
    print('Total number of synapses: {}'.format(results['n_synapses']))
    print('Maximum memory usage: {}MB'.format(results['memory']/1000))
    plot_results(args, results['spikes'])
    plt.show()
