// Implementation of testing the chromatin potential

#include "bob.h"
#include "helpers.h"
#include "kinetochore.h"
#include "minimum_distance.h"
#include "test_chromatin_mt_potential.h"
#include "triangle_mesh.h"

#include <iostream>

void TestChromatinMTPotential::RunTests() {
    std::cout << "****************\n";
    std::cout << "Test Chromatin MT Gaussian Potential run tests\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";
        auto result = kv.second();
        if (!result) {
            std::cout << "Test : " << kv.first << " failed, check output!\n";
            exit(1);
        }
    }
}

void TestChromatinMTPotential::SetTests() {
    YAML::Node node = YAML::LoadFile(filename_);

    if (node["TestChromosomesMTRecoil"]) {
        tests_["TestChromosomesMTRecoil"] = std::bind(&TestChromatinMTPotential::TestChromosomesMTRecoil, this);
    }
    if (node["TestChromosomesMTRecoilAngle"]) {
        tests_["TestChromosomesMTRecoilAngle"] = std::bind(&TestChromatinMTPotential::TestChromosomesMTRecoilAngle, this);
    }
    if (node["TestKinetochoreMTRecoil"]) {
        tests_["TestKinetochoreMTRecoil"] = std::bind(&TestChromatinMTPotential::TestKinetochoreMTRecoil, this);
    }
    if (node["TestKinetochoreTriangulation"]) {
        tests_["TestKinetochoreTriangulation"] = std::bind(&TestChromatinMTPotential::TestKinetochoreTriangulation, this);
    }
    if (node["TestKinetochoremeshMTRecoil"]) {
        tests_["TestKinetochoremeshMTRecoil"] = std::bind(&TestChromatinMTPotential::TestKinetochoremeshMTRecoil, this);
    }
}

bool TestChromatinMTPotential::TestChromosomesMTRecoil() {
    bool success = true;
    testfile_ = "test_chromosomes_mt_recoil.yaml";

    InitSystem();

    double uret = chromosome_mt_soft_gaussian_potential_allpairs(&parameters_,
                                                        &properties_,
                                                        f_comp_,
                                                        NULL,
                                                        t_comp_,
                                                        calc_matrix_);
    std::cout << "uret: " << uret << std::endl;
    success = TestForceTorqueBalance();

    properties_.chromosomes.UpdatePositions();
    properties_.chromosomes.PrintPositions(0);

    success = success && TestDeltas();

    return success;
}

bool TestChromatinMTPotential::TestChromosomesMTRecoilAngle() {
    bool success = true;
    testfile_ = "test_chromosomes_mt_recoil_angle.yaml";

    InitSystem();

    double uret = chromosome_mt_soft_gaussian_potential_allpairs(&parameters_,
                                                        &properties_,
                                                        f_comp_,
                                                        NULL,
                                                        t_comp_,
                                                        calc_matrix_);
    std::cout << "uret: " << uret << std::endl;
    success = TestForceTorqueBalance();

    properties_.chromosomes.UpdatePositions();
    properties_.chromosomes.PrintPositions(0);

    success = success && TestDeltas();

    return success;
}

bool TestChromatinMTPotential::TestKinetochoreMTRecoil() {
    bool success = true;
    testfile_ = "test_kinetochore_mt_recoil.yaml";

    InitSystem();

    double uret = kinetochore_mt_potential_allpairs(&parameters_,
                                                    &properties_,
                                                    f_comp_,
                                                    NULL,
                                                    t_comp_,
                                                    calc_matrix_);

    success = success && TestDeltas();

    std::cout << "uret: " << uret << std::endl;
    success = success && TestForceTorqueBalance();

    properties_.chromosomes.UpdatePositions();
    properties_.chromosomes.PrintPositions(0);

    return success;
}

bool TestChromatinMTPotential::TestKinetochoreTriangulation() {
    bool success = true;
    testfile_ = "test_kinetochore_triangulation.yaml";

    InitSystem();

    print_polygon(&(properties_.chromosomes.tris_[0]));

    double rmin[3], rcontact[3];
    double rminmag2, mu;
    min_distance_sphero_polygon(parameters_.n_dim,
                                parameters_.n_periodic,
                                properties_.unit_cell.h,
                                properties_.bonds.r_bond[0],
                                properties_.bonds.s_bond[0],
                                properties_.bonds.u_bond[0],
                                properties_.bonds.length[0],
                                &(properties_.chromosomes.tris_[0]),
                                rmin, &rminmag2, rcontact, &mu);

    std::cout << "--------\n";
    std::cout << "rminmag2: " << rminmag2 << ", mu: " << mu << std::endl;
    std::cout << "rmin    (" << rmin[0] << ", " << rmin[1] << ", " << rmin[2] << ")\n";
    std::cout << "contact (" << rcontact[0] << ", " << rcontact[1] << ", " << rcontact[2] << ")\n";
    return success;
}

