// Implementaiton of unit tests for kinetochore mesh

#include "bob.h"
#include "unit_test_kinetochoremesh_potential.h"

#include <fstream>
#include <iostream>
#include <numeric>

// Run the tests and all their variants
void UnitTestKinetochoremeshPotential::RunTests() {
    std::cout << "****************\n";
    std::cout << "Unit Test Kinetochoremesh Potential run tests\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";

        for (int iv = 0; iv < node_["Tests"][kv.first].size(); ++iv) {
        std::cout << "----------------\n";
            //std::cout << "  Variant : " << iv << std::endl;
            var_subnode_ = node_["Tests"][kv.first][iv];
            iv_ = iv;
            auto result = kv.second();

            if (!result) {
                std::cout << "Test : " << kv.first << " failed, check output!\n";
                exit(1);
            }
        }
    }
}

// Set the possible tests to run
void UnitTestKinetochoremeshPotential::SetTests() {
    node_ = YAML::LoadFile(filename_);

    // Check against names to bind method calls
    if (node_["Tests"]["TestPotentialUpdate"]) {
        tests_["TestPotentialUpdate"] = std::bind(&UnitTestKinetochoremeshPotential::TestPotentialUpdate, this);
    }
}

// Initialize the system of chromosomes (with other stuff toooo)
void UnitTestKinetochoremeshPotential::InitSystem(system_parameters *parameters,
                                                  system_properties *properties) {

    // Generic parameter initialization
    parse_parameters_node(var_subnode_["params"], parameters);
    gsl_rng_env_setup();
    properties->rng.T = gsl_rng_default;
    properties->rng.r = gsl_rng_alloc(properties->rng.T);
    gsl_rng_set(properties->rng.r, parameters->seed);

    // Generate the unit cell
    int ndim = parameters->n_dim;
    properties->unit_cell.h = (double **) allocate_2d_array(ndim, ndim, sizeof(double));
    double box = 100.0;
    for (int i = 0; i < ndim; ++i) {
        properties->unit_cell.h[i][i] = box;
    }

    // Get the microtubule information
    {
        int nmts = 0;
        if (var_subnode_["microtubules"]) {
            nmts = var_subnode_["microtubules"].size();
            init_unit_cell_structure(parameters, properties);
            int nsites = 2*nmts;
            parameters->n_spheros = properties->bonds.n_bonds = nmts;
            properties->sites.n_sites = nsites;
            properties->sites.v = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
            properties->sites.r = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
            init_site_structure_sphero(parameters, properties);
            init_bond_structure_sphero(parameters, properties);
            for (int imt = 0; imt < nmts; ++imt) {
                for (int i = 0; i < ndim; ++i) {
                    properties->bonds.r_bond[imt][i] = var_subnode_["microtubules"][imt]["r"][i].as<double>();
                    properties->bonds.u_bond[imt][i] = var_subnode_["microtubules"][imt]["u"][i].as<double>();
                }
                properties->bonds.length[imt] = var_subnode_["microtubules"][imt]["l"].as<double>();
            }
            // Print the bond information
            PrintMTs(parameters, properties);
        } else {
            nmts = var_subnode_["setup"]["n_mts"].as<int>();
            init_unit_cell_structure(parameters, properties);
            int nsites = 2*nmts;
            parameters->n_spheros = properties->bonds.n_bonds = nmts;
            properties->sites.n_sites = nsites;
            properties->sites.v = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
            properties->sites.r = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
            init_site_structure_sphero(parameters, properties);
            init_bond_structure_sphero(parameters, properties);
            double mt_length = var_subnode_["setup"]["mt_length"].as<double>();
            for (int imt = 0; imt < nmts; ++imt) {
                double u[3] = {0.0};
                generate_random_unit_vector(ndim, u, properties->rng.r);
                for (int i = 0; i < ndim; ++i) {
                    properties->bonds.r_bond[imt][i] = (gsl_rng_uniform(properties->rng.r) - 0.5) * box;
                    properties->bonds.u_bond[imt][i] = u[i];
                }
                properties->bonds.length[imt] = (gsl_rng_uniform(properties->rng.r) * mt_length);
            }

            PrintMTs(parameters, properties);
        }
        // Create some force stuff
        f_comp_ = (double **) allocate_2d_array(nmts, ndim, sizeof(double));
        t_comp_ = (double **) allocate_2d_array(nmts, ndim, sizeof(double));
        virial_comp_ = (double **) allocate_2d_array(ndim, ndim, sizeof(double));
        calc_matrix_ = (int *) allocate_1d_array(nmts, sizeof(int));

        properties->anchors.n_anchors = 0;
        parameters->poly_config = NULL; // needed for some reason to not die on successive reads
        init_dynamic_instability(parameters, properties);
        init_diffusion_sphero(parameters, properties);
    }

    // Load OMP settings
    set_omp_settings(parameters);
    print_omp_settings();

    #ifdef ENABLE_OPENMP
    #pragma omp parallel
    #pragma omp master
    #endif
    {
        int n_threads = properties->mp_local.n_threads = properties->rng_mt.n_threads = 1;

        #ifdef ENABLE_OPENMP
        n_threads = properties->mp_local.n_threads = properties->rng_mt.n_threads = omp_get_num_threads();
        #endif

        properties->rng_mt.rng = (rng_properties*)
            allocate_1d_array(properties->rng_mt.n_threads, sizeof(rng_properties));

        gsl_rng_env_setup();
        for (int i_thr = 0; i_thr < properties->rng_mt.n_threads; ++i_thr) {
            properties->rng_mt.rng[i_thr].T = gsl_rng_default;
            properties->rng_mt.rng[i_thr].r = gsl_rng_alloc(properties->rng_mt.rng[i_thr].T);
            gsl_rng_set(properties->rng_mt.rng[i_thr].r, (parameters->seed+1) * i_thr);
        }

        properties->mp_local.f_local = (double***)
            allocate_3d_array(n_threads, properties->bonds.n_bonds, parameters->n_dim, sizeof(double));

        properties->mp_local.t_local = (double***)
            allocate_3d_array(n_threads, properties->bonds.n_bonds, 3, sizeof(double));

        properties->mp_local.virial_local = (double***)
            allocate_3d_array(n_threads, parameters->n_dim, parameters->n_dim, sizeof(double));

        properties->mp_local.calc_local = (int**)
            allocate_2d_array(n_threads, properties->bonds.n_bonds, sizeof(int));
    }
}

