// implementation of probability functions

#include "probabilities.h"

#include <math.h>
#include <iostream>
#include <gsl/gsl_errno.h>

double advanced_integrand_1_2(double x, void *params) {
    integration_params *iparams = (integration_params *)params;
    double alpha1 = iparams->alpha;
    double alpha2 = iparams->alpha2;
    double costhetaminA = iparams->costhetaminA;
    double costheta2A   = iparams->costheta2A;
    double r0 = iparams->r0;
    double xc = iparams->xc;
    double chic = iparams->chic;
    double y0 = iparams->y0;

    //std::cout << "x: " << x << std::endl;
    //std::cout << "alpha1: " << alpha1 << std::endl;
    //std::cout << "alpha2: " << alpha2 << std::endl;
    //std::cout << "costehtaminA: " << costhetaminA << std::endl;
    //std::cout << "costheta2A: " << costheta2A << std::endl;
    //std::cout << "r0: " << r0 << std::endl;
    //std::cout << "xc: " << xc << std::endl;
    //std::cout << "chic: " << chic << std::endl;
    //std::cout << "y0: " << y0 << std::endl;

    double rms = sqrt(x*x + y0*y0);
    // First term
    double rprime = rms - r0 - xc;
    // Second term
    double angleprime = (y0*costhetaminA + x*costheta2A)/rms - 1.0 + chic;
    // Prefactor
    double prefactor = exp(alpha1*xc*xc + alpha2*chic*chic);

    // Put it all together
    return prefactor * exp(-alpha1*rprime*rprime -alpha2*angleprime*angleprime);
}

double advanced_integrand_1_2_quartic(double x, void *params) {
    integration_params *iparams = (integration_params *)params;
    double alpha1 = iparams->alpha;
    double alpha2 = iparams->alpha2;
    double costhetaminA = iparams->costhetaminA;
    double costheta2A   = iparams->costheta2A;
    double r0 = iparams->r0;
    double xc = iparams->xc;
    double chic = iparams->chic;
    double y0 = iparams->y0;

    double rms = sqrt(x*x + y0*y0);
    // First term
    double rprime = rms - r0 - xc;
    // Second term
    double angleprime = (y0*costhetaminA + x*costheta2A)/rms - 1.0;
    // Prefactor
    double prefactor = exp(alpha1*xc*xc);

    // Put it all together
    return prefactor * exp(-alpha1*rprime*rprime -alpha2*angleprime*angleprime*angleprime*angleprime -4.0*alpha2*chic*angleprime*angleprime*angleprime);
}

double prob_1_2_advanced(std::vector<double> &x, void *params) {
    gsl_integration_workspace *w
        = gsl_integration_workspace_alloc(10000);

    integration_params *inparams = (integration_params*) params;
    integration_params gslparams;
    gslparams.alpha = inparams->alpha;
    gslparams.alpha2 = inparams->alpha2;
    gslparams.costhetaminA = x[2];
    gslparams.costheta2A = x[3];
    gslparams.r0 = inparams->r0;
    gslparams.xc = inparams->xc;
    gslparams.chic = inparams->chic;
    gslparams.y0 = x[1];

    gsl_function F;
    F.function = &advanced_integrand_1_2;
    F.params = &gslparams;
    
    //std::cout << "Integration: x[" << x[0] << ", " << x[1] << ", " << x[2] << ", " << x[3] << "]\n";
    //std::cout << "  alpha: " << inparams->alpha << ", alpha2: " << inparams->alpha2 << std::endl;

    double result, error;
    if (x[0] > 0)
        gsl_integration_qags (&F, 0, x[0], 0, 1e-7, 10000,
                              w, &result, &error);
    else
        result = 0;

    gsl_integration_workspace_free (w);

    return result;
}

// Full integrand inclunding the lower bounds, eww
double prob_1_2_advanced_full(double *x, void *params, int *pstatus, gsl_integration_workspace *w) {
    //gsl_integration_workspace *w
    //    = gsl_integration_workspace_alloc(100000);

    integration_params *inparams = (integration_params*) params;
    integration_params gslparams;
    gslparams.alpha = inparams->alpha;
    gslparams.alpha2 = inparams->alpha2;
    gslparams.costhetaminA = x[3];
    gslparams.costheta2A = x[4];
    gslparams.r0 = inparams->r0;
    gslparams.xc = inparams->xc;
    gslparams.chic = inparams->chic;
    gslparams.y0 = x[2];

    gsl_function F;
    F.function = &advanced_integrand_1_2;
    F.params = &gslparams;
    
    //std::cout << "Integration: x[" << x[0] << ", " << x[1] << ", " << x[2] << ", " << x[3] << ", " << x[4] << "]\n";
    //std::cout << "  alpha: " << inparams->alpha << ", alpha2: " << inparams->alpha2 << std::endl;
    //std::cout << "  r0: " << inparams->r0 << ", xc: " << inparams->xc << std::endl;

    double result, error;
    *pstatus = gsl_integration_qags (&F, x[0], x[1], 0, 1e-7, 100000,
                                     w, &result, &error);
    if (*pstatus) {
        fprintf(stderr, "failed prob_1_2_advanced_full, gsl_errno=%d, error=%s\n", *pstatus, gsl_strerror(*pstatus));
        exit(1);
    }

    //gsl_integration_workspace_free (w);

    return result;
}

