#include "bob.h"
#include "kinetochore.h"

#include <iostream>

// Calculates the full chromatin/chromatid/kinetochore internal set of forces and torques
// See chromosome writeup for full information on derivation.
double chromosome_chromatin_potential(system_parameters *parameters, system_properties *properties,
                           double **f_bond, double **virial, double **t_bond, int *calc_matrix) {
    int nchromosomes, n_dim, n_bonds;
    double u, k, r0, kth, kv;
    double r[3], rhat[3];
    double **f_chromosome;
    double **t_chromosome;
    double **f_interkc;

    // Setup pointers to stuff
    nchromosomes = properties->chromosomes.nchromosomes_;
    n_dim = parameters->n_dim;
    n_bonds = properties->bonds.n_bonds;

    f_chromosome = properties->chromosomes.f_;
    f_interkc    = properties->chromosomes.f_interkc_;
    t_chromosome = properties->chromosomes.t_;

    k = properties->chromosomes.chromatin_k_;
    kth = properties->chromosomes.chromatin_ktor_;
    kv = properties->chromosomes.chromatid_kvtor_;
    r0 = properties->chromosomes.chromatin_r0_;
   
    // Legacy, always have to zero out the bond forces
    u = 0.0;
    if (properties->control.virial_flag)
        memset(virial[0], 0, n_dim * n_dim * sizeof(double));
    memset(f_bond[0], 0, n_bonds * n_dim * sizeof(double));
    memset(t_bond[0], 0, n_bonds * 3 * sizeof(double));

    // Reset the interkc potential too, we don't want it hanging around
    for (int ic = 0; ic < nchromosomes; ++ic) {
        for (int i = 0; i < n_dim; ++i) {
            f_interkc[2*ic  ][i] = 0.0;
            f_interkc[2*ic+1][i] = 0.0;
        }
    }

    // Check if the SAC is active or not, and if we are even doing anaphase separation
    if ((properties->chromosomes.do_anaphase_) && (properties->chromosomes.sac_status_ == 0)) {
        //std::cout << "chromosome chromatin potential turned off due to SAC!\n";
        return u;
    }

    // Set up to loop over chromosomes
    for (int ic = 0; ic < nchromosomes; ++ic) {
        //std::cout << "Chromatin potential:\n";
        //std::cout << "k: " << k << ", r0: " << r0 << std::endl;
        //std::cout << "kc_A: (" << properties->chromosomes.kinetochores_[2*ic].r_[0] << ", "
        //                      << properties->chromosomes.kinetochores_[2*ic].r_[1] << ", "
        //                      << properties->chromosomes.kinetochores_[2*ic].r_[2] << ")\n";
        //std::cout << "kc_B: (" << properties->chromosomes.kinetochores_[2*ic+1].r_[0] << ", "
        //                      << properties->chromosomes.kinetochores_[2*ic+1].r_[1] << ", "
        //                      << properties->chromosomes.kinetochores_[2*ic+1].r_[2] << ")\n";

        // Find the initial r vector r = r_A - r_B
        double rmag2 = 0.0;
        for (int i = 0; i < n_dim; ++i) {
            r[i] = properties->chromosomes.kinetochores_[2*ic].r_[i] - 
                   properties->chromosomes.kinetochores_[2*ic+1].r_[i];
            rmag2 += SQR(r[i]);
        }
        double rmag = sqrt(rmag2);
        for (int i = 0; i < n_dim; ++i) {
            rhat[i] = r[i] / rmag;
        }

        //std::cout << "rhat(" << rhat[0] << ", " << rhat[1] << ", " << rhat[2] << ")\n";
        //std::cout << "rmag: " << rmag << std::endl;

        // Calculate angular stuff
        double dotA = dot_product(n_dim, properties->chromosomes.kinetochores_[2*ic  ].u_, rhat);
        double dotB = dot_product(n_dim, properties->chromosomes.kinetochores_[2*ic+1].u_, rhat);
        double thetaA = safe_acos(dotA);
        double thetaB = safe_acos(dotB);
        double sinthetaA = sin(thetaA);
        double sinthetaB = sin(thetaB);

        //std::cout << "dotA: " << dotA << std::endl;
        //std::cout << "dotB: " << dotB << std::endl;
        //std::cout << "thetaA: " << thetaA << ", thetaB: " << thetaB << std::endl;
        //std::cout << "sinthetaA: " << sinthetaA << ", sinthetaB: " << sinthetaB << std::endl;

        // Linear factor
        double linearfactor = -k * (rmag - r0);
        //std::cout << "linearfactor: " << linearfactor << std::endl;

        // theta A
        double thetaAoversinA = 0.0;
        double rhatcrossuA[3] = {0.0};
        double rhatcrossrhatcrossuA[3] = {0.0};
        if (sinthetaA != 0.0) {
            thetaAoversinA = thetaA / sinthetaA;
        }
        cross_product(rhat, properties->chromosomes.kinetochores_[2*ic].u_, rhatcrossuA, n_dim); // cross product takes pointers, so can't just use again
        cross_product(rhat, rhatcrossuA, rhatcrossrhatcrossuA, n_dim);

        double thetaAfactor = 0.0;
        if (rmag > 0.0) {
            thetaAfactor = -kth / rmag * thetaAoversinA;
        }

        //std::cout << "thetaAfactor: " << thetaAfactor << std::endl;

        // theta B
        double thetaBoversinB = 0.0;
        double rhatcrossuB[3] = {0.0};
        double rhatcrossrhatcrossuB[3] = {0.0};
        if (sinthetaB != 0.0) {
            thetaBoversinB = thetaB / sinthetaB;
        }
        cross_product(rhat, properties->chromosomes.kinetochores_[2*ic+1].u_, rhatcrossuB, n_dim); // cross product takes pointers, so can't just use again
        cross_product(rhat, rhatcrossuB, rhatcrossrhatcrossuB, n_dim);

        double thetaBfactor = 0.0;
        if (rmag > 0.0) {
            thetaBfactor = -kth / rmag * thetaBoversinB;
        }

        //std::cout << "thetaBfactor: " << thetaBfactor << std::endl;

        // update the forces
        for (int i = 0; i < n_dim; ++i) {
            f_chromosome[2*ic  ][i] += (linearfactor * rhat[i] + thetaAfactor * rhatcrossrhatcrossuA[i] + thetaBfactor * rhatcrossrhatcrossuB[i]);
            f_chromosome[2*ic+1][i] -= (linearfactor * rhat[i] + thetaAfactor * rhatcrossrhatcrossuA[i] + thetaBfactor * rhatcrossrhatcrossuB[i]);
            f_interkc[2*ic  ][i] += (linearfactor * rhat[i] + thetaAfactor * rhatcrossrhatcrossuA[i] + thetaBfactor * rhatcrossrhatcrossuB[i]);
            f_interkc[2*ic+1][i] -= (linearfactor * rhat[i] + thetaAfactor * rhatcrossrhatcrossuA[i] + thetaBfactor * rhatcrossrhatcrossuB[i]);
            //std::cout << "fchromo[" << 2*ic << ", " << i << "] = " << f_chromosome[2*ic][i] << std::endl;
            //std::cout << "finterkc[" << 2*ic << ", " << i << "] = " << f_interkc[2*ic][i] << std::endl;
        }

        // Measure the virial contribution
        if (properties->control.virial_flag) {
            for (int i = 0; i < n_dim; ++i) {
                for (int j = 0; j < n_dim; ++j) {
                    virial[i][j] += r[i] * (linearfactor * rhat[j] + thetaAfactor * rhatcrossrhatcrossuA[j] + thetaBfactor * rhatcrossrhatcrossuB[j]);
                }
            }
        }

        //std::cout << "fA: (" << f_chromosome[2*ic][0] << ", "
        //                     << f_chromosome[2*ic][1] << ", "
        //                     << f_chromosome[2*ic][2] << ")\n";
        //std::cout << "fB: (" << f_chromosome[2*ic+1][0] << ", "
        //                     << f_chromosome[2*ic+1][1] << ", "
        //                     << f_chromosome[2*ic+1][2] << ")\n";

        // Other stuff required for torque
        double dotV = dot_product(n_dim, properties->chromosomes.kinetochores_[2*ic  ].v_,
                                         properties->chromosomes.kinetochores_[2*ic+1].v_);
        double thetaV = safe_acos(dotV);
        double sinthetaV = sin(thetaV);
        double thetaVoversinV = 0.0;

        //std::cout << "dotV: " << dotV << std::endl;
        //std::cout << "thetaV: " << thetaV << std::endl;
        //std::cout << "sinthetaV: " << sinthetaV << std::endl;

        if (sinthetaV != 0.0) {
            thetaVoversinV = thetaV / sinthetaV; 
        }
        double vhatAcrossvhatB[3] = {0.0};
        cross_product(properties->chromosomes.kinetochores_[2*ic  ].v_,
                      properties->chromosomes.kinetochores_[2*ic+1].v_,
                      vhatAcrossvhatB,
                      n_dim);

        // torques
        for (int i = 0; i < n_dim; ++i) {
            t_chromosome[2*ic  ][i] += (-kth * thetaAoversinA * rhatcrossuA[i] + kv * thetaVoversinV * vhatAcrossvhatB[i]);
            t_chromosome[2*ic+1][i] += (-kth * thetaBoversinB * rhatcrossuB[i] - kv * thetaVoversinV * vhatAcrossvhatB[i]);
        }

        // Energies
        u += 0.5 * k * (rmag - r0) * (rmag - r0); // linear component
        u += 0.5 * kth * thetaA * thetaA; // two angular components
        u += 0.5 * kth * thetaB * thetaB;

        //std::cout << "tA: (" << t_chromosome[2*ic][0] << ", "
        //                     << t_chromosome[2*ic][1] << ", "
        //                     << t_chromosome[2*ic][2] << ")\n";
        //std::cout << "tB: (" << t_chromosome[2*ic+1][0] << ", "
        //                     << t_chromosome[2*ic+1][1] << ", "
        //                     << t_chromosome[2*ic+1][2] << ")\n";

        if (u != u) {
            std::cerr << "NaN encountered in chromosome_chormatin_potential\n";
            std::cerr << "Step: " << properties->i_current_step << std::endl;
            std::cerr << "Kinetochore pair: " << ic << "(" << properties->chromosomes.kinetochores_[2*ic].cidx_ << ") ["
                << properties->chromosomes.kinetochores_[2*ic].idx_ << ", " << properties->chromosomes.kinetochores_[2*ic+1].idx_ << "]\n";
        }
    }
   
    return u;
}
