# -*- coding: utf-8 -*-
#
# This source file is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This source file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

'''
Creates data for fig. 3b and 3c.

Usage: create_pairwise_covariances_and_eigenvalues.py [options] --yaml=<yaml_filename

Options:
    -h, --help      show this information
'''
import docopt

import numpy as np
import yaml
import os
import scipy.sparse, scipy.sparse.linalg


def simulation_2d2p(N_x, N_y, g, D, K_E, K_I, r, lam_E, lam_I, distr, seed, return_pairwise=False):
    """
    Calc mean and var of covariances for a realisation of 1d neuronal network
    with two populations.

    Calculates mean and variance of covariances of the neuron activity in a
    balanced state. For this reason, a realisation of a spatially organised, 1d
    neuronal network with periodic boundary conditions and two populations is
    drawn and mean and variance as functions of distance are calculated from the
    effective connectivity matrix following the work of Dahmen et al. (2018).

    For simplicity we assume that neurons do not distinguish between excitatory
    and inhibitory target neurons and hence
    K_EE = K_IE = K_E, K_II = K_EI = K_I, w_EE = w_IE = w_E, w_II = w_EI = w_I.
    Additionally, we assume balance
    w_I = -g * w_E.

    NOTE: in the current implementation we always set K_EE = K_IE, K_II = K_EI.

    NOTE: N is assumed to be odd.

    Parameters:
    -----------
    N: int
        Number of inhibitory neurons (assumed to be odd).
    g: int
        Ratio of number of excitatory neurons to number of inhibitory neurons.
    D: float
        Squared amplitude of white Gaussian input noise.
    K_E: int
        Number of connections from E to E or I population (assumed to be the
        same).
    K_I: int
        Number of connections from I to E or I population (assumed to be the
        same).
    r: float
        Spectral radius of effective connectivity matrix.
    lam_E: int
        Decay constant of connectivity profile of E connections.
    lam_I: int
        Decay constant of connectivity profile of I connections.
    distr: str
        Decides which connectivity profile is used. Options: 'exp', 'gauss'.

    Returns:
    --------
    c_mean_x: np.ndarray:
        Array of means of covariances with increasing distance.
    c_var_x: np.ndarray
        Array of variances of covariances with increasing distance.
    distances: np.ndarray
        Unique distances in network, sorted increasingly.
    """

    # calculate weights from spectral radius assuming that neurons do not
    # distinguish between excitatory and inhibitory target neurons
    w_E, w_I = weights_2p(r, K_E, K_I, g)

    # enforce assumption of not distinguished target neurons
    K_EE = K_E
    K_IE = K_E
    K_EI = K_I
    K_II = K_I
    w_EE = w_E
    w_IE = w_E
    w_EI = w_I
    w_II = w_I
    lam_EE = lam_E
    lam_IE = lam_E
    lam_EI = lam_I
    lam_II = lam_I

    # draw a connectivity matrix
    W, X, distances = draw_binomial_2d2p(N_x, N_y, g,
                                         K_EE, K_EI, K_IE, K_II,
                                         w_EE, w_EI, w_IE, w_II,
                                         lam_EE, lam_EI, lam_IE, lam_II,
                                         distr, seed)

    # calculate covariance matrix C
    A = np.identity((g+1)*N_x*N_y) - W
    inv_A = np.linalg.inv(A)
    C = D*np.matmul(inv_A, inv_A.T)

    # divide C into its subparts
    C_E, C_I = np.split(C, [g*N_y*N_x])
    C_EE, C_EI = np.split(C_E, [g*N_x*N_y], 1)
    C_IE, C_II = np.split(C_I, [g*N_x*N_y], 1)
    # divide X into its subparts
    X_E, X_I = np.split(X, [g*N_x*N_y])
    X_EE, X_EI = np.split(X_E, [g*N_x*N_y], 1)
    X_IE, X_II = np.split(X_I, [g*N_x*N_y], 1)

    # In the following, we calculate the statistics of the covariance matrix
    # create empty variables
    c_mean_EE = np.zeros(len(distances))
    c_mean_EI = np.zeros(len(distances))
    c_mean_II = np.zeros(len(distances))
    c_var_EE = np.zeros(len(distances))
    c_var_EI = np.zeros(len(distances))
    c_var_II = np.zeros(len(distances))

    C = [C_EE, C_EI, C_II]
    X = [X_EE, X_EI, X_II]
    
    if return_pairwise:
        return X, C
    
    c_mean = [c_mean_EE, c_mean_EI, c_mean_II]
    c_var = [c_var_EE, c_var_EI, c_var_II]

    # iterate over all population combinations
    for i, C in enumerate(C):
        # iterate over all distances
        for j, x in enumerate(distances):
            # get covariances corresponding to current distance
            c = C[X[i] == x]
            # calculate mean and variance for given distance
            c_mean[i][j] = np.mean(c)
            c_var[i][j] = np.var(c)

    return c_mean, c_var, distances


