// Implementation of minimum distance calculation unit tests

#include "bob.h"
#include "unit_test_min_distance.h"
#include "minimum_distance.h"
#include "triangle_mesh.h"

#include <chrono>
#include <iostream>

// Run the tests and all the variants
void UnitTestMinDistance::RunTests() {

    std::cout << "****************\n";
    std::cout << "Unit Test Min Distance run tests\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";

        for (int iv = 0; iv < node_["Tests"][kv.first].size(); ++iv) {
            std::cout << "^^^^^^^^^^^^^^^^\n";
            std::cout << "  Variant : " << iv << std::endl;
            std::cout << "^^^^^^^^^^^^^^^^\n";
            var_subnode_ = node_["Tests"][kv.first][iv];
            iv_ = iv;
            auto result = kv.second();

            if (!result) {
                std::cout << "Test : " << kv.first << " failed, check output!\n";
                exit(1);
            }
        }
    }
}

// Set the possible tests to run
void UnitTestMinDistance::SetTests() {
    node_ = YAML::LoadFile(filename_);

    // Check against names to bind method calls
    if (node_["Tests"]["TestRandomSetup"]) {
        tests_["TestRandomSetup"] = std::bind(&UnitTestMinDistance::TestRandomSetup, this);
    }
}

/// @brief Initialize the target system for min distance tests
void UnitTestMinDistance::InitSystem(system_parameters *parameters,
                                     system_properties *properties) {
    // Generic parameter initialization
    parse_parameters_node(var_subnode_["params"], parameters);
    gsl_rng_env_setup();
    properties->rng.T = gsl_rng_default;
    properties->rng.r = gsl_rng_alloc(properties->rng.T);
    gsl_rng_set(properties->rng.r, parameters->seed);

    // Generate the unit cell
    int ndim = parameters->n_dim;
    double box = 100;
    properties->unit_cell.h = (double **) allocate_2d_array(ndim, ndim, sizeof(double));
    for (int i = 0; i < ndim; ++i) {
        properties->unit_cell.h[i][i] = box;
    }

    // Get the microtubule information
    {
        int nmts = var_subnode_["setup"]["n_mts"].as<int>();
        init_unit_cell_structure(parameters, properties);
        int nsites = 2*nmts;
        parameters->n_spheros = properties->bonds.n_bonds = nmts;
        properties->sites.n_sites = nsites;
        properties->sites.v = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
        properties->sites.r = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
        init_site_structure_sphero(parameters, properties);
        init_bond_structure_sphero(parameters, properties);
        double mt_length = var_subnode_["setup"]["mt_length"].as<double>();
        for (int imt = 0; imt < nmts; ++imt) {
            double u[3] = {0.0};
            generate_random_unit_vector(ndim, u, properties->rng.r);
            for (int i = 0; i < ndim; ++i) {
                properties->bonds.r_bond[imt][i] = (gsl_rng_uniform(properties->rng.r) - 0.5) * box;
                properties->bonds.u_bond[imt][i] = u[i];
            }
            properties->bonds.length[imt] = (gsl_rng_uniform(properties->rng.r) * mt_length);
        }

        //PrintMTs(parameters, properties);

        // Create some force stuff
        f_comp_ = (double **) allocate_2d_array(nmts, ndim, sizeof(double));
        t_comp_ = (double **) allocate_2d_array(nmts, ndim, sizeof(double));
        virial_comp_ = (double **) allocate_2d_array(ndim, ndim, sizeof(double));
        calc_matrix_ = (int *) allocate_1d_array(nmts, sizeof(int));
    }

    // Load OMP settings
    set_omp_settings(parameters);
    print_omp_settings();

    #ifdef ENABLE_OPENMP
    #pragma omp parallel
    #pragma omp master
    #endif
    {
        int n_threads = properties->mp_local.n_threads = properties->rng_mt.n_threads = 1;

        #ifdef ENABLE_OPENMP
        n_threads = properties->mp_local.n_threads = properties->rng_mt.n_threads = omp_get_num_threads();
        #endif

        properties->rng_mt.rng = (rng_properties*)
            allocate_1d_array(properties->rng_mt.n_threads, sizeof(rng_properties));

        gsl_rng_env_setup();
        for (int i_thr = 0; i_thr < properties->rng_mt.n_threads; ++i_thr) {
            properties->rng_mt.rng[i_thr].T = gsl_rng_default;
            properties->rng_mt.rng[i_thr].r = gsl_rng_alloc(properties->rng_mt.rng[i_thr].T);
            gsl_rng_set(properties->rng_mt.rng[i_thr].r, (parameters->seed+1) * i_thr);
        }

        properties->mp_local.f_local = (double***)
            allocate_3d_array(n_threads, properties->bonds.n_bonds, parameters->n_dim, sizeof(double));

        properties->mp_local.t_local = (double***)
            allocate_3d_array(n_threads, properties->bonds.n_bonds, 3, sizeof(double));

        properties->mp_local.virial_local = (double***)
            allocate_3d_array(n_threads, parameters->n_dim, parameters->n_dim, sizeof(double));

        properties->mp_local.calc_local = (int**)
            allocate_2d_array(n_threads, properties->bonds.n_bonds, sizeof(int));
    }
}

