import sympy, brian2

units= dict(
            list(vars(brian2.units).items())
           +list(vars(brian2.units.allunits).items())
           +list(vars(brian2.units.fundamentalunits).items())
           )

def rsubs(eq,*args,**kwargs):
  """
  from https://itbgit.biologie.hu-berlin.de/compneurophys/brianutils

  Recursive substitutions in sympy. Applies the dictionaries holding
  the substitutions r-times.

  INPUT
    eq     : equation (sympy or str)
    args   : several dictionaries
    r      : is the recursion depth (default: r=1)
  """
  import sympy

  if "r" in kwargs:
    r= kwargs["r"]
  else:
    r=1
  for j in range(r):
    for k in args:
      dic= [(str(i),str(j)) for i,j in dict(k).items()]
      eq= sympy.S(str(eq)).subs(dic)

  return eq


def load_model(model_dict, bifpar=[], substitution_depth=4):
  """
  from https://itbgit.biologie.hu-berlin.de/compneurophys/brianutils
  Loads dictionary holding model definitions into a
      brian2.Equations object.

      INPUT
        model_dict         : dictionary with model secifications
        bifpar             : list (or dict) with bifurcation parameters
        substitution_depth : recursion depth for substitution
      OUTPUT
        ode/sde            : brian2.Equations object
  """

  model_dict= dict(model_dict)
  mandatory_key_list= ["ode", "definitions", "functions", "parameters", "init_states"]
  optional_key_list= ["bibkey"] # not used

  for key in mandatory_key_list:
    assert key in model_dict, 'model_dict is missing key "{}"'.format(key)

  statevar_list = list(model_dict["init_states"].keys())
  time_derivative_list= ["d{}/dt".format(k) for k in statevar_list]
  state_units= dict([[j,"1"] if brian2.is_dimensionless(eval(k,units))
    else [j,repr(brian2.get_dimensions(eval(k,units)))]
    for j,k in model_dict["init_states"].items()])
  parameter_dict= dict(model_dict["parameters"])

  brian_ode= brian2.Equations("")
  if bifpar:
    for key in bifpar:
      par_value= eval(parameter_dict.pop(key),units)
      par_unit= repr(brian2.get_dimensions(par_value))
      if par_unit == "Dimension()":
        par_unit = 1
      brian_ode+=brian2.Equations("{0}:{1}".format(key,par_unit))
  for ode in model_dict["ode"]:
    odestr= "({})-({})".format(*ode.split("="))
    odestr= odestr.replace(" ","")
    time_derivative= [k for k in time_derivative_list if k in odestr][0] # bad style
    state_variable= [k for k in statevar_list if k in time_derivative][0]
    ode_rhs= sympy.S(sympy.solve(odestr,time_derivative)[0]) # [0] is bad style
    ode_rhs= rsubs(ode_rhs,
        model_dict["definitions"],
        model_dict["functions"],
        parameter_dict,
        r=substitution_depth)

    brian_ode+=brian2.Equations("{0} = {1} : {2}".format(
      time_derivative,
      ode_rhs,
      state_units[state_variable])
      )

  return brian_ode
