// Implementaiton of integration tests for xlinks

#include "bob.h"
#include "correlation_data.h"
#include "integration_test_fokkerplanck.h"

#include <fstream>
#include <iostream>

// Run the tests and all their variants
void IntegrationTestFokkerPlanck::RunTests() {
    std::cout << "****************\n";
    std::cout << "Integration Test Fokker Planck run tests\n";
    std::cout << "****************\n";

    for (auto &kv : tests_) {
        std::cout << "----------------\n";
        std::cout << "Test : " << kv.first << std::endl;
        std::cout << "----------------\n";

        for (int iv = 0; iv < node_["Tests"][kv.first].size(); ++iv) {
            std::cout << "  Variant : " << iv << std::endl;
            var_subnode_ = node_["Tests"][kv.first][iv];
            iv_ = iv;
            auto result = kv.second();

            if (!result) {
                std::cout << "Test : " << kv.first << " failed, check output!\n";
                exit(1);
            }
        }
    }
}

// Set the possible tests to run
void IntegrationTestFokkerPlanck::SetTests() {
    node_ = YAML::LoadFile(filename_);

    // Check against names to bind method calls
    if (node_["Tests"]["Run"]) {
        tests_["Run"] = std::bind(&IntegrationTestFokkerPlanck::Run, this);
    }
}

// Initialize the system of chromosomes (with other stuff toooo)
void IntegrationTestFokkerPlanck::InitSystem(system_parameters *parameters,
                                       system_properties *properties) {

    // Generic parameter initialization
    parse_parameters_node(var_subnode_["params"], parameters);
    gsl_rng_env_setup();
    properties->rng.T = gsl_rng_default;
    properties->rng.r = gsl_rng_alloc(properties->rng.T);
    gsl_rng_set(properties->rng.r, parameters->seed);

    // Generate the unit cell
    int ndim = parameters->n_dim;
    properties->unit_cell.h = (double **) allocate_2d_array(ndim, ndim, sizeof(double));
    for (int i = 0; i < ndim; ++i) {
        properties->unit_cell.h[i][i] = 100;
    }

    // Get the microtubule information
    {
        int nmts = var_subnode_["microtubules"].size();
        init_unit_cell_structure(parameters, properties);
        int nsites = 2*nmts;
        parameters->n_spheros = properties->bonds.n_bonds = nmts;
        properties->sites.n_sites = nsites;
        properties->sites.v = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
        properties->sites.r = (double **) allocate_2d_array(nsites, ndim, sizeof(double));
        init_site_structure_sphero(parameters, properties);
        init_bond_structure_sphero(parameters, properties);
        for (int imt = 0; imt < nmts; ++imt) {
            for (int i = 0; i < ndim; ++i) {
                properties->bonds.r_bond[imt][i] = var_subnode_["microtubules"][imt]["r"][i].as<double>();
                properties->bonds.u_bond[imt][i] = var_subnode_["microtubules"][imt]["u"][i].as<double>();
            }
            properties->bonds.length[imt] = var_subnode_["microtubules"][imt]["l"].as<double>();
        }
        // Print the bond information
        for (int imt = 0; imt < nmts; ++imt) {
            std::cout << "New Microtubule:\n";
            std::cout << "  r(" << properties->bonds.r_bond[imt][0] << ", "
                                << properties->bonds.r_bond[imt][1] << ", "
                                << properties->bonds.r_bond[imt][2] << ")\n";
            std::cout << "  u(" << properties->bonds.u_bond[imt][0] << ", "
                                << properties->bonds.u_bond[imt][1] << ", "
                                << properties->bonds.u_bond[imt][2] << ")\n";
            std::cout << "  l " << properties->bonds.length[imt] << std::endl;
        }
    }

    // Load OMP settings
    set_omp_settings(parameters);
    print_omp_settings();

    #ifdef ENABLE_OPENMP
    #pragma omp parallel 
    #pragma omp master
    #endif
    {
        int n_threads = properties->mp_local.n_threads = properties->rng_mt.n_threads = 1;
            
        #ifdef ENABLE_OPENMP
        n_threads = properties->mp_local.n_threads = properties->rng_mt.n_threads = omp_get_num_threads();
        #endif

        properties->rng_mt.rng = (rng_properties*) 
            allocate_1d_array(properties->rng_mt.n_threads, sizeof(rng_properties));

        gsl_rng_env_setup();
        for (int i_thr = 0; i_thr < properties->rng_mt.n_threads; ++i_thr) {
            properties->rng_mt.rng[i_thr].T = gsl_rng_default;
            properties->rng_mt.rng[i_thr].r = gsl_rng_alloc(properties->rng_mt.rng[i_thr].T);
            gsl_rng_set(properties->rng_mt.rng[i_thr].r, (parameters->seed+1) * i_thr);
        }

        properties->mp_local.f_local = (double***)
            allocate_3d_array(n_threads, properties->bonds.n_bonds, parameters->n_dim, sizeof(double));

        properties->mp_local.t_local = (double***)
            allocate_3d_array(n_threads, properties->bonds.n_bonds, 3, sizeof(double));

        properties->mp_local.virial_local = (double***)
            allocate_3d_array(n_threads, parameters->n_dim, parameters->n_dim, sizeof(double));
        
        properties->mp_local.calc_local = (int**)
            allocate_2d_array(n_threads, properties->bonds.n_bonds, sizeof(int));
    }
}

