// Implementation of testing the chromatin potential

#include "bob.h"
#include "correlation_data.h"
#include "helpers.h"
#include "kinetochore.h"
#include "gen_chromosomes_diffusion.h"

#include <iostream>

void GenChromosomesDiffusion::RunTests() {
    std::cout << "****************\n";
    std::cout << "Generate Chromosome Diffusion Values\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";
        auto result = kv.second();
        if (!result) {
            std::cout << "Test : " << kv.first << " failed, check output!\n";
            exit(1);
        }
    }
}

void GenChromosomesDiffusion::SetTests() {
    YAML::Node node = YAML::LoadFile(filename_);

    if (node["Diffusion"]) {
        tests_["Diffusion"] = std::bind(&GenChromosomesDiffusion::GenDiffusion, this);
    }
}

void GenChromosomesDiffusion::InitSystem() {
    seed_++;
    init_default_params(&parameters_);

    ndim_ = 3;
    parameters_.n_dim = ndim_;
    parameters_.n_periodic = 0;
    parameters_.delta = 0.001;
    parameters_.temp = 1.0;

    // Init h by hand
    properties_.unit_cell.h = (double**) allocate_2d_array(parameters_.n_dim, parameters_.n_dim, sizeof(double));
    for (int i = 0; i < parameters_.n_dim; ++i) {
        properties_.unit_cell.h[i][i] = 100;
    }

    gsl_rng_env_setup();
    properties_.rng.T = gsl_rng_default;
    properties_.rng.r = gsl_rng_alloc(properties_.rng.T);
    gsl_rng_set(properties_.rng.r, seed_);

    YAML::Node node = YAML::LoadFile(testfile_);
    // Initialize the chromosomes
    properties_.chromosomes.Init(&parameters_,
                                 &properties_,
                                 testfile_.c_str(),
                                 NULL);

    // Save off position information
    Kinetochore *kA = &(properties_.chromosomes.kinetochores_[0]);
    Kinetochore *kB = &(properties_.chromosomes.kinetochores_[1]);
    for (int i = 0; i < parameters_.n_dim; ++i) {
        rA[i] = kA->r_[i];
        uA[i] = kA->u_[i];
        vA[i] = kA->v_[i];
        wA[i] = kA->w_[i];
        rB[i] = kB->r_[i];
        uB[i] = kB->u_[i];
        vB[i] = kB->v_[i];
        wB[i] = kB->w_[i];
    }

    nsteps_ = node["nsteps"].as<int>();
    nposit_ = node["nposit"].as<int>();
    parameters_.n_steps = nsteps_;
    parameters_.n_posit = nposit_;
    parameters_.delta = node["delta"].as<double>();

    // Set up the potential information
    // For us, just the chromatin spring complex
    // Alias so it looks just like the inits
    system_parameters *parameters = &(parameters_);
    system_properties *properties = &(properties_);
    system_potential *potential = &(potential_);
    int n_comp = 1;
    //int n_comp = 0;
    potential->n_comp = n_comp;
    potential->pot_func = (double (**) (system_parameters *,
                                        system_properties *,
                                        double**,
                                        double**,
                                        double**,
                                        int*)) allocate_1d_array(n_comp, sizeof(void *));
    int n_bonds = 0;
    int n_sites = 0;
    int n_dim = ndim_;
    potential->f_comp = (double ***) allocate_3d_array(n_comp, n_sites-n_bonds, n_dim, sizeof(double));
    potential->u_comp = (double *) allocate_1d_array(n_comp, sizeof(double));
    potential->virial_comp = (double ***) allocate_3d_array(n_comp, n_dim, n_dim, sizeof(double));
    potential->t_comp = (double ***) allocate_3d_array(n_comp, n_bonds, n_dim, sizeof(double));
    potential->calc_matrix = (int*) allocate_1d_array(n_bonds, sizeof(int));

    // Need thermodynamics too
    properties->thermo.virial = (double **)
        allocate_2d_array(parameters->n_dim, parameters->n_dim, sizeof(double));
    properties->thermo.stress = (double **)
        allocate_2d_array(parameters->n_dim, parameters->n_dim, sizeof(double));
    properties->thermo.press_tensor = (double **)
        allocate_2d_array(parameters->n_dim, parameters->n_dim, sizeof(double));

    // bond structure
    init_bond_structure_sphero(parameters, properties);

    // Only the chormosome spring complex
    std::cout << "Potential: chromatin spring complex\n";
    potential->pot_func[0] = chromosome_chromatin_potential;
}

bool GenChromosomesDiffusion::TestForceTorqueBalance(double uret, bool print_info) {
    // tA + tB + r x f = 0

    bool success = true;
    int n_dim = parameters_.n_dim;
    double r[3] = {0.0};
    for (int i = 0; i < n_dim; ++i) {
        //r[i] = properties_.chromosomes.r_[0][i] - 
        //       properties_.chromosomes.r_[1][i];
        r[i] = rlastA[i] - rlastB[i];
    }
    double f[3] = {0.0};
    double tA[3] = {0.0};
    double tB[3] = {0.0};
    for (int i = 0; i < n_dim; ++i) {
        f[i] = properties_.chromosomes.f_[0][i]; 
        tA[i] = properties_.chromosomes.t_[0][i];
        tB[i] = properties_.chromosomes.t_[1][i];
    }

    if (print_info) {
        std::cout << "energy  [" << uret << "]\n";
        std::cout << "force   (" << f[0] << ", " << f[1] << ", " << f[2] << ")\n";
        std::cout << "torque A(" << tA[0] << ", " << tA[1] << ", " << tA[2] << ")\n";
        std::cout << "torque B(" << tB[0] << ", " << tB[1] << ", " << tB[2] << ")\n";
    }

    double rcrossf[3] = {0.0};
    cross_product(r, f, rcrossf, n_dim);

    double finalans[3] = {0.0};
    for (int i = 0; i < n_dim; ++i) {
        finalans[i] = tA[i] + tB[i] + rcrossf[i];
    }
    if (print_info) {
        std::cout << "final   (" << finalans[0] << ", " << finalans[1] << ", " << finalans[2] << ")\n";
    }

    if (fabs(finalans[0]) > 1e-10 ||
        fabs(finalans[1]) > 1e-10 ||
        fabs(finalans[2]) > 1e-10) {
        std::cout << "Force balance didin't work correctly, failing test!\n";
        success = false;
    }
    return success;
}

bool GenChromosomesDiffusion::GenDiffusion() {
    bool success = true;
    testfile_ = "gen_chromosomes_diffusion.yaml";

    InitSystem();

    properties_.time = 0.0;

    std::cout << "Running " << nsteps_ << " steps\n";
    for (int i = 0; i < nsteps_; ++i) {
        properties_.i_current_step = i;

        position_step_spindle_bd_mp(&parameters_, &properties_, &potential_);

        if (i % nposit_ == (nposit_-1)) {
        //if (i % 100000 == 99999) {
            // Previous step
            std::cout << "Step[" << i << "] recording position information for update\n";
            for (int idim = 0; idim < parameters_.n_dim; ++idim) {
                rlastA[idim] = properties_.chromosomes.r_[0][idim];
                rlastB[idim] = properties_.chromosomes.r_[1][idim];
            }
        }

        if (i % nposit_ == 0) {
        //if (i % 100000 == 0) {
            std::cout << "Step[" << i << "] readout\n";
            // Test the foce/torque balance of just kc0
            success = success && TestForceTorqueBalance(0.0, true);
            if (!success)
                break;
            properties_.chromosomes.WriteState(&parameters_, &properties_);
        }
    }

    return success;
}
