# coding=utf-8
from __future__ import division

import argparse

from brian2 import *

def run_exploration(runtime, g_na_values, I_values, target='cpp_standalone',
                    threads=0):
    if target == 'cpp_standalone':
        set_device('cpp_standalone')
        prefs.devices.cpp_standalone.openmp_threads = threads
    elif target == 'brian2genn':
        import brian2genn
        set_device('genn')
    else:
        raise ValueError('Unknown code generation target "{}"'.format(target))

    # HH model with injected current
    area = 20000*umetre**2
    Cm = (1*ufarad*cm**-2) * area
    gl = (5e-5*siemens*cm**-2) * area

    El = -60*mV
    EK = -90*mV
    ENa = 50*mV
    g_kd = (30*msiemens*cm**-2) * area
    VT = -63*mV

    # The model
    eqs = Equations('''
    dv/dt = (gl*(El-v)-
             g_na*(m*m*m)*h*(v-ENa)-
             g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt
    dm/dt = alpha_m*(1-m)-beta_m*m : 1
    dn/dt = alpha_n*(1-n)-beta_n*n : 1
    dh/dt = alpha_h*(1-h)-beta_h*h : 1
    alpha_m = 0.32*(mV**-1)*(13*mV-v+VT)/
             (exp((13*mV-v+VT)/(4*mV))-1.)/ms : Hz
    beta_m = 0.28*(mV**-1)*(v-VT-40*mV)/
            (exp((v-VT-40*mV)/(5*mV))-1)/ms : Hz
    alpha_h = 0.128*exp((17*mV-v+VT)/(18*mV))/ms : Hz
    beta_h = 4./(1+exp((40*mV-v+VT)/(5*mV)))/ms : Hz
    alpha_n = 0.032*(mV**-1)*(15*mV-v+VT)/
             (exp((15*mV-v+VT)/(5*mV))-1.)/ms : Hz
    beta_n = .5*exp((10*mV-v+VT)/(40*mV))/ms : Hz
    I : amp (constant)
    g_na : siemens (constant)
    ''')
    neuron = NeuronGroup(len(g_na_values)*len(I_values), eqs,
                         method='exponential_euler',
                         threshold='v>-20*mV', refractory='v>-20*mV')
    neuron.v = El
    spike_mon = SpikeMonitor(neuron)
    all_g_na_values, all_I_values = np.meshgrid(g_na_values, I_values)
    all_g_na_values = all_g_na_values.flat[:]
    all_I_values = all_I_values.flat[:]
    neuron.g_na = all_g_na_values*msiemens*cm**-2 * area
    neuron.I = all_I_values*pA

    run(runtime)
    took = device._last_run_time
    rates = spike_mon.count/runtime/Hz
    return all_g_na_values, all_I_values, rates, took


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--resolution', default=100,
                        help='Number of values to test per parameter. Defaults '
                             'to 100, i.e. 10000 values in total.', type=int)
    parser.add_argument('--runtime', default=10000,
                        help='Biological runtime (in ms) of the model. '
                             'Defaults to 10000 (i.e. 10s).',
                        type=int)
    parser.add_argument('--target', default='cpp_standalone',
                        help='Code generation target to use, should be either'
                             '"cpp_standalone" or "genn". Defaults to '
                             '"cpp_standalone"')
    parser.add_argument('--threads', default=0,
                        help='Number of threads to use')
    parser.add_argument('--plot', default=False,
                        help='Show a plot of the results in the end.',
                        action="store_true")
    args = parser.parse_args()
    g_na_values = np.linspace(10, 100, num=args.resolution)
    I_values = np.linspace(0, 20, num=args.resolution)

    rates, took = run_exploration(args.runtime*ms, g_na_values, I_values,
                                  target=args.target, threads=args.threads)
    matrix = np.reshape(rates, (len(g_na_values), len(I_values)))
    if args.target == 'cpp_standalone':
        target_str = 'C++ standalone'
        if args.threads > 0:
            target_str += ' ({} threads)'.format(args.threads)
    else:
        target_str = 'Brian2GeNN'
    print('Simulation with target {} took {:.2f}s'.format(target_str,
                                                          took))

    if args.plot:
        fig, ax = plt.subplots()
        img = ax.imshow(matrix, extent=(g_na_values[0], g_na_values[-1],
                                        I_values[0], I_values[-1]),
                        origin='lower',
                        interpolation='none', aspect=(90 / 20))
        ax.set(xlabel=u'sodium conductance density (mS/cm²)',
               ylabel=u'injected current (pA)',
               title=target_str)
        cb = fig.colorbar(img)
        cb.set_label('firing rate (Hz)')
        plt.show()
