# encoding:utf8
#
# This source file is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This source file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""
Collection of often used functions.

Functions:
----------
distance_matrix_1d:
    Calculate distance matrix for 1d network with periodic boundary conditions
    for all neurons.
distance_matrix_2d:
    Calculate distance matrix for 2d network with periodic boundary conditions
    for a single neuron.
prob_distr_1d:
    Defines 1d connectivity profile for one neuron.
prob_distr_1d_discrete:
    Defines 1d connectivity profile for one neuron in discrete mesh.
prob_distr_2d:
    Defines 2d connectivity profile for one neuron.
prob_distr_2d_discrete:
    Defines 2d connectivity profile for one neuron in discrete mesh.
weights_1p:
    Calc weights for network with one population of neurons.
weights_2p:
    Calc weights for network with inh and exc population of neurons.
"""

import numpy as np

def distance_matrix_1d(N):
    """
    Calculate distance matrix for 1d network with periodic boundary conditions
    for all neurons.

    Parameters:
    -----------
    N: int (odd)
        Number of neurons.

    Returns:
    --------
    distance_matrix: np.ndarray
        1d distance matrix.
    """
    # number of neurons one neuron can connect to (assumes M odd)
    M = (N-1)/2

    # create distance matrix
    x = np.arange(N)
    # account for periodic boundary conditions
    x = np.where(x > M, N - x, x)
    # distance matrix
    distance_matrix = np.zeros((N, N))
    # iterate over all neurons
    for i in range(N):
        # fill distance matrix using translational symmetry
        distance_matrix[i] = np.roll(x, i)

    return distance_matrix


def distance_matrix_2d(N_x, N_y):
    """
    Calculate distance matrix for 2d network with periodic boundary conditions
    for a single neuron.

    Parameters:
    -----------
    N_x: int (odd)
        Number of neurons in x direction.
    N_y: int (odd)
        Number of neurons in y direction.

    Returns:
    --------
    distance_matrix: np.ndarray
        2d distance matrix.
    """

    # number of neurons one neuron can connect to (assumes M odd)
    M_x = (N_x-1)/2
    M_y = (N_y-1)/2

    # create distance matrix
    y = np.arange(N_y)
    # account for periodic boundary conditions
    y = np.where(y > M_y, N_y - y, y)

    # create distance matrix
    x = np.arange(N_x)
    # account for periodic boundary conditions
    x = np.where(x > M_x, N_x - x, x)
    # create 2d coordinates
    x_coord, y_coord = np.meshgrid(x, y)
    # create distance matrix for single neuron
    distance_matrix = np.sqrt(x_coord**2 + y_coord**2)

    return distance_matrix


def prob_distr_1d(x, lam, distr='exp'):
    """
Defines 1d connectivity profile for one neuron.

Profiles normalised to 1.

Parameters:
-----------
    x: numerical
    distance of neurons given as numerical or array.
    lam: int or float
    Constant of decay / length scale.
    distr: str
    Decides which distribution is used. Options: 'exp', 'gauss'

Returns:
--------
np.ndarray
    Array of connection probabilities.
    """

    if distr == 'exp':
        return 1/(2*lam) * np.exp(-np.absolute(x) / lam)
    elif distr == 'gauss':
        return 1/(np.sqrt(2*np.pi)*lam) * np.exp(- np.power(x, 2) / (2 * np.power(lam, 2)))
    else:
        print('no such distribution implemented')


def prob_distr_1d_discrete(x, lam, N, distr='exp'):
    """
    Defines 1d connectivity profile for one neuron in discrete mesh.

    Profiles normalised to 1.

    Parameters:
    -----------
    x: numerical
        distance of neurons given as numerical or array.
    lam: int or float
        Constant of decay / length scale.
    N: int
        Number of neurons (assumed to be odd).
    distr: str
        Decides which distribution is used. Options: 'exp', 'gauss'

    Returns:
    --------
    np.ndarray
        Array of connection probabilities.
    """

    # get network mesh for one neuron
    x_all = distance_matrix_1d(N)[0]

    # normalise by sum over all distances
    return prob_distr_1d(x, lam, distr) / prob_distr_1d(x_all, lam, distr).sum()


def prob_distr_2d(x, lam, distr='exp'):
    """
    Defines 2d profile of connection probability between two neurons.

    Profiles are normalised to 1.

    Parameters:
    -----------
    x: numerical
        distance of neurons given as numerical or array.
    lam: int or float
        Constant of decay / length scale.
    distr: str
        Decides which distribution is used. Options: 'exp', 'gauss'

    Returns:
    --------
    np.ndarray
        Array of connection probabilities.
    """

    if distr == 'exp':
        return 1/(2*np.pi*lam**2) * np.exp(-np.absolute(x) / lam)
    elif distr == 'gauss':
        return 1/(2*np.pi*lam**2) * np.exp(- np.power(x, 2) / (2 * np.power(lam, 2)))
    else:
        print('no such distribution implemented')


def prob_distr_2d_discrete(x, lam, N_x, N_y, distr='exp'):
    """
    Defines 1d connectivity profile for one neuron in discrete mesh.

    Profiles normalised to 1.

    Parameters:
    -----------
    x: numerical
        distance of neurons given as numerical or array.
    lam: int or float
        Constant of decay / length scale.
    N_x: int
        Number of neurons in x direction (assumed to be odd).
    N_y: int
        Number of neurons in y direction (assumed to be odd).
    distr: str
        Decides which distribution is used. Options: 'exp', 'gauss'

    Returns:
    --------
    np.ndarray
        Array of connection probabilities.
    """

    # get network mesh for one neuron
    x_all = distance_matrix_2d(N_x, N_y)

    # normalise by sum over all distances
    return prob_distr_2d(x, lam, distr) / prob_distr_2d(x_all, lam, distr).sum()


def weights_1p(r, K):
    """ Calc weights for network with one population of inhibitory neurons."""
    w = - r / np.sqrt(K)
    return w


def weights_2p(r, K_E, K_I, g):
    """ Calc weights for network with inh and exc population of neurons. """
    w_E = r / np.sqrt(K_E + g**2 * K_I)
    w_I = - g * r / np.sqrt(K_E + g**2 * K_I)
    return w_E, w_I


def create_parameter_string(parameters):
    """ create parameter string """

    temp_parameters = parameters.copy()
    temp_parameters.pop('system', None)
    temp_parameters.pop('setting', None)
    temp_parameters.pop('plot_setting', None)
    temp_parameters.pop('ext', None)
    parameter_string = ''
    keys = list(temp_parameters.keys())
    keys.sort()
    for key in keys:
        value = temp_parameters[key]
        parameter_string += str(key)+str(value[0])+'_'
    parameter_string = parameter_string[:-1]

    return parameter_string
