// Implementation of testing of probabilities

#include "bob.h"
#include "probabilities.h"
#include "unit_test_probabilities.h"

#include <iostream>

// Run the tests and all their variants
void UnitTestProbabilities::RunTests() {
    std::cout << "****************\n";
    std::cout << "Unit Test Probabilities run tests\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";

        for (int iv = 0; iv < node_["Tests"][kv.first].size(); ++iv) {
            std::cout << "  Variant : " << iv << std::endl;
            var_subnode_ = node_["Tests"][kv.first][iv];
            iv_ = iv;
            auto result = kv.second();

            if (!result) {
                std::cout << "Test : " << kv.first << " failed, check output!\n";
                exit(1);
            }
        }
    }
}

// Set the tests to run
void UnitTestProbabilities::SetTests() {
    node_ = YAML::LoadFile(filename_);

    // Check against method calls
    if (node_["Tests"]["Integrand12"]) {
        tests_["Integrand12"] = std::bind(&UnitTestProbabilities::Integrand12, this);
    }
    if (node_["Tests"]["Integrand12Fdep"]) {
        tests_["Integrand12Fdep"] = std::bind(&UnitTestProbabilities::Integrand12Fdep, this);
    }
    if (node_["Tests"]["Integrand12Advanced"]) {
        tests_["Integrand12Advanced"] = std::bind(&UnitTestProbabilities::Integrand12Advanced, this);
    }
    if (node_["Tests"]["Prob12"]) {
        tests_["Prob12"] = std::bind(&UnitTestProbabilities::Prob12, this);
    }
    if (node_["Tests"]["Prob12Fdep"]) {
        tests_["Prob12Fdep"] = std::bind(&UnitTestProbabilities::Prob12Fdep, this);
    }
    if (node_["Tests"]["Prob12Advanced"]) {
        tests_["Prob12Advanced"] = std::bind(&UnitTestProbabilities::Prob12Advanced, this);
    }
}

bool UnitTestProbabilities::PrintandCheck(integration_params *tparams,
                                          double x,
                                          double retval,
                                          double answer,
                                          double tol) {
    bool success = true;

    std::cout << "  testing integrand 12 with r0: " << tparams->r0 << ", y0: " <<
        tparams->y0 << ", alpha: " << tparams->alpha << ", xc: " << tparams->xc <<
        ", x: " << x << std::endl;
    std::cout << "    result: " << retval << std::endl;
    std::cout << "    answer: " << answer << std::endl;

    if (!almost_equal(retval, answer, tol)) success = false;

    return success;
}

bool UnitTestProbabilities::PrintandCheck(integration_params *tparams,
                                          std::vector<double> x,
                                          double retval,
                                          double answer,
                                          double tol) {
    bool success = true;

    std::cout << "  testing integrand 12 with r0: " << tparams->r0 << ", y0: " <<
        tparams->y0 << ", alpha: " << tparams->alpha << ", xc: " << tparams->xc <<
        ", x[0]: " << x[0] << ", x[1]: " << x[1] << std::endl;
    std::cout << "    result: " << retval << std::endl;
    std::cout << "    answer: " << answer << std::endl;

    if (!almost_equal(retval, answer, tol)) success = false;

    return success;
}


// Integrand12 without anything else
bool UnitTestProbabilities::Integrand12() {
    bool success = true;

    integration_params tparams;
    tparams.xc = 0.0;
    tparams.r0 = var_subnode_["params"]["r0"].as<double>();
    tparams.y0 = var_subnode_["params"]["y0"].as<double>();
    tparams.alpha = var_subnode_["params"]["alpha"].as<double>();
    double x = var_subnode_["params"]["x"].as<double>();
    double answer = var_subnode_["params"]["answer"].as<double>();

    double retval = integrand_1_2(x, &tparams);

    success = PrintandCheck(&tparams, x, retval, answer);

    return success;
}

// Integrand12 with force dependence
bool UnitTestProbabilities::Integrand12Fdep() {
    bool success = true;

    integration_params tparams;
    tparams.r0 = var_subnode_["params"]["r0"].as<double>();
    tparams.y0 = var_subnode_["params"]["y0"].as<double>();
    tparams.alpha = var_subnode_["params"]["alpha"].as<double>();
    tparams.xc = var_subnode_["params"]["xc"].as<double>();
    double x = var_subnode_["params"]["x"].as<double>();
    double answer = var_subnode_["params"]["answer"].as<double>();

    double retval = integrand_1_2_fdep(x, &tparams);

    success = PrintandCheck(&tparams, x, retval, answer);

    return success;
}

