
#include "correlation_data.h"
#include <iostream>
#include <algorithm> // for copy
#include <iterator> // for ostream_iterator
#include <cstdlib>
#include <cstdio>
#include <string>
#include <cmath>

void CorrelationData::Init(int n_dim, double bin_size[], double min[], double max[]) {
    n_dim_ = n_dim;
    n_meas_ = 0;
        
    n_bin_.resize(n_dim_);

    bin_size_.insert(bin_size_.begin(), &bin_size[0], &bin_size[n_dim]);
    min_.insert(min_.begin(), &min[0], &min[n_dim]);
    
    h_.resize(n_dim_ * n_dim_);
    std::fill(h_.begin(), h_.end(), 0.0);
    h_inv_.resize(n_dim_ * n_dim_);
    std::fill(h_inv_.begin(), h_inv_.end(), 0.0);
    
    for (int i = 0; i < n_dim_; ++i) {
        h_[i*n_dim_ + i] = max[i] - min[i];
        h_inv_[i*n_dim_ + i] = 1.0 / h_[i*n_dim + i];
    }
    
    n_linear_ = 1;
    for (int i = 0; i < n_dim_; ++i) {
        n_bin_[i] = (int) (h_[i * n_dim_ + i] / bin_size_[i]);
        bin_size_[i] = h_[i*n_dim_ + i] / n_bin_[i];
        n_linear_ *= n_bin_[i];
    }

    dist_.resize(n_linear_);
    std::fill(dist_.begin(), dist_.end(), 0.0);
}

void CorrelationData::BinAxis(int n_dim, double r[], double u_frame[], double value) {
    if (n_dim_ != 1) {
        std::cerr << "Invalid number of dimensions to use BinAxis member of CorrelationData\n";
        exit(1);
    }
        
    /* Add contributions to pair correlation function. */
    double r_axis = 0.0;
    for (int i = 0; i < n_dim; ++i)
        r_axis += r[i] * u_frame[i];

    Bin(n_dim, &r_axis, value);
}

int CorrelationData::LinearIndex() {
    if (n_dim_ == 1) {
        if (i_bin_[0] >= 0 && i_bin_[0] < n_bin_[0])
            return i_bin_[0];
        else
            return -1;
    }
    else if (n_dim_ == 2) {
        if (i_bin_[0] >= 0 && i_bin_[0] < n_bin_[0] && i_bin_[1] >=0 && i_bin_[1] < n_bin_[1])
            return i_bin_[0] * n_bin_[1] + i_bin_[1];
        else
            return -1;
    }
    else {
        std::cerr << "CorrelationData of dimension '" << n_dim_ << "' not supported\n";
        std::cerr << "Feeling ambitious? "
            "Implement recursive algorithm or just hard-code higher dims here\n";
        exit(1);
    }
    return -1;
}

void CorrelationData::Bin(int n_dim, double r[], double value) {
    s_axis_.reserve(n_dim);
    i_bin_.reserve(n_dim);
    for (int i = 0; i < n_dim; ++i)
        s_axis_[i] = h_inv_[i * n_dim_ + i] * (r[i] - min_[i]);

    for (int i = 0; i < n_dim; ++i)
        i_bin_[i] = static_cast<int> (s_axis_[i] * n_bin_[i]);

    IncrementBin(value);
}

void CorrelationData::IncrementBin(double value) {
    int index = LinearIndex();
    if ((index >= 0) && (index < n_linear_)) {
        dist_[index] += value;
    }
}

void CorrelationData::NormalizeByConstant(double norm_factor) {
    for (int index = 0; index < n_linear_; ++index) {
        dist_[index] *= norm_factor;
    }
}

void CorrelationData::NormalizeNmeas() {
    NormalizeByConstant((n_meas_ == 0) ? 0.0 : 1.0/n_meas_);
}

void CorrelationData::NormalizeSumUnity() {
    double tot = 0.0;
    for (int index = 0; index < n_linear_; ++index)
        tot += dist_[index];

    for (int index = 0; index < n_linear_; ++index) {
        dist_[index] /= tot;
    }
}

void CorrelationData::OutputBinary(std::string outfile) {
    FILE *f_out = std::fopen(outfile.c_str(), "w");
   
    //std::cout << "ndim: " << n_dim_ << std::endl;
    std::fwrite(&n_dim_, sizeof(int), 1, f_out);
    //for (int i = 0; i < n_dim_; ++i) {
    //    std::cout << "nbin[" << i << "] = " << n_bin_[i] << std::endl;
    //}
    std::fwrite(n_bin_.data(), sizeof(int), n_dim_, f_out);
    //for (int i = 0; i < n_dim_ * n_dim_; ++i) {
    //    std::cout << "h[" << i << "] = " << h_[i] << std::endl;
    //}
    std::fwrite(h_.data(), sizeof(double), n_dim_ * n_dim_, f_out);
    std::fwrite(&n_linear_, sizeof(int), 1, f_out);
    std::fwrite(dist_.data(), sizeof(double), n_linear_, f_out);
    std::fclose(f_out);
}

void CorrelationData::Fill(int size, double *data) {
    if (size != n_linear_) {
        fprintf(stderr, "Warning: Filling correlation data (%p, %lu) with array of the wrong size (%p, %d)\n",
                this, dist_.size(),data,size);
    }
    dist_.resize(0);
    dist_.insert(dist_.end(), &data[0], &data[size]);
}

void CorrelationData::Print() {
    std::cout << "Printing CorrelationData object at: " << this << std::endl;
    std::cout << "  n_dim: " << n_dim_ << std::endl;

    std::cout << "  n_bin: ";
    std::copy(n_bin_.begin(), n_bin_.end(), std::ostream_iterator<int>(std::cout, " "));
    std::cout << std::endl;
    
    std::cout << "  bin_size: ";
    std::copy(bin_size_.begin(), bin_size_.end(), std::ostream_iterator<double>(std::cout, " "));
    std::cout << std::endl;

    std::cout << "  h: ";
    std::copy(h_.begin(), h_.end(), std::ostream_iterator<double>(std::cout, " "));
}
