#include "cell_list.h"
#include "allocate.h"


cell_list::cell_list() {
    this->n_dim = -1;
    this->n_periodic = -1;
    this->n_objects = -1;
    this->r_cutoff_max = 0;
    this->n_shell = 1;
    this->n_linear_hold = 0;
    this->n_objects_hold = 0;
}

/* Initializer for cell_list class */
cell_list::cell_list(int n_dim, int n_periodic, int n_objects,
                     double r_cutoff_max, double **s, double *a_perp) {
    init(n_dim, n_periodic, n_objects, r_cutoff_max, s, a_perp);
}

void cell_list::init(int n_dim, int n_periodic, int n_objects,
                     double r_cutoff_max, double **s, double *a_perp) {
    this->n_dim = n_dim;
    this->n_periodic = n_periodic;
    this->n_objects = n_objects;
    this->r_cutoff_max = r_cutoff_max;
    this->n_shell = 1;
    this->n_linear_hold = 0;
    this->n_objects_hold = 0;

    number_of_cells(a_perp);
    init_cell_structure(0, s);
    update_cell_lists(s);
}

/* This routine initializes the cell lists for cell-based neighbor search routines.
   The object coordinates are assumed to lie within a n_dim-dimensional
   hypercube of unit sidelength centered on the origin (scaled coordinates).

   Input: number of spatial dimensions (n_dim)
   number of objects (n_objects)
   scaled object coordinates (s)
   pointer to cell list (cells)

   Output: the cell list arrays are initialized on output. */
void cell_list::init_cell_lists(double **s) {
    /* Purge cell lists. */
    for (int i_linear = 0; i_linear < n_linear_hold; ++i_linear) {
        this->head[i_linear] = -1;
        this->tail[i_linear] = -1;
    }
    for (int i_object = 0; i_object < n_objects_hold; ++i_object) {
        this->prev[i_object] = -1;
        this->next[i_object] = -1;
        this->cell[i_object] = -1;
    }

    /* Assign objects to cells and add entries to cell lists. */
    for (int i_object = 0; i_object < n_objects; ++i_object) {
        int i_linear = cell_index(s[i_object]);
        add_cell_entry(i_object, i_linear);
    }

    return;
}

/* Allocate memory for and initialize variables in cell list structure.
   The structure is initialized if init_flag == 0 and re-initialized
   if init_flag == 1 (typically done when the number of cells changes).

   Input: initialization flag (init_flag)
   number of spatial dimensions (n_dim)
   number of periodic dimensions (n_periodic)
   number of objects (n_objects)
   array of scaled object coordinates (s)
   number of cells per interaction range (n_shell)
   pointer to cell list structure (cells)

   Output: memory for arrays in the cell list structure is allocated,
   and a number of variables in this structure are initialized. */