// Integrand12 Advanced verison with lots of stuff in it
bool UnitTestProbabilities::Integrand12Advanced() {
    bool success = true;

    integration_params tparams;
    tparams.r0 = var_subnode_["params"]["r0"].as<double>();
    tparams.y0 = var_subnode_["params"]["y0"].as<double>();
    tparams.alpha = var_subnode_["params"]["alpha1"].as<double>();
    tparams.alpha2 = var_subnode_["params"]["alpha2"].as<double>();
    tparams.xc = var_subnode_["params"]["xc"].as<double>();
    tparams.costhetaminA = var_subnode_["params"]["costhetaminA"].as<double>();
    tparams.costheta2A = var_subnode_["params"]["costheta2A"].as<double>();
    double x = var_subnode_["params"]["x"].as<double>();
    double answer = var_subnode_["params"]["answer"].as<double>();
    double tol = var_subnode_["params"]["tolerance"].as<double>();

    double retval = advanced_integrand_1_2(x, &tparams);

    success = PrintandCheck(&tparams, x, retval, answer, tol);

    return success;
}

// Probability 12 without anything else
bool UnitTestProbabilities::Prob12() {
    bool success = true;

    integration_params tparams;
    std::vector<double> x;
    x.resize(2);

    tparams.r0 = var_subnode_["params"]["r0"].as<double>();
    tparams.y0 = 0.0;
    tparams.alpha = var_subnode_["params"]["alpha"].as<double>();
    tparams.xc = 0.0;
    x[0] = var_subnode_["params"]["xmax"].as<double>();
    x[1] = var_subnode_["params"]["y0"].as<double>();
    double answer = var_subnode_["params"]["answer"].as<double>();

    double retval = prob_1_2(x, &tparams);

    success = PrintandCheck(&tparams, x, retval, answer);

    return success;
}

// Probability 12 with force dependence
bool UnitTestProbabilities::Prob12Fdep() {
    bool success = true;

    integration_params tparams;
    std::vector<double> x;
    x.resize(2);

    tparams.r0 = var_subnode_["params"]["r0"].as<double>();
    tparams.y0 = 0.0;
    tparams.alpha = var_subnode_["params"]["alpha"].as<double>();
    tparams.xc = var_subnode_["params"]["xc"].as<double>();
    x[0] = var_subnode_["params"]["xmax"].as<double>();
    x[1] = var_subnode_["params"]["y0"].as<double>();
    double answer = var_subnode_["params"]["answer"].as<double>();

    double retval = prob_1_2_fdep(x, &tparams);

    success = PrintandCheck(&tparams, x, retval, answer);

    return success;
}

// Probability 12 advanced
bool UnitTestProbabilities::Prob12Advanced() {
    bool success = true;

    integration_params tparams;
    std::vector<double> x;
    x.resize(4);

    tparams.r0 = var_subnode_["params"]["r0"].as<double>();
    tparams.y0 = 0.0;
    tparams.alpha = var_subnode_["params"]["alpha1"].as<double>();
    tparams.alpha2 = var_subnode_["params"]["alpha2"].as<double>();
    tparams.xc = var_subnode_["params"]["xc"].as<double>();
    tparams.chic = 0.0;
    //tparams.costhetaminA = var_subnode_["params"]["costhetaminA"].as<double>();
    //tparams.costheta2A = var_subnode_["params"]["costheta2A"].as<double>();
    x[2] = var_subnode_["params"]["costhetaminA"].as<double>();
    x[3] = var_subnode_["params"]["costheta2A"].as<double>();
    double answer = var_subnode_["params"]["answer"].as<double>();
    double tol = var_subnode_["params"]["tolerance"].as<double>();

    x[0] = var_subnode_["params"]["xmax"].as<double>();
    x[1] = var_subnode_["params"]["y0"].as<double>();
    
    double retval = 0.0;
    retval = prob_1_2_advanced(x, &tparams);

    success = PrintandCheck(&tparams, x, retval, answer);

    return success;
}