def prob_distr_2d(x, lam, distr='exp'):
    """
    Defines 2d profile of connection probability between two neurons.

    Profiles are normalised to 1.

    Parameters:
    -----------
    x: numerical
        distance of neurons given as numerical or array.
    lam: int or float
        Constant of decay / length scale.
    distr: str
        Decides which distribution is used. Options: 'exp', 'gauss'

    Returns:
    --------
    np.ndarray
        Array of connection probabilities.
    """

    if distr == 'exp':
        return 1/(2*np.pi*lam**2) * np.exp(-np.absolute(x) / lam)
    elif distr == 'gauss':
        return 1/(2*np.pi*lam**2) * np.exp(- np.power(x, 2) / (2 * np.power(lam, 2)))
    else:
        print('no such distribution implemented')


def prob_distr_2d_discrete(x, lam, N_x, N_y, distr='exp'):
    """
    Defines 1d connectivity profile for one neuron in discrete mesh.

    Profiles normalised to 1.

    Parameters:
    -----------
    x: numerical
        distance of neurons given as numerical or array.
    lam: int or float
        Constant of decay / length scale.
    N_x: int
        Number of neurons in x direction (assumed to be odd).
    N_y: int
        Number of neurons in y direction (assumed to be odd).
    distr: str
        Decides which distribution is used. Options: 'exp', 'gauss'

    Returns:
    --------
    np.ndarray
        Array of connection probabilities.
    """

    # get network mesh for one neuron
    x_all = distance_matrix_2d(N_x, N_y)

    # normalise by sum over all distances
    return prob_distr_2d(x, lam, distr) / prob_distr_2d(x_all, lam, distr).sum()


def distance_matrix_2d(N_x, N_y):
    """
    Calculate distance matrix for 2d network with periodic boundary conditions
    for a single neuron.

    Parameters:
    -----------
    N_x: int (odd)
        Number of neurons in x direction.
    N_y: int (odd)
        Number of neurons in y direction.

    Returns:
    --------
    distance_matrix: np.ndarray
        2d distance matrix.
    """

    # number of neurons one neuron can connect to (assumes M odd)
    M_x = (N_x-1)/2
    M_y = (N_y-1)/2

    # create distance matrix
    y = np.arange(N_y)
    # account for periodic boundary conditions
    y = np.where(y > M_y, N_y - y, y)

    # create distance matrix
    x = np.arange(N_x)
    # account for periodic boundary conditions
    x = np.where(x > M_x, N_x - x, x)
    # create 2d coordinates
    x_coord, y_coord = np.meshgrid(x, y)
    # create distance matrix for single neuron
    distance_matrix = np.sqrt(x_coord**2 + y_coord**2)

    return distance_matrix