// Full integrand with quartic angular force
double prob_1_2_advanced_quartic(double *x, void *params, int *pstatus, gsl_integration_workspace *w) {
    //gsl_integration_workspace *w
    //    = gsl_integration_workspace_alloc(100000);

    integration_params *inparams = (integration_params*) params;
    integration_params gslparams;
    gslparams.alpha = inparams->alpha;
    gslparams.alpha2 = inparams->alpha2;
    gslparams.costhetaminA = x[3];
    gslparams.costheta2A = x[4];
    gslparams.r0 = inparams->r0;
    gslparams.xc = inparams->xc;
    gslparams.chic = inparams->chic;
    gslparams.y0 = x[2];

    gsl_function F;
    F.function = &advanced_integrand_1_2_quartic;
    F.params = &gslparams;
    
    //std::cout << "Integration: x[" << x[0] << ", " << x[1] << ", " << x[2] << ", " << x[3] << ", " << x[4] << "]\n";
    //std::cout << "  alpha: " << inparams->alpha << ", alpha2: " << inparams->alpha2 << std::endl;
    //std::cout << "  r0: " << inparams->r0 << ", xc: " << inparams->xc << std::endl;

    double result, error;
    *pstatus = gsl_integration_qags (&F, x[0], x[1], 0, 1e-7, 100000,
                                     w, &result, &error);
    if (*pstatus) {
        fprintf(stderr, "failed prob_1_2_advanced_quartic, gsl_errno=%d\n", *pstatus);
        exit(1);
    }

    //gsl_integration_workspace_free (w);

    return result;
}

// Probability of 1 to 2 on the upper bound side
double prob_1_2_advanced_upper(std::vector<double> &x, void *params) {
    gsl_integration_workspace *w
        = gsl_integration_workspace_alloc(10000);

    integration_params *inparams = (integration_params*) params;
    integration_params gslparams;
    gslparams.alpha = inparams->alpha;
    gslparams.alpha2 = inparams->alpha2;
    gslparams.costhetaminA = x[2];
    gslparams.costheta2A = x[3];
    gslparams.r0 = inparams->r0;
    gslparams.xc = inparams->xc;
    gslparams.y0 = x[1];

    gsl_function F;
    F.function = &advanced_integrand_1_2;
    F.params = &gslparams;
    
    double result, error;
    if (x[0] > 0)
        gsl_integration_qags (&F, 0, x[0], 0, 1e-7, 10000,
                              w, &result, &error);
    else
        result = 0;

    gsl_integration_workspace_free (w);

    return result;
}

// Probability of 1 to 2 on the lower bound side
double prob_1_2_advanced_lower(std::vector<double> &x, void *params) {
    gsl_integration_workspace *w
        = gsl_integration_workspace_alloc(10000);

    integration_params *inparams = (integration_params*) params;
    integration_params gslparams;
    gslparams.alpha = inparams->alpha;
    gslparams.alpha2 = inparams->alpha2;
    gslparams.costhetaminA = x[2];
    gslparams.costheta2A = x[3];
    gslparams.r0 = inparams->r0;
    gslparams.xc = inparams->xc;
    gslparams.y0 = x[1];

    gsl_function F;
    F.function = &advanced_integrand_1_2;
    F.params = &gslparams;
    
    double result, error;
    if (x[0] < 0)
        gsl_integration_qags (&F, x[0], 0, 0, 1e-7, 10000,
                              w, &result, &error);
    else
        result = 0;

    gsl_integration_workspace_free (w);

    return result;
}

double integrand_1_2(double x, void *params) {
    xlink_params* myparams = (xlink_params *) params;
    double alpha = myparams->alpha;
    double r0 = myparams->r0;
    double y0 = myparams->y0;

    double exponent = sqrt(x*x + y0*y0) - r0;
    exponent *= -alpha * exponent;
    return exp(exponent);
}

double integrand_1_2_fdep(double x, void *params) {
    xlink_params* myparams = (xlink_params *)params;
    double alpha = myparams->alpha;
    double r0 = myparams->r0;
    double xc = myparams->xc;
    double y0 = myparams->y0;

    double rprime = sqrt(x*x + y0*y0) - r0 - xc;
    double prefactor = exp(alpha*xc*xc);
    return prefactor*exp(-alpha*rprime*rprime);
}