void cell_list::init_cell_structure(int init_flag, double **s) {
    if (this->n_linear_hold == 0)
        this->n_linear_hold = this->n_linear;
    if (this->n_objects_hold == 0)
        this->n_objects_hold = this->n_objects;

    int n_linear_alloc = this->n_linear + (this->n_linear > this->n_linear_hold) * 1000;

    /* Allocate memory for cell list arrays. */
    if (init_flag) {
        if (this->n_linear > this->n_linear_hold) {
            this->head = (int *) grealloc(this->head, (n_linear_alloc) * sizeof(int));
            this->tail = (int *) grealloc(this->tail, (n_linear_alloc) * sizeof(int));
        }
        if (this->n_objects > this->n_objects_hold) {
            while (this->n_objects_hold < this->n_objects)
                this->n_objects_hold = this->n_objects_hold + 100;

            this->cell = (int *) grealloc(this->cell, this->n_objects_hold * sizeof(int));
            this->prev = (int *) grealloc(this->prev, this->n_objects_hold * sizeof(int));
            this->next = (int *) grealloc(this->next, this->n_objects_hold * sizeof(int));
        }
    } else {
        this->head = (int *) allocate_1d_array(this->n_linear, sizeof(int));
        this->tail = (int *) allocate_1d_array(this->n_linear, sizeof(int));
        this->cell = (int *) allocate_1d_array(this->n_objects, sizeof(int));
        this->prev = (int *) allocate_1d_array(this->n_objects, sizeof(int));
        this->next = (int *) allocate_1d_array(this->n_objects, sizeof(int));
    }

    /* Compute maximum size of search domain. */
    int n_search_max = 1;
    int search_range = 2 * this->n_shell + 1;
    for (int i = 0; i < n_dim; ++i)
        n_search_max *= search_range;

    /* Allocate memory for cell search lists. */
    if (init_flag) {
        if (this->n_linear > this->n_linear_hold) {
            this->n_search_half =
                (int *) grealloc((void *) this->n_search_half, (n_linear_alloc) * sizeof(int));
            this->n_search_full =
                (int *) grealloc((void *) this->n_search_full, (n_linear_alloc) * sizeof(int));
            this->search_cell_half =
                (int **) realloc_2d_array((void **) this->search_cell_half, this->n_linear_hold,
                                          n_linear_alloc, n_search_max, sizeof(int));
            this->search_cell_full =
                (int **) realloc_2d_array((void **) this->search_cell_full, this->n_linear_hold,
                                          n_linear_alloc, n_search_max, sizeof(int));
            this->shift_cell_half =
                (int ***) realloc_3d_array((void ***) this->shift_cell_half, this->n_linear_hold,
                                           n_search_max, n_linear_alloc, n_search_max, this->n_dim,
                                           sizeof(int));
            this->shift_cell_full =
                (int ***) realloc_3d_array((void ***) this->shift_cell_full, this->n_linear_hold,
                                           n_search_max, n_linear_alloc, n_search_max, this->n_dim,
                                           sizeof(int));
            this->n_linear_hold = n_linear_alloc;
        }
    } else {
        this->n_search_half = (int *) allocate_1d_array(this->n_linear, sizeof(int));
        this->search_cell_half = (int **) allocate_2d_array(this->n_linear, n_search_max, sizeof(int));
        this->shift_cell_half
            = (int ***) allocate_3d_array(this->n_linear, n_search_max, n_dim, sizeof(int));
        this->n_search_full = (int *) allocate_1d_array(this->n_linear, sizeof(int));
        this->search_cell_full = (int **) allocate_2d_array(this->n_linear, n_search_max, sizeof(int));
        this->shift_cell_full
            = (int ***) allocate_3d_array(this->n_linear, n_search_max, this->n_dim, sizeof(int));
    }

    /* Initialize cell lists. */
    init_cell_lists(s);

    for (int i = 0; i < this->n_linear_hold; ++i)
        this->n_search_half[i] = this->n_search_full[i] = 0;

    int i_cell[3], j_cell[3], shift[3];
    /* Initialize half space cell search lists. */
    cell_search_list(1, this->n_dim - 1, i_cell, j_cell, shift);

    /* Initialize full space cell search lists. */
    cell_search_list(0, this->n_dim - 1, i_cell, j_cell, shift);
    
    return;
}

/* This routine computes the multidimensional and linear cell index for a specified object,
   for any type of boundary conditions and any number of spatial dimensions.
   The object coordinate is assumed to lie within a n_dim-dimensional
   hypercube of unit sidelength centered on the origin (scaled coordinates).

   Input: number of spatial dimensions (n_dim)
   array of number of cells in each dimension (n_cells)
   scaled object coordinate (s)

   Output: multidimensional cell index for the specified object (i_cell)
   linear cell index for the specified object (return value) */
int cell_list::cell_index(double *s) {
    int i_cell[3];

    /* Compute multidimensional cell index. */
    for (int i = 0; i < n_dim; ++i) {
        i_cell[i] = (int) ((s[i] + 0.5) * this->n_cells[i]);
        if (i_cell[i] == this->n_cells[i])
            i_cell[i] -= 1;
    }

    /* Convert multidimensional cell index to linear cell index and return. */
    return linear_index(0, n_dim, i_cell, this->n_cells);
}

/* This routine sets up an array containing the linear indices of cells within the search
   domain of all cells, for any type of boundary conditions (free, periodic, or mixed).
   An array of relative cell shifts is also compiled to enable efficient treatment of
   periodic boundary conditions. A recursive implementation is used to permit handling
   of any number of dimensions.

   Input: flag that controls whether full or half space of cells is included (half_flag)
   current level in recursive function call (i_dim)
   number of spatial dimensions (n_dim)
   number of cells per interaction range (n_shell)
   number of periodic dimensions (n_periodic)
   array of number of cells along each dimension (n_cells)
   multidimensional cell index for the home cell (i_cell)
   multidimensional cell index for the search cell (j_cell)
   array of relative cell shifts (shift)

   Output: array of number of cells in search domain of home cell (n_search)
   array of labels of cells in search domain of home cell (search_cell)
   array of cell shifts for applying periodic boundary conditions (shift_cell) */
