// Implementaiton of unit test of lookup tables in multiple ways

#include "bob.h"
#include "lookup_table.h"
#include "lookup_table_advanced.h"
#include "probabilities.h"
#include "unit_test_lookup_table.h"

#include <iostream>

// Run the tests and all their variants
void UnitTestLookupTable::RunTests() {
    std::cout << "****************\n";
    std::cout << "Unit Test Lookup Table run tests\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";

        for (int iv = 0; iv < node_["Tests"][kv.first].size(); ++iv) {
            std::cout << "  Variant : " << iv << std::endl;
            var_subnode_ = node_["Tests"][kv.first][iv];
            iv_ = iv;
            auto result = kv.second();

            if (!result) {
                std::cout << "Test : " << kv.first << " failed, check output!\n";
                exit(1);
            }
        }
    }
}

// Set the possible tests to run
void UnitTestLookupTable::SetTests() {
    node_ = YAML::LoadFile(filename_);

    // Check against method calls
    if (node_["Tests"]["BasicTest12"]) {
        tests_["BasicTest12"] = std::bind(&UnitTestLookupTable::BasicTest12, this);
    }
    if (node_["Tests"]["BasicTestFdep"]) {
        tests_["BasicTestFdep"] = std::bind(&UnitTestLookupTable::BasicTestFdep, this);
    }
    if (node_["Tests"]["LookupTest12"]) {
        tests_["LookupTest12"] = std::bind(&UnitTestLookupTable::LookupTest12, this);
    }
    if (node_["Tests"]["LookupTest12Values"]) {
        tests_["LookupTest12Values"] = std::bind(&UnitTestLookupTable::LookupTest12Values, this);
    }
    if (node_["Tests"]["AdvancedTest"]) {
        tests_["AdvancedTest"] = std::bind(&UnitTestLookupTable::AdvancedTest, this);
    }
    if (node_["Tests"]["Invert"]) {
        tests_["Invert"] = std::bind(&UnitTestLookupTable::Invert, this);
    }
}

// Check the selected answers
bool UnitTestLookupTable::CheckSelectedAnswers(LookupTable *lut,
                                               std::map<int, std::map<int, double>> *sa) {
    bool success = true;

    auto x = (*lut).x_;

    for (auto &kv0 : *sa) {
        for (auto &kv1 : kv0.second) {
            double retval = (*lut).table_[kv0.first * x[1].size() + kv1.first];
            double ans = kv1.second;
            std::cout << "    value  (x: " << kv0.first << ", y: " << kv1.first
                << "), = " << retval << std::endl;
            std::cout << "    answer (x: " << kv0.first << ", y: " << kv1.first
                << "), = " << ans << std::endl;
            if (!almost_equal(retval, ans, 0.00001)) success = success && false;
        }
    }

    return success;
}

// Check the selected answers
bool UnitTestLookupTable::CheckSelectedAnswersLookup(LookupTable *lut,
                                                     std::map<int, std::map<int, double>> *sa) {
    bool success = true;

    auto x = (*lut).x_;

    for (auto &kv0 : *sa) {
        for (auto &kv1 : kv0.second) {
            double x[2] = {(double)kv0.first,(double)kv1.first};
            double retval = lut->Lookup(x);
            double ans = kv1.second;
            std::cout << "    value  (x: " << kv0.first << ", y: " << kv1.first
                << "), = " << retval << std::endl;
            std::cout << "    answer (x: " << kv0.first << ", y: " << kv1.first
                << "), = " << ans << std::endl;
            if (!almost_equal(retval, ans, 0.00001)) success = success && false;
        }
    }

    return success;
}

bool UnitTestLookupTable::CheckSelectedAnswersLookup(LookupTable *lut,
                                                     std::map<double, std::map<double, double>> *sa) {
    bool success = true;

    auto x = (*lut).x_;

    for (auto &kv0 : *sa) {
        for (auto &kv1 : kv0.second) {
            double x[2] = {kv0.first,kv1.first};
            double retval = lut->Lookup(x);
            double ans = kv1.second;
            std::cout << "    value  (x: " << kv0.first << ", y: " << kv1.first
                << "), = " << retval << std::endl;
            std::cout << "    answer (x: " << kv0.first << ", y: " << kv1.first
                << "), = " << ans << std::endl;
            if (!almost_equal(retval, ans, 0.00001)) success = success && false;
        }
    }

    return success;
}