void UnitTestKinetochoremeshPotential::InitXsomeParams(system_parameters *parameters,
                                                       system_properties *properties) {
    InitXsomeParams(parameters, properties, var_subnode_["chromosome"]);
}

void UnitTestKinetochoremeshPotential::InitXsomeParams(system_parameters *parameters,
                                                       system_properties *properties,
                                                       YAML::Node param_node) {
    std::string x_file = parameters->chromosome_config;
    YAML::Node xsome_node = GetXsomeParamNode(x_file, param_node);
    properties->chromosomes.Init(parameters, properties, &xsome_node);
    FinishInitSystem(parameters,properties);
}

YAML::Node UnitTestKinetochoremeshPotential::GetXsomeParamNode(std::string x_file, YAML::Node mod_node){
    std::cout << "chromosome_file: " << x_file << std::endl;
    YAML::Node xsome_node = YAML::LoadFile(x_file.c_str());
    // Vector of all important xsome paramters with double values
    std::vector<std::string> position_param_strs = {
        "rA",
        "uA",
        "rB",
        "uB"
    };

    for (auto param : position_param_strs){
        // All double value parameters are checked to see if they exist or not in test yaml file
        // if so override the default ones from conventional xsome yaml file
        if (YAML::Node p_node = mod_node[param]) {
            xsome_node["chromosomes"]["chromosome"][0][param][0] = p_node[0].as<double>();
            xsome_node["chromosomes"]["chromosome"][0][param][1] = p_node[1].as<double>();
            xsome_node["chromosomes"]["chromosome"][0][param][2] = p_node[2].as<double>();
        }
    }
    return xsome_node;
}