/// @brief Finish initializing the system
void UnitTestMinDistance::FinishInitSystem(system_parameters *parameters,
                                           system_properties *properties) {
    // Set up the neighbor lists
    properties->neighbors.neighbs = new nl_list[properties->bonds.n_bonds];
    update_neighbor_lists_sphero_all_pairs_mp(parameters->n_dim, parameters->n_periodic,
                                              properties->unit_cell.h,
                                              parameters->skin, parameters->r_cutoff,
                                              properties->bonds.n_bonds,
                                              properties->bonds.r_bond,
                                              properties->bonds.s_bond,
                                              properties->bonds.u_bond,
                                              properties->bonds.length,
                                              properties->neighbors.neighbs,
                                              parameters->nl_twoway_flag);
    // Calculate distance between rods and place value into neighborlists
    brownian_sphero_neighbor_lists(parameters, properties,
                                   f_comp_, virial_comp_,
                                   t_comp_, calc_matrix_);
}

/// @brief Print the microtubule information
void UnitTestMinDistance::PrintMTs(system_parameters *parameters, system_properties *properties ) {
    int ndim = parameters->n_dim;
    int nmts = properties->bonds.n_bonds;
    for (int imt = 0; imt < nmts; ++imt) {
        std::cout << "New Microtubule:\n";
        std::cout << "  r(" << properties->bonds.r_bond[imt][0] << ", "
                            << properties->bonds.r_bond[imt][1] << ", "
                            << properties->bonds.r_bond[imt][2] << ")\n";
        std::cout << "  u(" << properties->bonds.u_bond[imt][0] << ", "
                            << properties->bonds.u_bond[imt][1] << ", "
                            << properties->bonds.u_bond[imt][2] << ")\n";
        std::cout << "  l " << properties->bonds.length[imt] << std::endl;
    }
}