// Check the selected answers
bool UnitTestLookupTable::CheckSelectedAnswersAdvanced(LookupTableAdvanced *lut,
                                                       std::vector<std::tuple<double, double, double, double, double>> *sa,
                                                       double tolerance) {
    bool success = true;

    auto x = (*lut).x_;

    for (auto kv0 : *sa) {
        double xtest[4] = {std::get<0>(kv0),
                           std::get<1>(kv0),
                           std::get<2>(kv0),
                           std::get<3>(kv0)};
        double retval = lut->Lookup(xtest);
        double ans = std::get<4>(kv0);
        std::cout << "    lut x[" <<
            xtest[0] << ", " <<
            xtest[1] << ", " <<
            xtest[2] << ", " <<
            xtest[3] << ", " <<
            "] = " << retval << std::endl;
        std::cout << "    ans x[" <<
            xtest[0] << ", " <<
            xtest[1] << ", " <<
            xtest[2] << ", " <<
            xtest[3] << ", " <<
            "] = " << ans << std::endl;
        if (!almost_equal(retval, ans, tolerance)) success = success && false;
    }

    return success;
}

// Check the inversion
bool UnitTestLookupTable::CheckInvert(LookupTableAdvanced *lut,
                                      std::vector<std::tuple<double, double, double, double, double>> *sa,
                                      double tolerance) {
    bool success = true;

    for (auto kv0 : *sa) {
        double xtest[4] = {std::get<0>(kv0),
                           std::get<1>(kv0),
                           std::get<2>(kv0),
                           std::get<3>(kv0)};
        double u = xtest[0];
        double retval = lut->Invert(0, u, xtest);
        double ans = std::get<4>(kv0);
        std::cout << "    Invert: x[unknown, "
            << xtest[1] << ", "
            << xtest[2] << ", "
            << xtest[3] << "], u = "
            << u << std::endl;
        std::cout << "    lut s: " << retval << std::endl;
        std::cout << "    ans s: " << ans << std::endl;
        if (!almost_equal(retval, ans, tolerance)) success = success && false;
    }

    return success;
}

bool UnitTestLookupTable::CheckInvertOriginal(LookupTable *lut,
                                              std::vector<std::tuple<double, double, double, double, double>> *sa,
                                              double tolerance) {
    bool success = true;

    for (auto kv0 : *sa) {
        double xtest[2] = {std::get<0>(kv0),
                           std::get<1>(kv0)};
        double u = xtest[0];
        double retval = lut->Invert(0, u, xtest);
        double ans = std::get<4>(kv0);
        std::cout << "    Invert: x[unknown, "
            << xtest[1] << "], u = "
            << u << std::endl;
        std::cout << "    lut s: " << retval << std::endl;
        std::cout << "    ans s: " << ans << std::endl;
        if (!almost_equal(retval, ans, tolerance)) success = success && false;
    }

    return success;
}

// Basic test of probability 12
bool UnitTestLookupTable::BasicTest12() {
    bool success = true;

    xlink_params params;
    std::vector<double> x[2];

    params.r0 = var_subnode_["params"]["r0"].as<double>();
    params.alpha = var_subnode_["params"]["alpha"].as<double>();
    double a_cutoff = var_subnode_["params"]["a_cutoff"].as<double>();
    double y_cutoff = var_subnode_["params"]["y_cutoff"].as<double>();
    double bin_size = var_subnode_["params"]["bin_size"].as<double>();

    std::map<int, std::map<int, double>> selected_answers;
    for (auto nit = var_subnode_["results"].begin();
              nit != var_subnode_["results"].end();
              ++nit) {
        YAML::Node mysubnode = *nit;
        int a = mysubnode["a"].as<int>();
        int b = mysubnode["b"].as<int>();
        double ans = mysubnode["ans"].as<double>();
        selected_answers[a][b] = ans;
    }

    for (double a = 0.0; a <= a_cutoff; a += bin_size) {
        x[0].push_back(a);
    }
    for (double y0 = 0.0; y0 <= y_cutoff; y0 += bin_size) {
        x[1].push_back(y0);
    }

    LookupTable lut;
    lut.Init(2, x, &prob_1_2, &params);

    success = CheckSelectedAnswers(&lut, &selected_answers);

    return success;
}