// Performs a binary search to get the value within some tolerance
double InvertBinary(double u,
                    double *xvec,
                    double (*func) (double *x, void *params, int *pstatus, gsl_integration_workspace *w),
                    gsl_integration_workspace *w,
                    double tolerance,
                    void *params) {
    double s = 0.0;
    int pstatus;

    integration_params *inparams = (integration_params*) params;
    integration_params rparams;
    rparams.alpha = inparams->alpha;
    rparams.alpha2 = inparams->alpha2;
    rparams.r0 = inparams->r0;
    rparams.xc = inparams->xc;
    rparams.chic = inparams->chic;

    //std::cout << "Target u: " << u << std::endl;
    double total_integral = func(xvec, &rparams, &pstatus, w);
    //std::cout << "total integral: " << total_integral << std::endl;
    //std::cout << "  lim0: " << xvec[0] << ", lim1: " << xvec[1] << std::endl;

    // The probability is the integral from lim0 to x, check to see if this results in a given u
    bool found = false;
    double xlow_lim  = xvec[0];
    double xhigh_lim = xvec[1];
    xvec[1] = xlow_lim;
    double xlow = func(xvec, &rparams, &pstatus, w)/total_integral;
    xvec[1] = xhigh_lim;
    double xhigh = func(xvec, &rparams, &pstatus, w)/total_integral;
    int niter = 10000;
    int iiter = 0;
    while (!found && iiter < niter) {
        // Compute the integral at the halfway point between xlow and xhigh
        //std::cout << "xlow_lim: " << xlow_lim << ", xhigh_lim: " << xhigh_lim << std::endl;
        //std::cout << "xlow: " << xlow << ", xhigh: " << xhigh << std::endl;
        double x = (xlow_lim + xhigh_lim) / 2.;
        //std::cout << "  x: " << x << std::endl;
        xvec[1] = x; // the upper limit of integration
        double t = func(xvec, &rparams, &pstatus, w)/total_integral;
        //std::cout << "  t: " << t << std::endl;
        // Determine if we are higher or lower than the bound
        // First, check exit condition, that is, we are within tolerance of the target
        if (fabs(t - u) < tolerance) {
            found = true;
            s = x;
            break;
        }

        // Reset xlow and xhigh appropriately, along with the limits
        if (t < u && u < xhigh) {
            xlow_lim = x;
            xlow = t;
        } else if (xlow < u && u < t) {
            xhigh_lim = x;
            xhigh = t;
        }
        iiter++;
        if (iiter == niter) {
            std::cout << "WARNING: Reached maximum number of iterations for InvertBinary " << niter << ", going with what we have!\n";
        }

        //exit(1);
    }

    //std::cout << "Invert binary result: " << s << std::endl;
    //xvec[1] = s;
    //std::cout << "   fractional: " << func(xvec, &rparams)/total_integral << std::endl;
    //exit(1);
    return s;
}

double InverseTransformSample(double u,
                              double *xvec,
                              double lline,
                              double (*func) (double *x, void *params, int *pstatus, gsl_integration_workspace *w),
                              gsl_integration_workspace *w,
                              double tolerance,
                              void *params) {
    //std::cout << "InverseTransformSample\n";
    double s = 0.0;
    int pstatus;

    // Integration paramters
    integration_params *inparams = (integration_params*) params;
    integration_params rparams;
    rparams.alpha = inparams->alpha;
    rparams.alpha2 = inparams->alpha2;
    rparams.r0 = inparams->r0;
    rparams.xc = inparams->xc;
    rparams.chic = inparams->chic;

    integration_params rparams2;
    rparams2.alpha = inparams->alpha;
    rparams2.alpha2 = inparams->alpha2;
    rparams2.costhetaminA = xvec[3];
    rparams2.costheta2A = xvec[4];
    rparams2.r0 = inparams->r0;
    rparams2.xc = inparams->xc;
    rparams2.chic = inparams->chic;
    rparams2.y0 = xvec[2];

    double xmin = xvec[0];
    double xmax = xvec[1];

    //std::cout << "xmin = " << xmin << ", xmax = " << xmax << std::endl;

    // First, just do the rho expected values
    double h = 0.05;
    int N = (xmax - xmin) / h;
    //std::cout << "N = " << N << std::endl;

    // Generate the density norm factor
    double density_norm_factor = func(xvec, &rparams, &pstatus, w) * lline;
    //std::cout << "density_norm_factor: " << density_norm_factor << std::endl;

    //std::vector<double> rho(N, 0.0);
    //for (int i = 0; i < N; ++i) {
    //    rho[i] = advanced_integrand_1_2(xmin + i*h, &rparams2);
    //}
    //for (auto xpos: rho) {
    //    std::cout << xpos << ' ';
    //}

    std::vector<double> v(N, 0.0);
    double diffmin = 100000000.0;
    int argmin = -1;
    for (int j = 0; j < N; ++j) {
        for (int i = 0; i < j; ++i) {
            v[j] += advanced_integrand_1_2(xmin + i*h, &rparams2) / density_norm_factor;
        }
        // Check if we are the closest to the random number
        double diffabs = fabs(v[j] - u);
        if (diffabs < diffmin) {
            diffmin = diffabs;
            argmin = j;
        }
    }

    //// Return the correct value
    //std::cout << "argmin = " << argmin << std::endl;
    s = xmin + argmin*h;

    return s;
}