/// @brief Test random setups of microtubules via min_distance_sphero and segment_dist
bool UnitTestMinDistance::TestRandomSetup() {
    bool success = true;

    system_parameters parameters;
    system_properties properties;

    // Initialize and Read the system
    InitSystem(&parameters, &properties);
    double total_rminmag2 = 0.0;
    double total_mu = 0.0;
    double total_lambda = 0.0;
    double total_rmin[3] = {0.0};

    // Run the min distance sphero dr calc
    auto start_time= std::chrono::high_resolution_clock::now();
    {
        total_rminmag2 = 0.0;
        total_mu = 0.0;
        total_lambda = 0.0;
        total_rmin[0] = 0.0;
        total_rmin[1] = 0.0;
        total_rmin[2] = 0.0;
        success = TestMinDistanceSpheroDr(&parameters,
                                          &properties,
                                          &total_rminmag2,
                                          &total_lambda,
                                          &total_mu,
                                          total_rmin);
    }
    auto end_time = std::chrono::high_resolution_clock::now();
    auto time_span = std::chrono::duration_cast<std::chrono::duration<double>>(end_time - start_time);
    std::cout << "min_distance_sphero_dr, length: " << total_rminmag2 << ", mu: " << total_mu << ", lambda: " << total_lambda;
    std::cout << ", rmin(" << total_rmin[0] << ", " << total_rmin[1] << ", " << total_rmin[2] << ")\n";
    std::cout << "min_distance_sphero_dr: " << time_span.count() << "s\n";

    // Check the segment distance code as well now
    start_time= std::chrono::high_resolution_clock::now();
    {
        total_rminmag2 = 0.0;
        total_mu = 0.0;
        total_lambda = 0.0;
        total_rmin[0] = 0.0;
        total_rmin[1] = 0.0;
        total_rmin[2] = 0.0;
        success = TestMinDistanceSegments(&parameters,
                                          &properties,
                                          &total_rminmag2,
                                          &total_lambda,
                                          &total_mu,
                                          total_rmin);
    }
    end_time = std::chrono::high_resolution_clock::now();
    time_span = std::chrono::duration_cast<std::chrono::duration<double>>(end_time - start_time);
    std::cout << "segment_dist, length: " << total_rminmag2 << ", mu: " << total_mu << ", lambda: " << total_lambda;
    std::cout << ", rmin(" << total_rmin[0] << ", " << total_rmin[1] << ", " << total_rmin[2] << ")\n";
    std::cout << "segment_dist: " << time_span.count() << "s\n";

    return success;
}

/// @brief Test the min distance sphero dr
bool UnitTestMinDistance::TestMinDistanceSpheroDr(system_parameters *parameters,
                                                  system_properties *properties,
                                                  double *prminmag2,
                                                  double *plambda,
                                                  double *pmu,
                                                  double *prmin) {
    bool success = true;

    int nsteps = parameters->n_steps;
    int nbonds = properties->bonds.n_bonds;

    for (int istep = 0; istep < nsteps; ++istep) {
        // Calculate the min distance between all pairs of rods in the system
        for (int i_bond = 0; i_bond < nbonds - 1; ++i_bond) {
            for (int j_bond = i_bond + 1; j_bond < nbonds; ++j_bond) {
                double dr[3] = {0.0};
                double r_min[3] = {0.0};
                double r_min_mag2 = 0.0;
                double lambda = 0.0;
                double mu = 0.0;
                min_distance_sphero_dr(parameters->n_dim,
                                       parameters->n_periodic,
                                       properties->unit_cell.h,
                                       properties->bonds.r_bond[i_bond],
                                       properties->bonds.s_bond[i_bond],
                                       properties->bonds.u_bond[i_bond],
                                       properties->bonds.length[i_bond],
                                       properties->bonds.r_bond[j_bond],
                                       properties->bonds.s_bond[j_bond],
                                       properties->bonds.u_bond[j_bond],
                                       properties->bonds.length[j_bond],
                                       dr,
                                       r_min,
                                       &r_min_mag2,
                                       &lambda,
                                       &mu);

                *prminmag2 += r_min_mag2;
                *pmu += mu;
                *plambda += lambda;
                prmin[0] += r_min[0];
                prmin[1] += r_min[1];
                prmin[2] += r_min[2];
            }
        }
    }
    return success;
}