// Basic test of looup table fdep
bool UnitTestLookupTable::BasicTestFdep() {
    bool success = true;

    xlink_params params;
    std::vector<double> x[2];

    params.r0 = var_subnode_["params"]["r0"].as<double>();
    params.alpha = var_subnode_["params"]["alpha"].as<double>();
    params.xc = var_subnode_["params"]["xc"].as<double>();
    double a_cutoff = var_subnode_["params"]["a_cutoff"].as<double>();
    double y_cutoff = var_subnode_["params"]["y_cutoff"].as<double>();
    double bin_size = var_subnode_["params"]["bin_size"].as<double>();

    std::map<int, std::map<int, double>> selected_answers;
    for (auto nit = var_subnode_["results"].begin();
              nit != var_subnode_["results"].end();
              ++nit) {
        YAML::Node mysubnode = *nit;
        int a = mysubnode["a"].as<int>();
        int b = mysubnode["b"].as<int>();
        double ans = mysubnode["ans"].as<double>();
        selected_answers[a][b] = ans;
    }

    for (double a = 0.0; a <= a_cutoff; a += bin_size) {
        x[0].push_back(a);
    }
    for (double y0 = 0.0; y0 <= y_cutoff; y0 += bin_size) {
        x[1].push_back(y0);
    }

    LookupTable lut;
    lut.Init(2, x, &prob_1_2_fdep, &params);

    success = CheckSelectedAnswers(&lut, &selected_answers);

    return success;
}

// Basic test of looking up 12
bool UnitTestLookupTable::LookupTest12() {
    bool success = true;

    xlink_params params;
    std::vector<double> x[2];

    params.r0 = var_subnode_["params"]["r0"].as<double>();
    params.alpha = var_subnode_["params"]["alpha"].as<double>();
    double a_cutoff = var_subnode_["params"]["a_cutoff"].as<double>();
    double y_cutoff = var_subnode_["params"]["y_cutoff"].as<double>();
    double bin_size = var_subnode_["params"]["bin_size"].as<double>();

    std::map<int, std::map<int, double>> selected_answers;
    for (auto nit = var_subnode_["results"].begin();
              nit != var_subnode_["results"].end();
              ++nit) {
        YAML::Node mysubnode = *nit;
        int a = mysubnode["a"].as<int>();
        int b = mysubnode["b"].as<int>();
        double ans = mysubnode["ans"].as<double>();
        selected_answers[a][b] = ans;
    }

    for (double a = 0.0; a <= a_cutoff; a += bin_size) {
        x[0].push_back(a);
    }
    for (double y0 = 0.0; y0 <= y_cutoff; y0 += bin_size) {
        x[1].push_back(y0);
    }

    LookupTable lut;
    lut.Init(2, x, &prob_1_2, &params);

    success = CheckSelectedAnswersLookup(&lut, &selected_answers);

    return success;
}

bool UnitTestLookupTable::LookupTest12Values() {
    bool success = true;

    xlink_params params;
    std::vector<double> x[2];

    params.r0 = var_subnode_["params"]["r0"].as<double>();
    params.alpha = var_subnode_["params"]["alpha"].as<double>();
    double a_cutoff = var_subnode_["params"]["a_cutoff"].as<double>();
    double y_cutoff = var_subnode_["params"]["y_cutoff"].as<double>();
    double bin_size = var_subnode_["params"]["bin_size"].as<double>();

    std::map<double, std::map<double, double>> selected_answers;
    for (auto nit = var_subnode_["results"].begin();
              nit != var_subnode_["results"].end();
              ++nit) {
        YAML::Node mysubnode = *nit;
        double a = mysubnode["a"].as<double>();
        double b = mysubnode["b"].as<double>();
        double ans = mysubnode["ans"].as<double>();
        selected_answers[a][b] = ans;
    }

    for (double a = 0.0; a <= a_cutoff; a += bin_size) {
        x[0].push_back(a);
    }
    for (double y0 = 0.0; y0 <= y_cutoff; y0 += bin_size) {
        x[1].push_back(y0);
    }

    LookupTable lut;
    lut.Init(2, x, &prob_1_2, &params);

    success = CheckSelectedAnswersLookup(&lut, &selected_answers);
    return success;
}