bool TestChromatinMTPotential::TestKinetochoremeshMTRecoil() {
    bool success = true;
    testfile_ = "test_kinetochoremesh_mtrecoil.yaml";

    InitSystem();

    double uret = kinetochoremesh_mt_wca_potential_allpairs(&parameters_,
                                                            &properties_,
                                                            f_comp_,
                                                            NULL,
                                                            t_comp_,
                                                            calc_matrix_);

    success = success && TestDeltas();

    std::cout << "uret: " << uret << std::endl;
    success = success && TestForceTorqueBalance();

    properties_.chromosomes.UpdatePositions();
    properties_.chromosomes.PrintPositions(0);

    // Test that the mesh is on the kinetochore plane
    for (int ikc = 0; ikc < properties_.chromosomes.nkcs_; ++ikc) {
        for (int iv = 0; iv < properties_.chromosomes.nkctriverts_; ++iv) {
            double myvert[3];
            for (int i = 0; i < 3; ++i) {
                myvert[i] = properties_.chromosomes.tris_[ikc].verts[i][iv] - properties_.chromosomes.r_[ikc][i];
            }
            double dotprod = dot_product(parameters_.n_dim,
                                         properties_.chromosomes.u_[ikc],
                                         myvert);
            std::cout << "dotprod: " << dotprod << std::endl;
        }
    }

    return success;
}

void TestChromatinMTPotential::InitSystem() {
    seed_++;
    init_default_params(&parameters_);

    ndim_ = 3;
    parameters_.n_dim = ndim_;
    parameters_.n_periodic = 0;
    parameters_.delta = 0.001;
    parameters_.temp = 1.0;

    // Init h by hand
    properties_.unit_cell.h = (double**) allocate_2d_array(parameters_.n_dim, parameters_.n_dim, sizeof(double));
    for (int i = 0; i < parameters_.n_dim; ++i) {
        properties_.unit_cell.h[i][i] = 100;
    }

    gsl_rng_env_setup();
    properties_.rng.T = gsl_rng_default;
    properties_.rng.r = gsl_rng_alloc(properties_.rng.T);
    gsl_rng_set(properties_.rng.r, seed_);

    // Create the microtubules
    YAML::Node node = YAML::LoadFile(testfile_);
    nmts_ = node["mt"].size();
    parameters_.n_spheros = properties_.bonds.n_bonds = nmts_;

    // Init the structures for the bonds
    {
        init_unit_cell_structure(&parameters_, &properties_);
        int nsites = 2 * nmts_;
        properties_.sites.v = (double**) allocate_2d_array(nsites, ndim_, sizeof(double));
        properties_.sites.r = (double**) allocate_2d_array(nsites, ndim_, sizeof(double));
        init_site_structure_sphero(&parameters_, &properties_);
        init_bond_structure_sphero(&parameters_, &properties_);
    }

    for (int imt = 0; imt < nmts_; ++imt) {
        for (int i = 0; i < ndim_; ++i) {
            properties_.bonds.r_bond[imt][i] = node["mt"][imt]["r"][i].as<double>();
            properties_.bonds.u_bond[imt][i] = node["mt"][imt]["u"][i].as<double>();
        }
        properties_.bonds.length[imt] = node["mt"][imt]["l"].as<double>();
    }
    init_diffusion_sphero(&parameters_, &properties_);

    // Print the bond information
    for (int imt = 0; imt < nmts_; ++imt) {
        std::cout << "New Microtubule:\n";
        std::cout << "  r(" << properties_.bonds.r_bond[imt][0] << ", "
                            << properties_.bonds.r_bond[imt][1] << ", "
                            << properties_.bonds.r_bond[imt][2] << ")\n";
        std::cout << "  u(" << properties_.bonds.u_bond[imt][0] << ", "
                            << properties_.bonds.u_bond[imt][1] << ", "
                            << properties_.bonds.u_bond[imt][2] << ")\n";
        std::cout << "  l " << properties_.bonds.length[imt] << std::endl;
    }

    // Initialize the chromosomes
    properties_.chromosomes.Init(&parameters_,
                                 &properties_,
                                 testfile_.c_str(),
                                 NULL);
    // For ease, although very danger, remove the second kinetochore completely
    properties_.chromosomes.nkcs_ = 1;

    // Potential needs some additional information
    f_comp_ = (double **) allocate_2d_array(nmts_, ndim_, sizeof(double));
    t_comp_ = (double **) allocate_2d_array(nmts_, ndim_, sizeof(double));
    calc_matrix_ = (int *) allocate_1d_array(nmts_, sizeof(int));

    // Save off position information
    Kinetochore *kA = &(properties_.chromosomes.kinetochores_[0]);
    Kinetochore *kB = &(properties_.chromosomes.kinetochores_[1]);
    for (int i = 0; i < parameters_.n_dim; ++i) {
        rA[i] = kA->r_[i];
        uA[i] = kA->u_[i];
        vA[i] = kA->v_[i];
        wA[i] = kA->w_[i];
        rB[i] = kB->r_[i];
        uB[i] = kB->u_[i];
        vB[i] = kB->v_[i];
        wB[i] = kB->w_[i];
    }
}