// Finish initializing the system
void IntegrationTestFokkerPlanck::FinishInitSystem(system_parameters *parameters,
                                            system_properties *properties) {
    // Set up the neighbor lists
    properties->neighbors.neighbs = new nl_list[properties->bonds.n_bonds];
    update_neighbor_lists_sphero_all_pairs_mp(parameters->n_dim, parameters->n_periodic,
                                              properties->unit_cell.h,
                                              parameters->skin, parameters->r_cutoff,
                                              properties->bonds.n_bonds,
                                              properties->bonds.r_bond,
                                              properties->bonds.s_bond,
                                              properties->bonds.u_bond,
                                              properties->bonds.length,
                                              properties->neighbors.neighbs,
                                              parameters->nl_twoway_flag);

}

// Generate the invert distribution of attachments along MT
bool IntegrationTestFokkerPlanck::Run() {
    bool success = true;

    system_parameters parameters;
    system_properties properties;

    // Initialize the system
    InitSystem(&parameters, &properties);

    // Create the xlinks, then create a node based on the basic information
    // and then override it
    std::string x_file = parameters.crosslink_file;
    std::cout << "xlink file: " << x_file << std::endl;
    YAML::Node xlink_node = YAML::LoadFile(x_file.c_str());

    // Override the values
    xlink_node["force_dependent"]                           = var_subnode_["xlink"]["force_dependent"].as<bool>();
    xlink_node["crosslink"][0]["reservoir"]                 = var_subnode_["xlink"]["reservoir"].as<int>();
    xlink_node["crosslink"][0]["spring_constant"]           = var_subnode_["xlink"]["spring_constant"].as<double>();
    xlink_node["crosslink"][0]["velocity"]                  = var_subnode_["xlink"]["velocity"].as<double>();
    xlink_node["crosslink"][0]["velocity_polar_scale"]      = var_subnode_["xlink"]["velocity_polar_scale"].as<double>();
    xlink_node["crosslink"][0]["velocity_antipolar_scale"]  = var_subnode_["xlink"]["velocity_antipolar_scale"].as<double>();
    xlink_node["crosslink"][0]["equilibrium_length"]        = var_subnode_["xlink"]["equilibrium_length"].as<double>();
    xlink_node["crosslink"][0]["barrier_weight"]            = var_subnode_["xlink"]["barrier_weight"].as<double>();
    xlink_node["crosslink"][0]["characteristic_length"]     = var_subnode_["xlink"]["characteristic_length"].as<double>();
    xlink_node["crosslink"][0]["polar_affinity"]            = var_subnode_["xlink"]["polar_affinity"].as<double>();
    xlink_node["crosslink"][0]["diffusion_bound"]           = var_subnode_["xlink"]["diffusion_bound"].as<double>();
    xlink_node["crosslink"][0]["diffusion_bound_2"]         = var_subnode_["xlink"]["diffusion_bound_2"].as<double>();
    xlink_node["crosslink"][0]["concentration_1"]           = var_subnode_["xlink"]["concentration_1"].as<double>();
    xlink_node["crosslink"][0]["concentration_2"]           = var_subnode_["xlink"]["concentration_2"].as<double>();
    xlink_node["crosslink"][0]["on_rate_1"]                 = var_subnode_["xlink"]["on_rate_1"].as<double>();
    xlink_node["crosslink"][0]["on_rate_2"]                 = var_subnode_["xlink"]["on_rate_2"].as<double>();
    xlink_node["crosslink"][0]["stall_force"]               = var_subnode_["xlink"]["stall_force"].as<double>();

    // Create the crosslinks
    properties.crosslinks.Init(&parameters,
                               &properties,
                               &xlink_node);

    FinishInitSystem(&parameters, &properties);

    // Now , mimic what is seen in the previous motor_test_fokker_planck from Robert
    const int N = 1600;
    double length = properties.bonds.length[0];
    double delta_t = 0.0005;
    double t_tot = 10.0;
    double nsteps = (int) (t_tot/delta_t);
    std::cout << "Nsteps: " << nsteps << std::endl;

    int n_print_percent;
    if (nsteps > 100) {
        n_print_percent = nsteps / 100;
    } else {
        n_print_percent = 1;
    }

    double x[N], y[N];
    // set up the grids
    for (int i = 0; i < N; ++i) {
        x[i] = y[i] = (i+0.5)*(length/N) - 0.5*length;
    }

    double** v1 = (double**) allocate_2d_array(N, N, sizeof(double));
    double** v2 = (double**) allocate_2d_array(N, N, sizeof(double));
    double** enmat_on  = (double**) allocate_2d_array(N, N, sizeof(double));
    double** enmat_off = (double**) allocate_2d_array(N, N, sizeof(double));
    double dr[3] = {properties.bonds.r_bond[1][0] - properties.bonds.r_bond[0][0],
                    properties.bonds.r_bond[1][1] - properties.bonds.r_bond[0][1],
                    properties.bonds.r_bond[1][2] - properties.bonds.r_bond[0][2]};
    double **u_bond = properties.bonds.u_bond;
    double r0 = properties.crosslinks.r_equil_[0];
    double barrier_weight = properties.crosslinks.barrier_weight_[0];
    double xc = properties.crosslinks.xc_[0];
    double max_on = 0.0;
    double max_off = 0.0;

    //std::cout << "Initialize on/off rates\n";
    //std::cout << "  r0: " << r0 << std::endl;
    //std::cout << "  barrier: " << barrier_weight << std::endl;
    //std::cout << "  xc: " << xc << std::endl;
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j < N; ++j) {
            double k_stretch = properties.crosslinks.k_stretch_[0];

            double f_stall = properties.crosslinks.f_stall_[0][0];
            int bond_1 = 0; int bond_2 = 1;

            double dr_cross[3] = {(dr[0] + y[i] * u_bond[bond_2][0] - x[j] * u_bond[bond_1][0]),
                                  (dr[1] + y[i] * u_bond[bond_2][1] - x[j] * u_bond[bond_1][1]),
                                  (dr[2] + y[i] * u_bond[bond_2][2] - x[j] * u_bond[bond_1][2])};
            double f[3] = {k_stretch * dr_cross[0],
                           k_stretch * dr_cross[1],
                           k_stretch * dr_cross[2]};
                                  
            double ui_dot_f = dot_product(3, u_bond[bond_1], f);
            double uj_dot_f = dot_product(3, u_bond[bond_2], f);

            double f_mag_i = ui_dot_f;
            double f_mag_j = -uj_dot_f;

            // Do the proper velocity relation with the crosslinkg to stall it out.
            // If parallel force is along velocity vector, then it shouldn't slow
            // down crosslink
            v1[i][j] = properties.crosslinks.velocity_[0][0];
            if (f_mag_i*v1[i][j] > 0.0)
                f_mag_i = 0.0;
            else {
                f_mag_i = ABS(f_mag_i);
            }

            v2[i][j] = properties.crosslinks.velocity_[0][0];
            if (f_mag_j*v2[i][j] > 0.0)
                f_mag_j = 0.0;
            else {
                f_mag_j = ABS(f_mag_j);
            }

            if (f_mag_i < f_stall)
                v1[i][j] *= 1.0 - f_mag_i/f_stall;
            else
                v1[i][j] = 0.0;
            if (f_mag_j < f_stall)
                v2[i][j] *= 1.0 - f_mag_j/f_stall;
            else
                v2[i][j] = 0.0;

            // Do the original (or zero) energy dependent version
            double rmag = sqrt(dot_product(3, dr_cross, dr_cross));
            if (!properties.crosslinks.force_dependent_) {
                enmat_on[i][j]  = exp(-0.5 * k_stretch * (1 - barrier_weight) * (rmag - r0) * (rmag - r0));
                enmat_off[i][j] = exp( 0.5 * k_stretch * barrier_weight * (rmag - r0) * (rmag - r0));
                //std::cout << "enmat_on [" << i << "][" << j << "] = " << enmat_on[i][j] << std::endl;
                //std::cout << "enmat_off[" << i << "][" << j << "] = " << enmat_off[i][j] << std::endl;
                // Restrict the off matrix to something large, but not nan
                if (rmag > 1.0 + properties.crosslinks.r_cutoff_1_2_[0]) {
                    enmat_off[i][j] = max_off;
                } else {
                    max_off = MAX(max_off, enmat_off[i][j]);
                }
            } else {
                enmat_on[i][j]  = exp(-0.5 * k_stretch * (rmag - r0) * (rmag - r0) + k_stretch * xc * (rmag - r0));
                enmat_off[i][j] = exp(k_stretch * xc * (rmag - r0));
                if (rmag > 1.0 + properties.crosslinks.r_cutoff_1_2_[0]) {
                    enmat_off[i][j] = max_off;
                } else {
                    max_off = MAX(max_off, enmat_off[i][j]);
                }
            }
        }
    }

    std::cout << "Prep probability densities\n";
    std::cout << "  max_off: " << max_off << std::endl;
    // Prep the probability density
    double psi1[N], psi2[N];
    double **psi = (double**) allocate_2d_array(N, N, sizeof(double));
    for (int i = 0; i < N; ++i) {
        psi1[i] = psi2[i] = 0.0;
        for (int j = 0; j < N; ++j) {
            psi[i][j] = 0.0;
        }
    }

    // Initialize temporary flux matrices
    double** xfluxdiff = (double**) allocate_2d_array(N, N, sizeof(double));
    double** yfluxdiff = (double**) allocate_2d_array(N, N, sizeof(double));
    double** dpsifrom1 = (double**) allocate_2d_array(N, N, sizeof(double));
    double** dpsifrom2 = (double**) allocate_2d_array(N, N, sizeof(double));

    std::cout << "Prep boltzmann factors\n";
    double boltz1[N], boltz2[N];
    for (int i = 0; i < N; ++i) {
        boltz1[i] = boltz2[i] = 0.0;
        for (int j = 0; j < N; ++j) {
            boltz1[i] += (length/N) * enmat_on[i][j];
            boltz2[i] += (length/N) * enmat_on[j][i];
        }
    }

    // Proceed to solve the distribution function!
    for (int istep = 0; istep < nsteps; ++istep) {
        if (istep % n_print_percent == 0) {
            fprintf(stdout, "%d%% Complete\n", (int)(100 * (float)istep / (float)nsteps));
            fflush(stdout);

            //// Dump the concentration in the middle of each
            //std::cout << "  psi: " << psi[(int)N/2][(int)N/2] << std::endl;
            //std::cout << "  psi1: " << psi1[(int)N/2] << std::endl;
            //std::cout << "  psi2: " << psi2[(int)N/2] << std::endl;
            //std::cout << "  enmat_on : " << enmat_on[(int)N/2][(int)N/2] << std::endl;
            //std::cout << "  enmat_off: " << enmat_off[(int)N/2][(int)N/2] << std::endl;
            //std::cout << "  max_off: " << max_off << std::endl;
            //std::cout << "  xfluxdiff: " << xfluxdiff[(int)N/2][(int)N/2] << std::endl;
            //std::cout << "  yfluxdiff: " << yfluxdiff[(int)N/2][(int)N/2] << std::endl;
            //std::cout << "  dpsifrom1: " << dpsifrom1[(int)N/2][(int)N/2] << std::endl;
            //std::cout << "  dpsifrom2: " << dpsifrom2[(int)N/2][(int)N/2] << std::endl;
        }
        double v0 = properties.crosslinks.velocity_[0][0];
        double psi1flux[N];
        double psi2flux[N];

        // Stage 2 fluxes for joint distribution
        for (int i = 0; i < N; ++i) {
            xfluxdiff[i][0] = v1[i][0]*psi[i][0];
            for (int j = 1; j < N; ++j) {
                xfluxdiff[i][j] = v1[i][j]*psi[i][j]-v1[i][j-1]*psi[i][j-1];
            }
        }
        for (int i = 0; i < N; ++i) {
            yfluxdiff[0][i] = v2[0][i]*psi[0][i];
        }
        for (int i = 1; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                yfluxdiff[i][j] = v2[i][j]*psi[i][j]-v2[i-1][j]*psi[i-1][j];
            }
        }

        // Total flux for filaments 1 and 2
        double dpsi1[N];
        double delta_x = length/N;
        double eps01 = properties.crosslinks.eps_eff_1_[0][0];
        double eps12 = properties.crosslinks.eps_eff_2_[0][0];
        double k01 = properties.crosslinks.on_rate_1_[0][0];
        double k12 = properties.crosslinks.on_rate_2_[0][0];

        // Stage 1 fluxes for filaments 1 and 2  (motors move at speed v0)
        psi1flux[0] = v0 * psi1[0];
        for (int i = 1; i < N; ++i) {
            psi1flux[i] = v0 * psi1[i]-v0*psi1[i-1];
        }
        psi2flux[0] = v0 * psi2[0];
        for (int i = 1; i < N; ++i) {
            psi2flux[i] = v0 * psi2[i]-v0*psi2[i-1];
        }

        for (int i = 0; i < N; ++i) {
            dpsi1[i] = eps01 * k01;
            dpsi1[i] -= k01 * psi1[i];
            dpsi1[i] -= k12 * eps12 *boltz1[i] * psi1[i];
            for (int j = 0; j < N; ++j) {
                dpsi1[i] += k12 * psi[i][j] * delta_x * enmat_off[i][j];
            }
            dpsi1[i] += v1[i][N-1] * psi[i][N-1];
            dpsi1[i] -= psi1flux[i] / delta_x;
        }

        double dpsi2[N];
        for (int i = 0; i < N; ++i) {
            dpsi2[i] = eps01 * k01;
            dpsi2[i] -= k01 * psi2[i];
            dpsi2[i] -= k12 * eps12 * boltz2[i] * psi2[i];
            for (int j = 0; j < N; ++j)
                dpsi2[i] += k12 * psi[j][i] * delta_x * enmat_off[j][i];
            dpsi2[i] += v2[N-1][i] * psi[N-1][i];
            dpsi2[i] -= psi2flux[i] / delta_x;
        }

        // Flux into "stage 2" from "stage 1"
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                dpsifrom1[i][j] = psi1[i] * enmat_on[i][j] * k12 * eps12;
            }
        }
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                dpsifrom2[i][j] = psi2[j] * enmat_on[i][j] * k12 * eps12;
            }
        }

        for (int i = 0; i < N; ++i) {
            psi1[i] += dpsi1[i] * delta_t;
            psi2[i] += dpsi2[i] * delta_t;
        }

        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                psi[i][j] += delta_t *
                    (-xfluxdiff[i][j]/delta_x -
                     yfluxdiff[i][j]/delta_x +
                     dpsifrom1[i][j] +
                     dpsifrom2[i][j] -
                     2 * k12 * psi[i][j] * enmat_off[i][j]);
            }
        }
    }

    std::string fname_base = var_subnode_["result_file"].as<std::string>();
    std::string fname_psi1 = "fp_" + fname_base + "_psi1.dat";
    std::string fname_psi = "fp_" + fname_base + "_psi.dat";

    CorrelationData psi1_dist;
    {
        double bin_size[] = {length/N};
        double min[] = {0};
        double max[] = {properties.bonds.length[0]};
        psi1_dist.Init(1, bin_size, min, max);
    }
    psi1_dist.Fill(N, psi1);
    psi1_dist.OutputBinary(fname_psi1.c_str());

    CorrelationData psi_dist;
    {
        double bin_size[] = {length/N, length/N};
        double min[] = {0, 0};
        double max[] = {properties.bonds.length[0], properties.bonds.length[0]};
        psi_dist.Init(2, bin_size, min, max);
    }

    psi_dist.Fill(N*N, &psi[0][0]);
    psi_dist.OutputBinary(fname_psi.c_str());

    return success;
}