// Advanced test of new lookup tables (4d and higher)
bool UnitTestLookupTable::AdvancedTest() {
    bool success = true;

    xlink_params params;
    int ndim = var_subnode_["params"]["dim"].as<int>();
    std::vector<double> x[4];

    params.r0 = var_subnode_["params"]["r0"].as<double>();
    params.alpha = var_subnode_["params"]["alpha"].as<double>();
    params.alpha2 = 0.0;
    params.xc = var_subnode_["params"]["xc"].as<double>();
    double a_cutoff = var_subnode_["params"]["a_cutoff"].as<double>();
    double y_cutoff = var_subnode_["params"]["y_cutoff"].as<double>();
    double bin_size = var_subnode_["params"]["bin_size"].as<double>();

    std::vector<std::tuple<double, double, double, double, double>> selected_answers;
    double tolerance = 0.00001;
    for (auto nit =  var_subnode_["results"].begin();
              nit != var_subnode_["results"].end();
              ++nit) {
        YAML::Node mysubnode = *nit;
        double a = mysubnode["a"].as<double>();
        double b = mysubnode["b"].as<double>();
        double c = mysubnode["c"].as<double>();
        double d = mysubnode["d"].as<double>();
        double ans = mysubnode["ans"].as<double>();
        std::tuple<double, double, double, double, double> mytup = std::make_tuple(a, b, c, d, ans);
        selected_answers.push_back(mytup);

        if (mysubnode["tolerance"]) {
            tolerance = mysubnode["tolerance"].as<double>();
        }
    }

    for (double a = 0.0; a <= a_cutoff; a += bin_size) {
        x[0].push_back(a);
    }
    for (double y0 = 0.0; y0 <= y_cutoff; y0 += bin_size) {
        x[1].push_back(y0);
    }
    if (ndim > 2) {
        for (double theta0 = -1.0; theta0 <= 1.0; theta0 += bin_size) {
            if (ndim == 3) {
                //std::cout << "pushing back theta0: " << theta0 << std::endl;
                x[2].push_back(theta0);
                x[3].push_back(0.0);
            }
            if (ndim == 4) {
                x[2].push_back(theta0);
                x[3].push_back(theta0);
            }
        }
        params.alpha2 = var_subnode_["params"]["alpha2"].as<double>();
    }

    LookupTableAdvanced lut;
    lut.Init(ndim, x, &prob_1_2_advanced, &params);

    success = CheckSelectedAnswersAdvanced(&lut, &selected_answers, tolerance);

    return success;
}

bool UnitTestLookupTable::Invert() {
    bool success = true;

    xlink_params params;
    int ndim = var_subnode_["params"]["dim"].as<int>();
    std::vector<double> x[4];

    params.r0 = var_subnode_["params"]["r0"].as<double>();
    params.alpha = var_subnode_["params"]["alpha"].as<double>();
    params.alpha2 = 0.0;
    params.xc = var_subnode_["params"]["xc"].as<double>();
    double a_cutoff = var_subnode_["params"]["a_cutoff"].as<double>();
    double y_cutoff = var_subnode_["params"]["y_cutoff"].as<double>();
    double bin_size = var_subnode_["params"]["bin_size"].as<double>();

    std::vector<std::tuple<double, double, double, double, double>> selected_answers;
    double tolerance = 0.00001;
    double u = 0.0;
    for (auto nit =  var_subnode_["results"].begin();
              nit != var_subnode_["results"].end();
              ++nit) {
        YAML::Node mysubnode = *nit;
        double a = mysubnode["u"].as<double>();
        double b = mysubnode["b"].as<double>();
        double c = mysubnode["c"].as<double>();
        double d = mysubnode["d"].as<double>();
        double ans = mysubnode["ans"].as<double>();
        std::tuple<double, double, double, double, double> mytup = std::make_tuple(a, b, c, d, ans);
        selected_answers.push_back(mytup);

        if (mysubnode["tolerance"]) {
            tolerance = mysubnode["tolerance"].as<double>();
        }
    }

    for (double a = 0.0; a <= a_cutoff; a += bin_size) {
        x[0].push_back(a);
    }
    for (double y0 = 0.0; y0 <= y_cutoff; y0 += bin_size) {
        x[1].push_back(y0);
    }
    if (ndim > 2) {
        for (double theta0 = -1.0; theta0 <= 1.0; theta0 += bin_size) {
            if (ndim == 3) {
                //std::cout << "pushing back theta0: " << theta0 << std::endl;
                x[2].push_back(theta0);
                x[3].push_back(0.0);
            }
            if (ndim == 4) {
                x[2].push_back(theta0);
                x[3].push_back(theta0);
            }
        }
        params.alpha2 = var_subnode_["params"]["alpha2"].as<double>();
    }

    if (ndim == 2) {
        LookupTable lut;
        lut.Init(ndim, x, &prob_1_2_fdep, &params);

        success = CheckInvertOriginal(&lut, &selected_answers, tolerance);

    } else {
        LookupTableAdvanced lut;
        lut.Init(ndim, x, &prob_1_2_advanced, &params);

        success = CheckInvert(&lut, &selected_answers, tolerance);
    }

    return success;
}