void cell_list::cell_search_list(int half_flag, int i_dim, int *i_cell, int *j_cell, int *shift) {
    
    /* Loop over home cells along current dimension. */
    for (i_cell[i_dim] = 0; i_cell[i_dim] < this->n_cells[i_dim]; ++i_cell[i_dim]) {

        /* Loop over cells in search domain of home cell along current dimension. */
        for (int i_shift = -this->n_shell; i_shift <= this->n_shell; ++i_shift) {
            j_cell[i_dim] = i_cell[i_dim] + i_shift;

            /* Take boundary conditions into account. */
            shift[i_dim] = 0;
            int include = 1;
            if (j_cell[i_dim] >= this->n_cells[i_dim]) {
                if (i_dim < this->n_periodic) {       /* Periodic boundary conditions. */
                    j_cell[i_dim] -= this->n_cells[i_dim];
                    shift[i_dim] = -1;
                } else          /* Free boundary conditions. */
                    include = 0;
            } else if (j_cell[i_dim] < 0) {
                if (i_dim < n_periodic) {       /* Periodic boundary conditions. */
                    j_cell[i_dim] += this->n_cells[i_dim];
                    shift[i_dim] = 1;
                } else          /* Free boundary conditions. */
                    include = 0;
            }

            /* If bottom level of recursion is reached, compute linear indices of
               home cell and search cell and add entries to cell search arrays.
               Otherwise, proceed to next level of recursion. */
            if (include) {
                int *n_search;
                int **search_cell;
                int ***shift_cell;
                if (half_flag == 0) {
                    n_search = this->n_search_full;
                    search_cell = this->search_cell_full;
                    shift_cell = this->shift_cell_full;
                }
                else {
                    n_search = this->n_search_half;
                    search_cell = this->search_cell_half;
                    shift_cell = this->shift_cell_half;
                }
                
                if (i_dim == 0) {
                    int i_linear = linear_index(0, n_dim, i_cell, this->n_cells);
                    int j_linear = linear_index(0, n_dim, j_cell, this->n_cells);

                    /* If half_flag == 1, only add entries for "positive" half space. */
                    if (!half_flag || j_linear >= i_linear) {
                        search_cell[i_linear][n_search[i_linear]] = j_linear;
                        for (int i = 0; i < n_dim; ++i)
                            shift_cell[i_linear][n_search[i_linear]][i] = shift[i];
                        ++n_search[i_linear];
                    }
                } else {
                    cell_search_list(half_flag, i_dim - 1, i_cell, j_cell, shift);
                }
            }
        }
    }

    return;
}

/* This routine adds the entry for a specified object to a cell list.

   Input: label of object to add to cell list (i_object)
   linear index of cell to which to add object (i_linear)
   array of linear cell indices corresponding to objects (cell)
   array of labels of the first object in each cell (head)
   array of labels of the last object in each cell (tail)
   array of labels of the previous object in each cell (prev)
   array of labels of the next object in each cell (next)

   Output: the cell list arrays are modified on output. */
void cell_list::add_cell_entry(int i_object, int i_linear) {
    /* Assign linear cell index to object. */
    this->cell[i_object] = i_linear;

    /* Prepend object to the head of the list. */
    int old_head = this->head[i_linear];
    this->head[i_linear] = i_object;
    this->prev[i_object] = -1;
    this->next[i_object] = old_head;
    if (old_head == -1)
        this->tail[i_linear] = i_object;
    else
        this->prev[old_head] = i_object;

    return;
}

/* This routine checks to see if the number of cells along each spatial dimension has changed,
   and re-initializes the cell lists if it has.

   Input: number of spatial dimensions (n_dim)
   number of periodic spatial dimensions (n_periodic)
   array of perpendicular distances between opposite cell faces (a_perp)
   number of objects (n_objects)
   array of scaled object positions (s_object)
   number of cells per interaction distance (n_shell)
   maximum interaction cutoff distance (r_cutoff_max)
   pointer to cell list structure (cells)

   Output: update flag (return value)
   the cell list structure is modified on output */
int cell_list::check_cell_lists(int n_objects, double *a_perp,
                                double **s_object, double r_cutoff) {
    this->r_cutoff_max = r_cutoff;
    /* Copy old number of cells and compare */
    int n_cells_old[3];
    for (int i = 0; i < n_dim; ++i)
        n_cells_old[i] = this->n_cells[i];
    number_of_cells(a_perp);

    /* Re-initialize cell lists if necessary. */
    int update = this->n_objects != n_objects;
    for (int i = 0; i < n_dim; ++i)
        update = update || this->n_cells[i] != n_cells_old[i];
    if (update) {
        this->n_objects = n_objects;
        init_cell_structure(1, s_object);
    }

    return update;
}