/// @brief Test the min distance segments from triangle mesh
bool UnitTestMinDistance::TestMinDistanceSegments(system_parameters *parameters,
                                                  system_properties *properties,
                                                  double *prminmag2,
                                                  double *plambda,
                                                  double *pmu,
                                                  double *prmin) {
    bool success = true;

    int nsteps = parameters->n_steps;
    int ndim = parameters->n_dim;
    int nbonds = properties->bonds.n_bonds;

    for (int istep = 0; istep < nsteps; ++istep) {
        // Calculate the min distance between all pairs of rods in the system
        for (int i_bond = 0; i_bond < nbonds - 1; ++i_bond) {
            for (int j_bond = i_bond + 1; j_bond < nbonds; ++j_bond) {
                double r_min[3] = {0.0};

                double t1 = 0.0;
                double t2 = 0.0;
                double dist = 0.0;

                //double r1_start[3] = {0.0};
                //double r1_end[3] = {0.0};
                //double r2_start[3] = {0.0};
                //double r2_end[3] = {0.0};

                //// Version 1 with the written version of each endpoint of the bonds
                //for (int i = 0; i < ndim; ++i) {
                //    r1_start[i] = properties->bonds.r_bond[i_bond][i] - 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][i];
                //    r1_end[i] = properties->bonds.r_bond[i_bond][i] + 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][i];
                //    r2_start[i] = properties->bonds.r_bond[j_bond][i] - 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][i];
                //    r2_end[i] = properties->bonds.r_bond[j_bond][i] + 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][i];
                //}

                //segment_dist(r1_start[0], r1_start[1], r1_start[2],
                //             r1_end[0], r1_end[1], r1_end[2],
                //             r2_start[0], r2_start[1], r2_start[2],
                //             r2_end[0], r2_end[1], r2_end[2],
                //             &t1, &t2, &dist);

                // Version 2 with the points just set into segment_dist
                segment_dist(properties->bonds.r_bond[i_bond][0] - 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][0],
                             properties->bonds.r_bond[i_bond][1] - 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][1],
                             properties->bonds.r_bond[i_bond][2] - 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][2],
                             properties->bonds.r_bond[i_bond][0] + 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][0],
                             properties->bonds.r_bond[i_bond][1] + 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][1],
                             properties->bonds.r_bond[i_bond][2] + 0.5 * properties->bonds.length[i_bond] * properties->bonds.u_bond[i_bond][2],
                             properties->bonds.r_bond[j_bond][0] - 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][0],
                             properties->bonds.r_bond[j_bond][1] - 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][1],
                             properties->bonds.r_bond[j_bond][2] - 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][2],
                             properties->bonds.r_bond[j_bond][0] + 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][0],
                             properties->bonds.r_bond[j_bond][1] + 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][1],
                             properties->bonds.r_bond[j_bond][2] + 0.5 * properties->bonds.length[j_bond] * properties->bonds.u_bond[j_bond][2],
                             &t1, &t2, &dist);

                double mu = (t2 - 0.5)*properties->bonds.length[j_bond];
                double lambda = (t1 - 0.5)*properties->bonds.length[i_bond];
                r_min[0] = (properties->bonds.r_bond[j_bond][0] + mu * properties->bonds.u_bond[j_bond][0]) - (properties->bonds.r_bond[i_bond][0] + lambda * properties->bonds.u_bond[i_bond][0]);
                r_min[1] = (properties->bonds.r_bond[j_bond][1] + mu * properties->bonds.u_bond[j_bond][1]) - (properties->bonds.r_bond[i_bond][1] + lambda * properties->bonds.u_bond[i_bond][1]);
                r_min[2] = (properties->bonds.r_bond[j_bond][2] + mu * properties->bonds.u_bond[j_bond][2]) - (properties->bonds.r_bond[i_bond][2] + lambda * properties->bonds.u_bond[i_bond][2]);


                *prminmag2 += dist;
                *pmu += mu;
                *plambda += lambda;
                prmin[0] += r_min[0];
                prmin[1] += r_min[1];
                prmin[2] += r_min[2];
            }
        }
    }

    return success;
}