// Finish initializing the system
void UnitTestKinetochoremeshPotential::FinishInitSystem(system_parameters *parameters,
                                            system_properties *properties) {
    // Set up the neighbor lists
    properties->neighbors.neighbs = new nl_list[properties->bonds.n_bonds];
    update_neighbor_lists_sphero_all_pairs_mp(parameters->n_dim, parameters->n_periodic,
                                              properties->unit_cell.h,
                                              parameters->skin, parameters->r_cutoff,
                                              properties->bonds.n_bonds,
                                              properties->bonds.r_bond,
                                              properties->bonds.s_bond,
                                              properties->bonds.u_bond,
                                              properties->bonds.length,
                                              properties->neighbors.neighbs,
                                              parameters->nl_twoway_flag);
    // Calculate distance between rods and place value into neighborlists
    brownian_sphero_neighbor_lists(parameters, properties,
                                   f_comp_, virial_comp_,
                                   t_comp_, calc_matrix_);

}

void UnitTestKinetochoremeshPotential::PrintMTs(system_parameters *parameters, system_properties *properties ) {
    int ndim = parameters->n_dim;
    int nmts = properties->bonds.n_bonds;
    for (int imt = 0; imt < nmts; ++imt) {
        std::cout << "New Microtubule:\n";
        std::cout << "  r(" << properties->bonds.r_bond[imt][0] << ", "
                            << properties->bonds.r_bond[imt][1] << ", "
                            << properties->bonds.r_bond[imt][2] << ")\n";
        std::cout << "  u(" << properties->bonds.u_bond[imt][0] << ", "
                            << properties->bonds.u_bond[imt][1] << ", "
                            << properties->bonds.u_bond[imt][2] << ")\n";
        std::cout << "  l " << properties->bonds.length[imt] << std::endl;
    }
}

// Test the potential and update positions
bool UnitTestKinetochoremeshPotential::TestPotentialUpdate() {
    bool success = true;

    system_parameters parameters;
    system_properties properties;

    // Initialize the system
    InitSystem(&parameters, &properties);
    InitXsomeParams(&parameters, &properties);

    // Run the potential calculation
    double uret = 0.0;
    if (var_subnode_["results"]["ptype"].as<std::string>() == "allpairs") {
        uret = kinetochoremesh_mt_wca_potential_allpairs(&parameters,
                                                         &properties,
                                                         f_comp_,
                                                         virial_comp_,
                                                         t_comp_,
                                                         calc_matrix_);
    }

    success = TestForceTorqueBalance(&parameters, &properties, true);
    return success;
}

// Test the force torque balance (really write them out)
bool UnitTestKinetochoremeshPotential::TestForceTorqueBalance(system_parameters *parameters,
                                                              system_properties *properties,
                                                              bool print_info) {
    bool success = true;
    double fA[3] = {0.0};
    double tA[3] = {0.0};
    double fmt[3] = {0.0};
    double tmt[3] = {0.0};

    int ndim = parameters->n_dim;
    for (int i = 0; i < ndim; ++i) {
        fA[i] = properties->chromosomes.f_[0][i];
        tA[i] = properties->chromosomes.t_[0][i];

        fmt[i] = f_comp_[0][i];
        tmt[i] = t_comp_[0][i];
    }

    if (print_info) {
        std::cout << "force  KC("
            << fA[0] << ", "
            << fA[1] << ", "
            << fA[2] << ")\n";
        std::cout << "torque KC("
            << tA[0] << ", "
            << tA[1] << ", "
            << tA[2] << ")\n";
        std::cout << "force  MT("
            << fmt[0] << ", "
            << fmt[1] << ", "
            << fmt[2] << ")\n";
        std::cout << "torque MT("
            << tmt[0] << ", "
            << tmt[1] << ", "
            << tmt[2] << ")\n";
    }

    double kc_f_result[3] = {0.0};
    double kc_t_result[3] = {0.0};
    for (int i = 0; i < ndim; ++i) {
        kc_f_result[i] = var_subnode_["results"]["fA"][i].as<double>();
        kc_t_result[i] = var_subnode_["results"]["tA"][i].as<double>();
    }

    for (int i = 0; i < ndim; ++i) {
        if (!almost_equal(kc_f_result[i], fA[i], 0.01)) success = success && false;
        if (!almost_equal(kc_t_result[i], tA[i], 0.01)) success = success && false;
    }

    return success;
}