/* This routine removes the entry for a specified object from a cell list.

   Input: label of object to remove from cell list (i_object)
   linear index of cell from which to remove object (i_linear)
   array of cell indices corresponding to objects (cell)
   array of labels of the first object in each cell (head)
   array of labels of the last object in each cell (tail)
   array of labels of the previous object in each cell (prev)
   array of labels of the next object in each cell (next)

   Output: the cell list arrays are modified on output. */
void cell_list::remove_cell_entry(int i_object, int i_linear) {

    /* Reset prev and next pointers for previous and next entries in cell list,
       and reset head and/or tail if necessary. */
    int i_prev = this->prev[i_object];
    int i_next = this->next[i_object];
    if (i_prev == -1)
        this->head[i_linear] = i_next;
    else
        this->next[i_prev] = i_next;
    if (i_next == -1)
        this->tail[i_linear] = i_prev;
    else
        this->prev[i_next] = i_prev;

    /* Remove object from cell list. */
    this->prev[i_object] = -1;
    this->next[i_object] = -1;
    this->cell[i_object] = -1;

    return;
}

/* This routine updates the cell lists for cell-based neighbor search routines.
   The object coordinates are assumed to lie within a n_dim-dimensional
   hypercube of unit sidelength centered on the origin (scaled coordinates).

   Input: number of spatial dimensions (n_dim)
   number of objects (n_objects)
   scaled object coordinates (s)
   pointer to cell list (cells)

   Output: the cell list arrays are initialized on output. */
void cell_list::update_cell_lists(double **s) {
    /* Loop over objects, assigning objects to cells and updating cell entries as necessary. */
    for (int i_object = 0; i_object < n_objects; ++i_object)
        update_cell_lists_single(i_object, s[i_object]);

    return;
}

/* This routine updates the cell lists for a single object, for cell-based neighbor search routines.
   The object coordinates are assumed to lie within a n_dim-dimensional
   hypercube of unit sidelength centered on the origin (scaled coordinates).

   Input: object label (i_object)
          number of spatial dimensions (n_dim)
          scaled object coordinate (s)
          pointer to cell list (cells)

   Output: the cell list arrays are initialized on output. */
void cell_list::update_cell_lists_single(int i_object, double *s) {
    /* Compute new cell index. */
    int i_linear = cell_index(s);

    /* If cell index has changed, modify cell lists. */
    if (i_linear != cell[i_object]) {
        remove_cell_entry(i_object, this->cell[i_object]);
        add_cell_entry(i_object, i_linear);
    }

    return;
}

/* This routine assigns a linear index to the nodes of a multidimensional grid.
   A recursive implementation is used to permit handling of any number of dimensions.

   Input: current level in recursive function call (i_dim)
          number of spatial dimensions (n_dim)
          array of indices specifying multidimensional grid point (i_grid)
          array of number of grid points in each dimension (n_grid)

   Output: the return value is the linear index corresponding to the specified
           multidimensional grid point. */
int cell_list::linear_index(int i_dim, int n_dim, int *i_grid, int *n_grid) {
    /* Compute linear index using recursion. */
    if (i_dim == n_dim - 1)
        return i_grid[i_dim];
    else
        return i_grid[i_dim] + n_grid[i_dim] * linear_index(i_dim + 1, n_dim, i_grid, n_grid);
}

/* This routine calculates the number of cells along each dimension of the unit cell and
   the total number of cells, for any type of boundary condition (free, periodic, or mixed).
   This routine requires an array containing the orthogonal component of each unit cell vector
   (the part of the unit cell vector that is orthogonal to all other unit cell vectors).

   Input: number of spatial dimensions (n_dim)
          number of cells per interaction cutoff distance (n_shell)
          interaction cutoff distance (r_cutoff)
          array of orthogonal component of unit cell vectors (a_perp)

   Output: array of number of cells along each dimension (n_cells)
           total number of cells (return value) */
void cell_list::number_of_cells(double *a_perp) {
    /* Compute number of cells along each dimension of the unit cell and total number of cells. */
    this->n_linear = 1;
    for (int i = 0; i < this->n_dim; ++i) {
        this->n_cells[i] = (int) (this->n_shell * a_perp[i] / this->r_cutoff_max);
        this->n_linear *= n_cells[i];
    }
}