bool TestChromatinMTPotential::TestForceTorqueBalance() {
    // tA + tB + r x f = 0

    bool success = true;
    int n_dim = parameters_.n_dim;
    double r[3] = {0.0};
    for (int i = 0; i < n_dim; ++i) {
        r[i] = properties_.chromosomes.r_[0][i] - 
               properties_.chromosomes.r_[1][i];
    }
    double fA[3] = {0.0};
    double fB[3] = {0.0};
    double tA[3] = {0.0};
    double tB[3] = {0.0};
    double fmt[3] = {0.0};
    double tmt[3] = {0.0};
    for (int i = 0; i < n_dim; ++i) {
        fA[i] = properties_.chromosomes.f_[0][i]; 
        fB[i] = properties_.chromosomes.f_[1][i];
        fmt[i] = f_comp_[0][i];
        tA[i] = properties_.chromosomes.t_[0][i];
        tB[i] = properties_.chromosomes.t_[1][i];
        tmt[i] = t_comp_[0][i];
    }

    std::cout << "force  A(" << fA[0] << ", " << fA[1] << ", " << fA[2] << ")\n";
    std::cout << "force  B(" << fB[0] << ", " << fB[1] << ", " << fB[2] << ")\n";
    std::cout << "force mt(" << fmt[0] << ", " << fmt[1] << ", " << fmt[2] << ")\n";
    std::cout << "torque A(" << tA[0] << ", " << tA[1] << ", " << tA[2] << ")\n";
    std::cout << "torque B(" << tB[0] << ", " << tB[1] << ", " << tB[2] << ")\n";
    std::cout << "torquemt(" << tmt[0] << ", " << tmt[1] << ", " << tmt[2] << ")\n";

    return success;
}

bool TestChromatinMTPotential::TestDeltas() {
    bool success = true;
    // Just write out the deltas of the change in positions
    // final - initial
    double dr[3] = {0.0};
    double du[3] = {0.0};
    double dv[3] = {0.0};
    double dw[3] = {0.0};

    std::cout << "kA deltas: \n";
    Kinetochore *kA = &(properties_.chromosomes.kinetochores_[0]);
    Kinetochore *kB = &(properties_.chromosomes.kinetochores_[1]);
    for (int i = 0; i < parameters_.n_dim; ++i) {
        dr[i] = kA->r_[i] - rA[i];
        du[i] = kA->u_[i] - uA[i];
        dv[i] = kA->v_[i] - vA[i];
        dw[i] = kA->w_[i] - wA[i];
    }
    std::cout << "  drA(" << dr[0] << ", " << dr[1] << ", " << dr[2] << ")\n";
    std::cout << "  duA(" << du[0] << ", " << du[1] << ", " << du[2] << ")\n";
    std::cout << "  dvA(" << dv[0] << ", " << dv[1] << ", " << dv[2] << ")\n";
    std::cout << "  dwA(" << dw[0] << ", " << dw[1] << ", " << dw[2] << ")\n";

    std::cout << "kB deltas: \n";
    for (int i = 0; i < parameters_.n_dim; ++i) {
        dr[i] = kB->r_[i] - rB[i];
        du[i] = kB->u_[i] - uB[i];
        dv[i] = kB->v_[i] - vB[i];
        dw[i] = kB->w_[i] - wB[i];
    }
    std::cout << "  drB(" << dr[0] << ", " << dr[1] << ", " << dr[2] << ")\n";
    std::cout << "  duB(" << du[0] << ", " << du[1] << ", " << du[2] << ")\n";
    std::cout << "  dvB(" << dv[0] << ", " << dv[1] << ", " << dv[2] << ")\n";
    std::cout << "  dwB(" << dw[0] << ", " << dw[1] << ", " << dw[2] << ")\n";

    return success;
}