def draw_binomial_2d2p(N_x, N_y, g,
                       K_EE, K_EI, K_IE, K_II,
                       w_EE, w_EI, w_IE, w_II,
                       lam_EE, lam_EI, lam_IE, lam_II,
                       distr, seed=42):
    """
    Calculates connectivity matrix of a 2d network with two populations,
    periodic boundary conditions, and binomially distributed weights. Multiple
    connections between pairs of neurons are allowed.

    Parameters:
    -----------
    N_x: int
        Number of inhibitory neurons in x-direction (assumed to be odd).
    N_y: int
        Number of inhibitory neurons in y-direction (assumed to be odd).
    g: int
        Ratio of number of excitatory neurons to number of inhibitory neurons.
    K_EE: int
        Number of connections from E to E.
    K_EI: int
        Number of connections from I to E.
    K_IE: int
        Number of connections from E to I.
    K_II: int
        Number of connections from I to I.
    w_EE: float
        Connectivity weights of EE connections.
    w_EI: float
        Connectivity weights of EI connections.
    w_IE: float
        Connectivity weights of IE connections.
    w_II: float
        Connectivity weights of II connections.
    lam_EE: int
        Decay constant of connectivity profile of EE connections.
    lam_EI: int
        Decay constant of connectivity profile of EI connections.
    lam_IE: int
        Decay constant of connectivity profile of IE connections.
    lam_II: int
        Decay constant of connectivity profile of II connections.
    distr: str
        Decides which connectivity profile is used. Options: 'exp', 'gauss'.

    Returns:
    --------
    np.ndarray
        Connectivity matrix.
    np.ndarray
        Distance matrix.
    np.ndarray
        Unique distances.
    """


    # calculate Euclidean distances for a single neuron
    distance_matrix_single = distance_matrix_2d(N_x, N_y)

    # build distance matrix for network from distance matrix of single neuron
    distance_matrix = np.zeros((N_x*N_y, N_x*N_y))
    n = 0
    for i in range(N_x):
        distance_matrix_single_temp = np.roll(distance_matrix_single, i, 0)
        for j in range(N_y):
            distance_matrix[n] = np.roll(distance_matrix_single_temp, j, 1).flatten()
            n += 1

    # get unique distances
    distances = np.unique(distance_matrix)[1:]

    # create distance matrices taking into account g exc neurons per inh neuron
    X_EE = np.repeat(np.repeat(distance_matrix, g, 0), g, 1)
    X_EI = np.repeat(distance_matrix, g, 0)
    X_IE = np.repeat(distance_matrix, g, 1)
    X_II = distance_matrix

    # connection probabilty
    P_EE = prob_distr_2d_discrete(X_EE, lam_EE, N_x, N_y, distr)
    P_EI = prob_distr_2d_discrete(X_EI, lam_EI, N_x, N_y, distr)
    P_IE = prob_distr_2d_discrete(X_IE, lam_IE, N_x, N_y, distr)
    P_II = prob_distr_2d_discrete(X_II, lam_II, N_x, N_y, distr)

    # weight matrices
    # divide connection probability of EE and IE connections to account for
    # higher number of excitatory neurons
    np.random.seed(seed)
    W_EE = w_EE * np.random.binomial(K_EE, P_EE/g)
    W_EI = w_EI * np.random.binomial(K_EI, P_EI)
    W_IE = w_IE * np.random.binomial(K_IE, P_IE/g)
    W_II = w_II * np.random.binomial(K_II, P_II)

    # build connectivity matrix from submatrices
    W = np.vstack([np.hstack([W_EE, W_EI]), np.hstack([W_IE, W_II])])
    # build distance matrix from submatrices
    X = np.vstack([np.hstack([X_EE, X_EI]), np.hstack([X_IE, X_II])])

    # return connectivity matrix
    return W, X, distances


def weights_2p(r, K_E, K_I, g):
    """ Calc weights for network with inh and exc population of neurons. """
    w_E = r / np.sqrt(K_E + g**2 * K_I)
    w_I = - g * r / np.sqrt(K_E + g**2 * K_I)
    return w_E, w_I


if __name__ == '__main__':

    args = docopt.docopt(__doc__)
    yaml_file = args['--yaml']
    
    # load parameters from yaml file
    with open(yaml_file, 'r') as stream:
        try:
            parameters = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    
    N_x = parameters['N_x'][0]
    N_y = parameters['N_y'][0]
    D = parameters['D'][0]
    g = parameters['g'][0]
    K_E = parameters['K_E'][0]
    K_I = parameters['K_I'][0]
    r = parameters['r'][0]
    lam_E = parameters['lam_E'][0]
    lam_I = parameters['lam_I'][0]
    distr = parameters['distr'][0]
    seed = parameters['seed'][0]
    
    print('start calculating covariances')
    
    distance_matrices, covariances = simulation_2d2p(N_x, N_y, g, D, K_E, K_I,
                                                     r, lam_E, lam_I, distr,
                                                     seed,
                                                     return_pairwise=True)
    
    path = 'data/'
    if not os.path.exists(path):
        os.makedirs(path)
    
    np.save(path + 'covariances.npy', covariances[0])
    np.save(path + 'distances.npy', distance_matrices[0])
    
    del distance_matrices
    del covariances
    
    print('finished calculating covariances')
    
    print('start drawing connectivity matrix')
    w_E, w_I = weights_2p(r, K_E, K_I, g)
    
    W, X, distances = draw_binomial_2d2p(N_x, N_y, g,
                                         K_E, K_I, K_E, K_I,
                                         w_E, w_I, w_E, w_I,
                                         lam_E, lam_I, lam_E, lam_I,
                                         distr, seed=42)
    print('finished drawing connectivity matrix')
    
    print('start calculating eigenvalues')
    
    W = np.float32(W)
    eigvals = np.linalg.eig(W)[0]
    
    # if your computer cannot calc the eigenvals due to finite resources, you
    # could try to use the following two lines to calculate the eigenvalues
    # W = scipy.sparse.csr_matrix(W)
    # eigvals = scipy.sparse.linalg.eigs(W, k=W.shape[0]-2,
    #                                    return_eigenvectors=False)
    print('finished calculation eigenvalues')
    
    np.save(path + 'eigenvalues.npy', eigvals)
