/* This routine calculates the forces, potential energy, and virial for sites
   interacting via a lj plus generalized ramp potential, for any type of boundary condition
   (free, periodic, or mixed) and any number of dimensions, using an all-pairs search.

   Input: pointer to parameters structure (parameters)
   pointer to properties structure (properties)
   potential/force component index (i_comp)

   Output: array of forces (f)
   virial (virial)
   potential energy (return value) */

#include "bob.h"
#include "xlink_management.h"
#include "xlink_entry.h"

#include <iostream>

double crosslink_interaction_bd_mp(system_parameters *parameters,
                                   system_properties *properties, double **f_bond,
                                   double **virial, double **t_bond, int *calc_matrix) {

    double ***f_local = properties->mp_local.f_local;
    double ***t_local = properties->mp_local.t_local;
    int **calc_local = properties->mp_local.calc_local;
    double ***virial_local = properties->mp_local.virial_local;
    double ****virial_local_xlink = properties->mp_local.virial_local_xlink;

    /* Zero potential energy. */

    double u = 0.0;
    #ifdef ENABLE_OPENMP
    #pragma omp parallel 
    #endif
    {
        /* Get current thread number to know which accumulator to work on */
        int i_thr = 0;
        #ifdef ENABLE_OPENMP
        i_thr = omp_get_thread_num();
        #endif

        memset(&(f_local[i_thr][0][0]), 0,
               properties->bonds.n_bonds * parameters->n_dim * sizeof(double));
        memset(&(t_local[i_thr][0][0]), 0,
               properties->bonds.n_bonds * 3 * sizeof(double));
        memset(calc_local[i_thr], 0, properties->bonds.n_bonds * sizeof(int));
        if (properties->control.virial_flag) {
            memset(&(virial_local[i_thr][0][0]), 0,
                   parameters->n_dim * parameters->n_dim * sizeof(double));
            // For some reason memset doesn't work on the ones that I've defined...
            for (int itype = 0; itype < properties->crosslinks.n_types_; ++itype) {
                for(int i = 0; i < parameters->n_dim; ++i) {
                    for (int j = 0; j < parameters->n_dim; ++j) {
                        virial_local_xlink[i_thr][itype][i][j] = 0.0;
                    }
                }
            }
        }

        /* Loop over all attached pairs. */
        #ifdef ENABLE_OPENMP
        #pragma omp for reduction(+:u) schedule(runtime)
        #endif
        for (int i_bond = 0; i_bond < properties->bonds.n_bonds; ++i_bond) {
            for (int i_type = 0; i_type < properties->crosslinks.n_types_; ++i_type) {
                for (xlink_list::iterator xlink = properties->crosslinks.stage_2_xlinks_[i_type][i_bond].begin();
                     xlink < properties->crosslinks.stage_2_xlinks_[i_type][i_bond].end();
                     xlink++) {

                    if(xlink->IsActive()) {
                        double f_link[3] = {0.0, 0.0, 0.0};
                        u += xlink->CalcForce(parameters->n_dim,
                                              parameters->n_periodic,
                                              properties->unit_cell.h,
                                              properties->bonds.r_bond,
                                              properties->bonds.s_bond,
                                              properties->bonds.u_bond,
                                              properties->bonds.length,
                                              f_link);

                        int bond_1 = xlink->head_parent_[0];
                        int bond_2 = xlink->head_parent_[1];
                        for (int i = 0; i < parameters->n_dim; ++i) {
                            f_local[i_thr][bond_1][i] += f_link[i];
                            f_local[i_thr][bond_2][i] -= f_link[i];
                        }

                        /* Calculate torques */
                        double lambda = xlink->cross_position_[0] - 0.5 * properties->bonds.length[bond_1];
                        double mu =     xlink->cross_position_[1] - 0.5 * properties->bonds.length[bond_2];
                        /* Calculate torques */
                        double r_contact_i[3] = {0.0, 0.0, 0.0};
                        for (int i = 0; i < parameters->n_dim; ++i)
                            r_contact_i[i] = properties->bonds.u_bond[bond_1][i] * lambda;
                        
                        double r_contact_j[3] = {0.0, 0.0, 0.0};
                        for (int i = 0; i < parameters->n_dim; ++i)
                            r_contact_j[i] = properties->bonds.u_bond[bond_2][i] * mu;
                        
                        double tau[3];
                        cross_product(r_contact_i, f_link, tau, parameters->n_dim);
                        for (int i = 0; i < 3; ++i)
                            t_local[i_thr][bond_1][i] += tau[i];
                        cross_product(r_contact_j, f_link, tau, parameters->n_dim);
                        for (int i = 0; i < 3; ++i)
                            t_local[i_thr][bond_2][i] -= tau[i];

                        /* Add contribution to virial. */
                        if (properties->control.virial_flag) {
                            for (int i = 0; i < parameters->n_dim; ++i) {
                                for (int j = 0; j < parameters->n_dim; ++j) {
                                    virial_local_xlink[i_thr][i_type][i][j] -= xlink->dr_[i] * f_link[j];
                                    virial_local[i_thr][i][j] -= xlink->dr_[i] * f_link[j];
                                }
                            }
                        }

                        /* Walk xlink along spherocylinders */
                        // FIXME: This is not thread safe calling of the RNG, should use the bond specific one
                        //xlink->Step(parameters->n_dim,
                        //            parameters->delta,
                        //            properties->bonds.u_bond,
                        //            properties->bonds.length,
                        //            f_link,
                        //            properties->rng.r);
                        // This should be thread-safe and use the proper random number generator
                        xlink->Step(parameters->n_dim,
                                    parameters->delta,
                                    properties->bonds.u_bond,
                                    properties->bonds.length,
                                    f_link,
                                    properties->bonds.rng_local[i_bond].r);

                        calc_local[i_thr][bond_1] = calc_local[i_thr][bond_2] = 1;
                    }
                }
            }
        }
    }

    memset(f_bond[0], 0, properties->bonds.n_bonds * parameters->n_dim * sizeof(double));
    memset(t_bond[0], 0, properties->bonds.n_bonds * 3 * sizeof(double));

    #ifdef ENABLE_OPENMP
    #pragma omp parallel for
    #endif
    for (int i_bond = 0; i_bond < properties->bonds.n_bonds; ++i_bond) {
        for (int i_thr = 0; i_thr < properties->mp_local.n_threads; ++i_thr) {
            if (calc_local[i_thr][i_bond]) {
                for (int i = 0; i < parameters->n_dim; ++i) {
                    f_bond[i_bond][i] += f_local[i_thr][i_bond][i];
                }
                for (int i = 0; i < 3; ++i)
                    t_bond[i_bond][i] += t_local[i_thr][i_bond][i];

                calc_matrix[i_bond] = 1;
            }
        }
    }

    if (properties->control.virial_flag) {
        memset(virial[0], 0, parameters->n_dim * parameters->n_dim * sizeof(double));
        for (int i_thr = 0; i_thr < properties->mp_local.n_threads; ++i_thr)
            for (int i = 0; i < parameters->n_dim; ++i)
                for (int j = 0; j < parameters->n_dim; ++j)
                    virial[i][j] += virial_local[i_thr][i][j];

        // Contribution broken out by xlink type
        //memset(properties->thermo.virial_xlink[0], 0, properties->crosslinks.n_types_ * parameters->n_dim * parameters->n_dim * sizeof(double));
        for (int itype = 0; itype < properties->crosslinks.n_types_; ++itype) {
            for (int i = 0; i < parameters->n_dim; ++i) {
                for (int j = 0; j < parameters->n_dim; ++j) {
                    properties->thermo.virial_xlink[itype][i][j] = 0.0;
                }
            }
        }
        for (int i_thr = 0; i_thr < properties->mp_local.n_threads; ++i_thr) {
            for (int itype = 0; itype < properties->crosslinks.n_types_; ++itype) {
                for (int i = 0; i < parameters->n_dim; ++i) {
                    for (int j = 0; j < parameters->n_dim; ++j) {
                        properties->thermo.virial_xlink[itype][i][j] += virial_local_xlink[i_thr][itype][i][j];
                    }
                }
            }
        }
    }

    return u;
}
