// Kinetochore Microtubule repulsive harmonic potential

#include "bob.h"

#include "kinetochore.h"
#include "minimum_distance.h"

#include <iostream>

double chromosome_mt_soft_gaussian_potential_allpairs(system_parameters *parameters, system_properties *properties,
                           double **f_bond, double **virial, double **t_bond, int *calc_matrix) {
    double u = 0.0;

    // Set up shortcuts
    int ndim = parameters->n_dim;
    int nperiodic = parameters->n_periodic;
    int n_bonds = properties->bonds.n_bonds;
    int nchromosomes = properties->chromosomes.nchromosomes_;
    int nkcs = properties->chromosomes.nkcs_;
    double kcoffset = properties->chromosomes.chromatid_kc_offset_;
    double chromatidlength = properties->chromosomes.chromatid_length_;
    double chromatidamp = properties->chromosomes.chromatid_mt_repulsion_;
    double rcutoff2 = properties->chromosomes.chromatid_mt_cutoff2_;
    double sigma = properties->chromosomes.chromatid_mt_sigma_;
    double **h = properties->unit_cell.h;
    double **r_bond = properties->bonds.r_bond;
    double **s_bond = properties->bonds.s_bond;
    double **u_bond = properties->bonds.u_bond;
    double *l_bond = properties->bonds.length;
    double **f_kc = properties->chromosomes.f_;
    double **t_kc = properties->chromosomes.t_;
    double **ftip = properties->chromosomes.fkcmttip_;

    ChromosomeManagement* chromosomes = &properties->chromosomes;

    double sigma2 = SQR(sigma);
    double sigma3 = CUBE(sigma);
    double root2pi = sqrt(2*M_PI);

    if (properties->control.virial_flag)
        memset(virial[0], 0, ndim * ndim * sizeof(double));
    memset(f_bond[0], 0, n_bonds * ndim * sizeof(double));
    memset(t_bond[0], 0, n_bonds * 3 * sizeof(double));

    // Really this loops over kinetochores (and chromatids) which are the interacting partners
    for (int ikc = 0; ikc < nkcs; ++ikc) {
        Kinetochore *kc_iter = &(chromosomes->kinetochores_[ikc]);
        double flipsister = kc_iter->second_sister_ ? -1.0 : 1.0; //needed to get offset of chromatid
        double rchromatid[3] = {0.0};
        for (int i = 0; i < ndim; ++i) {
            rchromatid[i] = kc_iter->r_[i] - flipsister * kcoffset * kc_iter->u_[i];
        }

        // Brute force this for now
        for (int ibond = 0; ibond < n_bonds; ++ibond) {
            double lambda, mu, r_min_mag2, r_min[3], dr[3];

            min_distance_sphero_dr(parameters->n_dim,
                                   parameters->n_periodic,
                                   properties->unit_cell.h,
                                   rchromatid,
                                   NULL,
                                   kc_iter->v_,
                                   chromatidlength,
                                   properties->bonds.r_bond[ibond],
                                   properties->bonds.s_bond[ibond],
                                   properties->bonds.u_bond[ibond],
                                   properties->bonds.length[ibond],
                                   dr,
                                   r_min,
                                   &r_min_mag2,
                                   &lambda,
                                   &mu);
            if (r_min_mag2 < rcutoff2) {
                //std::cout << "step: " << properties->i_current_step << std::endl;
                //std::cout << "kc[" << ikc << "]\n";
                //std::cout << "   r(" << rchromatid[0] << ", " << rchromatid[1] << ", " << rchromatid[2] << ")\n";
                //std::cout << "   v(" << kc_iter->v_[0] << ", " << kc_iter->v_[1] << ", " << kc_iter->v_[2] << ")\n";
                //std::cout << "bond[" << ibond << "]\n";
                //std::cout << "   r(" << properties->bonds.r_bond[ibond][0] << ", " << properties->bonds.r_bond[ibond][1] << ", " << properties->bonds.r_bond[ibond][2] << ")\n";
                //std::cout << "   u(" << properties->bonds.u_bond[ibond][0] << ", " << properties->bonds.u_bond[ibond][1] << ", " << properties->bonds.u_bond[ibond][2] << ")\n";
                //std::cout << "rmin(" << r_min[0] << ", " << r_min[1] << ", " << r_min[2] << ")\n";
                //std::cout << "lambda: " << lambda << ", mu: " << mu << std::endl;
                // Calculate forces
                double factor = -chromatidamp / root2pi / sigma3 * exp(-r_min_mag2 / (2. * sigma2));
                u += chromatidamp / root2pi / sigma * exp(-r_min_mag2 / (2. * sigma2));
                //std::cout << "factor: " << factor << std::endl;
                //std::cout << "u: " << u << std::endl;

                double rminmag = sqrt(r_min_mag2);
                double f_cutoff = 0.1 / parameters->delta *properties->bonds.gamma_par[ibond];
                if (factor * rminmag > f_cutoff) {
                    factor = f_cutoff / rminmag;
                    printf(" *** Force exceeded f_cutoff chromosome_mt_soft_gaussian_potential***\n");
                }

                double fsoft[3] = {0.0};
                for (int i = 0; i < ndim; ++i) {
                    fsoft[i] = factor * r_min[i]; 
                }

                //std::cout << "fsoft (" << fsoft[0] << ", " << fsoft[1] << ", " << fsoft[2] << ")\n";
                //std::cout << "fsoft = " << sqrt(dot_product(ndim, fsoft, fsoft)) << std::endl;

                // Add to accumulators
                for (int i = 0; i < ndim; ++i) {
                    f_kc[ikc][i] += fsoft[i];
                    f_bond[ibond][i] -= fsoft[i];
                }

                // Add to virial
                // Do the center to center separation of this
                // Notice that the chromosome takes the place of bond1, so need a minus sign
                if (properties->control.virial_flag) {
                    for (int i = 0; i < parameters->n_dim; ++i) {
                        for (int j = 0; j < parameters->n_dim; ++j) {
                            virial[i][j] -= dr[i] * fsoft[j];
                        }
                    }
                }

                // Calculate torques
                double rcontact_kc[3] = {0.0};
                double rcontact_mt[3] = {0.0};
                for (int i = 0; i < ndim; ++i) {
                    rcontact_kc[i] = kc_iter->v_[i] * lambda;
                    rcontact_mt[i] = properties->bonds.u_bond[ibond][i] * mu;
                }
                //std::cout << "rcontact_kc(" << rcontact_kc[0] << ", " << rcontact_kc[1] << ", " << rcontact_kc[2] << ")\n";
                //std::cout << "rcontact_mt(" << rcontact_mt[0] << ", " << rcontact_mt[1] << ", " << rcontact_mt[2] << ")\n";

                double tau[3] = {0.0};
                cross_product(rcontact_kc, fsoft, tau, 3);
                for (int i = 0; i < 3; ++i) {
                    t_kc[ikc][i] -= tau[i];
                }
                cross_product(rcontact_mt, fsoft, tau, 3);
                for (int i = 0; i < 3; ++i) {
                    t_bond[ibond][i] += tau[i];
                }

                // Set the inverse drag matrix to be calculated
                calc_matrix[ibond] = 1;

                // We need to do a special check to see if we're within the tip distance for the KC-MT interaction
                // for force dependent catastrophe
                if (properties->bonds.length[ibond] - (mu - 0.5)*properties->bonds.length[ibond] < chromosomes->chromatid_mt_fc_distance_) {
                    //std::cout << "Within tip distance, bond: " << ibond << std::endl;
                    //std::cout << "  length: " << properties->bonds.length[ibond] << ", mu: " << mu << std::endl;
                    //std::cout << "  tip dist: " << properties->bonds.length[ibond] - (mu - 0.5)*properties->bonds.length[ibond] << std::endl;
                    for (int i = 0; i < ndim; ++i) {
                        ftip[ibond][i] -= fsoft[i];
                    }
                    //std::cout << "ftip (" << ftip[ibond][0] << ", " << ftip[ibond][1] << ", " << ftip[ibond][1] << ")\n";
                }

                //// XXX FIXME TEST
                //// Test that ta + tb + r x f = 0
                //double ftest[3] = {0.0};
                //double rcrossf[3] = {0.0};
                //cross_product(dr, fsoft, rcrossf, 3);
                //for (int i = 0; i < ndim; ++i) {
                //    ftest[i] = t_kc[ikc][i] + t_bond[ibond][i] + rcrossf[i];
                //}
                //std::cout << "ftest(" << ftest[0] << ", " << ftest[1] << ", " << ftest[2] << ")\n";
            }
            if (u != u) {
                std::cerr << "NaN encountered in chromosome_mt_soft_gaussian_potential\n";
                std::cerr << "Step: " << properties->i_current_step << std::endl;
                std::cerr << "Kinetochore[" << ikc << "], bond[" << ibond << "]\n";
            }
        }
    }

    return u;
}
