// Implementation of unit tests for xlink management

#include "bob.h"
#include "unit_test_xlink_management.h"
#include "xlink_entry.h"

//typedef std::vector<XlinkEntry> xlink_list;
#include <iostream>

// Run the tests and all their variants
void UnitTestXlinkManagement::RunTests() {
    std::cout << "****************\n";
    std::cout << "Unit Test Xlink Management run tests\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";

        for (int iv = 0; iv < node_["Tests"][kv.first].size(); ++iv) {
            std::cout << "  Variant : " << iv << std::endl;
            var_subnode_ = node_["Tests"][kv.first][iv];
            iv_ = iv;
            auto result = kv.second();

            if (!result) {
                std::cout << "Test : " << kv.first << " failed, check output!\n";
                exit(1);
            }
        }
    }
}

void UnitTestXlinkManagement::BasicInit(system_parameters * parameters,
                                        system_properties * properties){

    // Set up random number generators
    gsl_rng_env_setup();
    properties->rng.T = gsl_rng_default;
    properties->rng.r = gsl_rng_alloc(properties->rng.T);
    gsl_rng_set(properties->rng.r, parameters->seed);

    // Generic parameter initialization
    parse_parameters_node(var_subnode_["params"], parameters);
    //init_default_params(parameters);

    // Costum parameter initialization
    parameters->n_periodic = 0;
    int n_dim = parameters->n_dim = 3;
    properties->unit_cell.h = (double **) allocate_2d_array(n_dim, n_dim, sizeof(double));

    return;

}

// Set up the possible tests to run
void UnitTestXlinkManagement::SetTests() {
    node_ = YAML::LoadFile(filename_);

    // Check against names to bind method calls
    if (node_["Tests"]["Stage0Diffusion"]) {
        tests_["Stage0Diffusion"] = std::bind(&UnitTestXlinkManagement::Stage0Diffusion, this);
    }
    if (node_["Tests"]["Stage1Diffusion"]) {
        tests_["Stage1Diffusion"] = std::bind(&UnitTestXlinkManagement::Stage1Diffusion, this);
    }
    if (node_["Tests"]["Stage2Diffusion"]) {
        tests_["Stage2Diffusion"] = std::bind(&UnitTestXlinkManagement::Stage2Diffusion, this);
    }
}

bool UnitTestXlinkManagement::Stage0Diffusion() {
    // Decide how to run diffusion of an xlink
    // 1) Initialize xlink management/an xlink entry 
    // 2) Run a step of diffusion
    // 3) Compare to an output to determine if it worked properly
    //      If result is not defined in the var_subnode have a ready made calculation
    // 4) Cleanup
    // Test the diffusion of stage0 xlinks
    
    bool success = true;

    system_parameters parameters;
    system_properties properties;

    // Parameters specific to stage0diffusion test, can be changed in param subnode
    properties.bonds.n_bonds = 0;
    int n_steps = parameters.n_steps = 1000000;
    double delta = parameters.delta = .00025;

    int n_print_percent;
    if (n_steps > 100) {
        n_print_percent = n_steps / 100;
    } else {
        n_print_percent = 1;
    }

    std::string xlink_file = var_subnode_["params"]["xlink_file"].as<std::string>();
    YAML::Node xlink_node = YAML::LoadFile(xlink_file.c_str());

    double diffusion_coeff = xlink_node["crosslink"][0]["diffusion_free"].as<double>(); 

    BasicInit(&parameters, &properties);

    // Max distance a random walk could make,
    double max_dist = parameters.delta* parameters.n_steps* diffusion_coeff;
    double expect_dist = sqrt(2*delta*n_steps*parameters.n_dim*diffusion_coeff);

    for (int i = 0; i < parameters.n_dim; ++i) {
        properties.unit_cell.h[i][i] = max_dist*2;
    }

    init_unit_cell_structure(&parameters, &properties);

    XlinkManagement xlink_mgt(&parameters, &properties, var_subnode_["params"]["xlink_file"].as<std::string>().c_str(), nullptr);

    // Initialize all xlinks to the center of the system
    xlink_mgt.CenterXlinks();

    // Main unit test loop
    for (int i_step=0; i_step<parameters.n_steps; i_step++){
        // Print out the current percentage complete
        if (i_step % n_print_percent == 0) {
            fprintf(stdout, "%d%% Complete\n", (int)(100 * (float)i_step / (float)n_steps));
            fflush(stdout);
        }

        //TODO Not sure if this is the best way to do this
        xlink_mgt.Update_0_1_Probability(&parameters, &properties);
    }

    //Make a vector of the distances squared
    int tot_free = 0; 
    for (int i_type = 0; i_type < xlink_mgt.n_types_; ++i_type) 
        tot_free += xlink_mgt.n_free_[i_type];
    std::vector<double> diffuse_dist2 (tot_free, 0);
    xlink_mgt.FillFreeXlinkSqrdDisplacementVec(diffuse_dist2);

    //double mean_dist[3] = {0};
    //double var_dist[3] = {0};

    
    //Mean of diffusion
    double diffuse_dist2_mean = 0;
    //Standard deviation 
    double diffuse_dist2_std = 0;

    for (int i=0; i<tot_free; i++){
        //std::cout<<"  xlink "<<i<<": "<<diffuse_dist2[i]<<std::endl;
        diffuse_dist2_mean += diffuse_dist2[i];
    }
    diffuse_dist2_mean /= double(tot_free);

    for (int i=0; i<tot_free; i++){
        //std::cout<<"  xlink "<<i<<": "<<diffuse_dist2[i]<<std::endl;
        diffuse_dist2_std += SQR(diffuse_dist2[i]-diffuse_dist2_mean);
    }
    diffuse_dist2_std /= (double(tot_free)-1.0)*double(tot_free);
    diffuse_dist2_std = sqrt(diffuse_dist2_std);


    //double diffuse_dist_mean = sqrt(diffuse_dist2_mean);
    double calc_diffusion_coeff = diffuse_dist2_mean/(2*double(parameters.n_dim)*delta*n_steps);
    double calc_diffusion_SE = diffuse_dist2_std/(2*double(parameters.n_dim)*delta*n_steps);

// More indepth analysis of displacement
/*
 *    // Calc mean displacement of xlinks
 *    for (int i_type = 0; i_type < xlink_mgt.n_types_; ++i_type) {
 *        for (auto xl_iter = xlink_mgt.stage_0_xlinks_[i_type].begin();
 *             xl_iter < xlink_mgt.stage_0_xlinks_[i_type].end();
 *             ++xl_iter) {
 *            for (int i=0; i<parameters.n_dim; ++i)
 *                mean_dist[i] += xl_iter->r_cross_[i];
 *        }
 *    }
 *    for (int i=0; i<parameters.n_dim; ++i)
 *        mean_dist[i] /= tot_free;
 *
 *    // Calc variance of displacements
 *    for (int i_type = 0; i_type < xlink_mgt.n_types_; ++i_type) {
 *        for (auto xl_iter = xlink_mgt.stage_0_xlinks_[i_type].begin();
 *             xl_iter < xlink_mgt.stage_0_xlinks_[i_type].end();
 *             ++xl_iter) {
 *        }
 *    }
 */

    double error = diffusion_coeff-calc_diffusion_coeff; 
    /*
     *std::cout<<"---Diffusion squared mean of xlinks---\n";
     *std::cout<<diffuse_dist2_mean<<std::endl;
     *std::cout<<"---sqrt Diffusion squared mean of xlinks---\n";
     *std::cout<<sqrt(diffuse_dist2_mean)<<std::endl;
     *std::cout<<"---Expected sqrt Diffusion squared mean of xlinks---\n";
     *std::cout<<expect_dist<<std::endl;
     */
    std::cout<<"Given diffusion constant: "<< diffusion_coeff <<
        ", Calculated diffusion constant: "<< calc_diffusion_coeff <<std::endl <<
        "   Error: "<< error <<
        ", Standard Error: "<< calc_diffusion_SE << std::endl;

    // TODO is there a better way to check diffusion than this?
    if (error > calc_diffusion_SE) success = false;

    if (var_subnode_["results"]["max_error"]){
        double max_error = var_subnode_["results"]["max_error"].as<double>();
        if(max_error < calc_diffusion_SE || max_error < error)
            success = false;
    }



    return success;
}

bool UnitTestXlinkManagement::Stage1Diffusion() {
    bool success = true;

    std::cout << "Stage1Diffusion test not implemented\n" << std::endl;

    return success;
}


bool UnitTestXlinkManagement::Stage2Diffusion() {
    bool success = true;

    std::cout << "Stage2Diffusion test not implemented\n" << std::endl;

    return success;
}

