/* Molecular dynamics program to recreate movies of 3d spindle dynamics

   Command-line input: name of parameter file (param_file)

   Output: various thermodynamic quantities are written to standard output */

#include "bob.h"
#include "graphics.h"
#include "minimum_distance.h"

#include "correlation_data.h"

#include "parse_flags.h"

#include <iostream>
#include <fstream>
#include <unordered_map>
typedef std::unordered_map<std::string, CorrelationData> dist_map;
typedef std::unordered_map<std::string, std::ofstream*> stream_output; 

void bin_lengths(system_parameters *parameters, system_properties *properties,
                 dist_map &distributions, stream_output &streams) {
    
    std::vector<double> pairing_dist_max(properties->bonds.n_bonds, 0.0);
    std::vector<double> pairing_dist_tot(properties->bonds.n_bonds, 0.0);

    double u_spindle[3];
    for (int i = 0; i < 3; ++i) {
        u_spindle[i] = properties->anchors.r_anchor[1][i] -
            properties->anchors.r_anchor[0][i];
    }
    
    {
        double norm_factor = 1/sqrt(dot_product(3, u_spindle, u_spindle));
        for (int i = 0; i < 3; ++i) 
            u_spindle[i] *= norm_factor;
    }

    
    for (unsigned int i_mt = 0;
         i_mt < properties->anchors.anchor_list[0].size();
         ++i_mt) {
        for (unsigned int j_mt = properties->anchors.anchor_list[0].size();
             j_mt < (properties->anchors.anchor_list[0].size() +
                     properties->anchors.anchor_list[1].size());
             ++j_mt) {

            double pairing_dist =
                antiparallel_overlap(3, 
                                     0, 
                                     properties->unit_cell.h, 
                                     1.6, 
                                     properties->bonds.r_bond[i_mt], 
                                     properties->bonds.s_bond[i_mt], 
                                     properties->bonds.u_bond[i_mt], 
                                     properties->bonds.length[i_mt], 
                                     properties->bonds.r_bond[j_mt], 
                                     properties->bonds.s_bond[j_mt], 
                                     properties->bonds.u_bond[j_mt], 
                                     properties->bonds.length[j_mt]);
            

            pairing_dist_max[i_mt] = MAX(pairing_dist_max[i_mt], pairing_dist);
            pairing_dist_max[j_mt] = MAX(pairing_dist_max[j_mt], pairing_dist);
            pairing_dist_tot[i_mt] += pairing_dist;
            pairing_dist_tot[j_mt] += pairing_dist;
        }
    }
    for (int i_mt = 0; i_mt < properties->bonds.n_bonds; ++i_mt) {
        if (pairing_dist_max[i_mt] > 0.0) {
            distributions["pairing_length_max"].Bin(1, &pairing_dist_max[i_mt], 1.0);
            streams["pairing_length_max"]->
                write(reinterpret_cast<char*>(&properties->bonds.length[i_mt]),
                      sizeof(properties->bonds.length[i_mt]));
        }
        if (pairing_dist_tot[i_mt] > 0.0)
            distributions["pairing_length_tot"].Bin(1, &pairing_dist_max[i_mt], 1.0);
        
        if (pairing_dist_max[i_mt] > 12) {
            distributions["interpolar"].Bin(1, &properties->bonds.length[i_mt], 1.0);
            streams["interpolar"]->
                write(reinterpret_cast<char*>(&properties->bonds.length[i_mt]),
                      sizeof(properties->bonds.length[i_mt]));
        }
        else {
            distributions["polar"].Bin(1, &properties->bonds.length[i_mt], 1.0);
            streams["polar"]->
                write(reinterpret_cast<char*>(&properties->bonds.length[i_mt]),
                      sizeof(properties->bonds.length[i_mt]));
        }
        distributions["total"].Bin(1, &properties->bonds.length[i_mt], 1.0);
        streams["total"]->
                write(reinterpret_cast<char*>(&properties->bonds.length[i_mt]),
                      sizeof(properties->bonds.length[i_mt]));

        double theta = acos(fabs(dot_product(3,
                                             u_spindle,
                                             properties->bonds.u_bond[i_mt])));
        distributions["angle"].Bin(1, &theta, 1.0);
        streams["angle"]->write(reinterpret_cast<char*>(&theta), sizeof(theta));
    }

    for (dist_map::iterator dist = distributions.begin();
         dist != distributions.end();
         ++dist) {
        dist->second.IncrementMeasCounter();
    }
    
}

int is_spindle(system_parameters *parameters, system_properties *properties,
               double center) {
    double dr[3];
    for (int i = 0; i < 3; ++i) {
        dr[i] = properties->anchors.r_anchor[1][i] - properties->anchors.r_anchor[0][i];
    }
    double dr_mag = sqrt(dot_product(3, dr, dr));

    if (dr_mag > (center - 2) && dr_mag < (center + 2)) {
        for (unsigned int i_mt = 0;
             i_mt < properties->anchors.anchor_list[0].size();
             ++i_mt) {
            for (unsigned int j_mt = properties->anchors.anchor_list[0].size();
                 j_mt < (properties->anchors.anchor_list[0].size() +
                      properties->anchors.anchor_list[1].size());
                 ++j_mt) {

                if ( antiparallel_overlap(3, 
                                          0, 
                                          properties->unit_cell.h, 
                                          1.6, 
                                          properties->bonds.r_bond[i_mt], 
                                          properties->bonds.s_bond[i_mt], 
                                          properties->bonds.u_bond[i_mt], 
                                          properties->bonds.length[i_mt], 
                                          properties->bonds.r_bond[j_mt], 
                                          properties->bonds.s_bond[j_mt], 
                                          properties->bonds.u_bond[j_mt], 
                                          properties->bonds.length[j_mt]) > 12)
                    return 1;
            }
        }
    }
    
    return 0;
}

double frac_interpolar(system_parameters *parameters,
                       system_properties *properties) {
    int n_mts = properties->anchors.anchor_list[0].size() +
        properties->anchors.anchor_list[1].size();
    std::vector<double> paired(n_mts, 0.0);
    
    for (unsigned int i_mt = 0;
         i_mt < properties->anchors.anchor_list[0].size();
         ++i_mt) {

        for (unsigned int j_mt = properties->anchors.anchor_list[0].size();
             j_mt < (properties->anchors.anchor_list[0].size() +
                     properties->anchors.anchor_list[1].size());
             ++j_mt) {

            double pairing_dist =
                antiparallel_overlap(3, 
                                     0, 
                                     properties->unit_cell.h, 
                                     1.6, 
                                     properties->bonds.r_bond[i_mt], 
                                     properties->bonds.s_bond[i_mt], 
                                     properties->bonds.u_bond[i_mt], 
                                     properties->bonds.length[i_mt], 
                                     properties->bonds.r_bond[j_mt], 
                                     properties->bonds.s_bond[j_mt], 
                                     properties->bonds.u_bond[j_mt], 
                                     properties->bonds.length[j_mt]);
            
            if (pairing_dist > 12)
                paired[i_mt] = paired[j_mt] = 1.0;
        }
    }
    double tot_paired = 0.0;
    for (int i = 0; i < n_mts; ++i)
        tot_paired += paired[i];
    return tot_paired / n_mts;
}

double frac_length_interpolar(system_parameters *parameters, system_properties *properties){
    //int n_mts = properties->anchors.anchor_list[0].size() +
        //properties->anchors.anchor_list[1].size();
    double *length = properties->bonds.length;

    double tot_overlap_length = 0;
    double max_overlap_length = 0;
    
    for (unsigned int i_mt = 0;
         i_mt < properties->anchors.anchor_list[0].size();
         ++i_mt) {

        for (unsigned int j_mt = properties->anchors.anchor_list[0].size();
             j_mt < (properties->anchors.anchor_list[0].size() +
                     properties->anchors.anchor_list[1].size());
             ++j_mt) {

            double pairing_dist =
                antiparallel_overlap(3, 
                                     0, 
                                     properties->unit_cell.h, 
                                     1.6, 
                                     properties->bonds.r_bond[i_mt], 
                                     properties->bonds.s_bond[i_mt], 
                                     properties->bonds.u_bond[i_mt], 
                                     properties->bonds.length[i_mt], 
                                     properties->bonds.r_bond[j_mt], 
                                     properties->bonds.s_bond[j_mt], 
                                     properties->bonds.u_bond[j_mt], 
                                     properties->bonds.length[j_mt]);

            //Get the total possible overlap between i_mt and j_mt
            max_overlap_length += 2.0 * ((length[i_mt] < length[j_mt]) ? length[i_mt] : length[j_mt]);
            
            //Double the pairing distance since two mts are overlapped
            tot_overlap_length += pairing_dist*2.0;
        }
    }
    return (tot_overlap_length / max_overlap_length);
}

double avg_mt_length(system_properties *properties){
    double total_length = 0;
    for (int i_mt=0; i_mt<properties->bonds.n_bonds; i_mt++)
        total_length += properties->bonds.length[i_mt];
    //fprintf(stdout, "avg_length = %f \n", total_length/double(properties->bonds.n_bonds));
    return (total_length/double(properties->bonds.n_bonds));
}

double norm_interpolar_length(system_parameters *parameters, 
                              system_properties *properties){

    int n_mts = properties->anchors.anchor_list[0].size() +
        properties->anchors.anchor_list[1].size();

    double total_length = 0;
    double overlap_length = 0;
    double* length = properties->bonds.length;
    for (int i=0; i<n_mts; i++) total_length += length[i];
    
    for (unsigned int i_mt = 0;
         i_mt < properties->anchors.anchor_list[0].size();
         ++i_mt) {

        for (unsigned int j_mt = properties->anchors.anchor_list[0].size();
             j_mt < (properties->anchors.anchor_list[0].size() +
                     properties->anchors.anchor_list[1].size());
             ++j_mt) {

            overlap_length +=
                2*antiparallel_overlap(3, 
                                     0, 
                                     properties->unit_cell.h, 
                                     1.6, 
                                     properties->bonds.r_bond[i_mt], 
                                     properties->bonds.s_bond[i_mt], 
                                     properties->bonds.u_bond[i_mt], 
                                     properties->bonds.length[i_mt], 
                                     properties->bonds.r_bond[j_mt], 
                                     properties->bonds.s_bond[j_mt], 
                                     properties->bonds.u_bond[j_mt], 
                                     properties->bonds.length[j_mt]);


        }
    }
    return (overlap_length/total_length);
}

double real_interpolar_length(system_parameters *parameters, 
                              system_properties *properties){

    int n_mts = properties->anchors.anchor_list[0].size() +
        properties->anchors.anchor_list[1].size();

    //The amount of overlap length for each mt
    std::vector<double> overlap_lengths(n_mts, 0);
    double total_length = 0;
    double* length = properties->bonds.length;

    for (int i=0; i<n_mts; i++) total_length += length[i];

    for (unsigned int i_mt = 0;
         i_mt < properties->anchors.anchor_list[0].size();
         ++i_mt) {

        for (unsigned int j_mt = properties->anchors.anchor_list[0].size();
             j_mt < (properties->anchors.anchor_list[0].size() +
                     properties->anchors.anchor_list[1].size());
             ++j_mt) {

            double pairing_dist =
                antiparallel_overlap(3, 
                                     0, 
                                     properties->unit_cell.h, 
                                     1.6, 
                                     properties->bonds.r_bond[i_mt], 
                                     properties->bonds.s_bond[i_mt], 
                                     properties->bonds.u_bond[i_mt], 
                                     properties->bonds.length[i_mt], 
                                     properties->bonds.r_bond[j_mt], 
                                     properties->bonds.s_bond[j_mt], 
                                     properties->bonds.u_bond[j_mt], 
                                     properties->bonds.length[j_mt]);

            if (pairing_dist > overlap_lengths[i_mt]) overlap_lengths[i_mt] = pairing_dist;
            if (pairing_dist > overlap_lengths[j_mt]) overlap_lengths[j_mt] = pairing_dist;
        }
    }

    //Get the largest value for overlap lengths without double
    double overlap_length = 0;
    for (int i=0; i<n_mts; i++) overlap_length += overlap_lengths[i];
    return overlap_length/total_length;
}


int main(int argc, char *argv[]) {
    system_parameters parameters;
    system_properties properties;
    system_potential potential;
    FILE *f_posit;
    char default_file[F_MAX], param_file[F_MAX];
    char posit_file[F_MAX];

    // New awesome way of loading parameters
    run_options run_opts = parse_opts(argc, argv, 4);

    strcpy(default_file, run_opts.default_file.c_str());
    strcpy(param_file, run_opts.equil_file.c_str());

    // The posit file must be the first p file
    std::string pfile = run_opts.posit_file_names[0];
    strcpy(posit_file, pfile.c_str());

    /* Read in default parameters. */
    parse_parameters(default_file, &parameters);

    /* Read in run-specific input parameters. */
    parse_parameters(param_file, &parameters);

    /* Initialize variables and allocate memory. */
    init_spindle_bd_mp(&parameters, &properties, &potential);

    
    /* Open input file, read header. */
    f_posit = gfopen(posit_file, "rb");
    properties.read_header_func(&parameters, &properties, f_posit);
    long int i_config = 0;
    long int n_configs = 0;

    long int start = ftell(f_posit);
    read_positions_spindle(parameters.n_dim,
                           properties.sites.n_sites,
                           properties.anchors.n_anchors,
                           &properties.time,
                           properties.unit_cell.h,
                           properties.sites.r,
                           properties.anchors.r_anchor,
                           properties.anchors.u_anchor,
                           f_posit);
    //properties.chromosomes.ReadState(&parameters, &properties);
    //properties.crosslinks.ReadState(&parameters, &properties);

    // Should really update everything correctly so we draw the right image
    // XXX FIXME Draw the crosslinks, chromosomes, etc, correctly via loading
    // and then resetting the file seek

    long int frame_size = ftell(f_posit) - start;
    fseek(f_posit, 0, SEEK_END);
    n_configs = (ftell(f_posit) - start) / frame_size;
    fseek(f_posit, start, SEEK_SET);
    //parameters.n_graph = n_configs / 4;

    #ifndef NOGRAPH
    Graphics graphics;
    if (parameters.graph_flag) {
        graphics.Init(&parameters, parameters.n_dim,
                      properties.unit_cell.h,
                      properties.anchors.n_anchors);
        graphics.ResizeWindow(800, 800);
        graphics.SetBoundaryType("sphere");
        graphics.tomogram_view_ = run_opts.tomogram;
        graphics.Draw(properties.bonds.n_bonds,
                      properties.unit_cell.h,
                      properties.bonds.r_bond,
                      properties.bonds.u_bond,
                      properties.bonds.length,
                      properties.sites.n_sites,
                      properties.sites.r,
                      parameters.sphere_diameter);
    }
    #endif

    std::unordered_map<std::string, double> centers;
    centers["1.05um"] = 40 * 1.05;
    centers["1.825um"] = 40 * 1.825;
    centers["2.15um"] = 40 * 2.15;

    double nuclear_diameter = properties.unit_cell.h[0][0];

    for (std::unordered_map<std::string, double>::iterator center = centers.begin();
         center != centers.end();
         ++center) {
        bool meas_spindle_separation = false;
        if (center == centers.begin())
            meas_spindle_separation = true;
        
        dist_map distributions;
        double bin_size[] = {8.0};
        double lims[] = {0.0, nuclear_diameter};
        distributions["total"].Init(1, bin_size, &lims[0], &lims[1]);
        distributions["polar"].Init(1, bin_size, &lims[0], &lims[1]);
        distributions["interpolar"].Init(1, bin_size, &lims[0], &lims[1]);
        distributions["pairing_length_tot"].Init(1, bin_size, &lims[0], &lims[1]);
        distributions["pairing_length_max"].Init(1, bin_size, &lims[0], &lims[1]);

        bin_size[0] = M_PI/10;
        distributions["angle"].Init(1, bin_size, &lims[0], &lims[1]);

        stream_output streams;
        streams["total"] =
            new std::ofstream("length_stream_total_" + center->first, 
                              std::ios::out | std::ios::binary);
        streams["polar"] =
            new std::ofstream("length_stream_polar_" + center->first, 
                              std::ios::out | std::ios::binary);
        streams["interpolar"] =
            new std::ofstream("length_stream_interpolar_" + center->first, 
                              std::ios::out | std::ios::binary);
        streams["pairing_length_max"] =
            new std::ofstream("length_stream_pairing_length_max_" + center->first, 
                              std::ios::out | std::ios::binary);
        streams["angle"] =
            new std::ofstream("angle_stream_" + center->first,
                              std::ios::out | std::ios::binary);

    
        std::ofstream f_frac_interpolar;
        std::ofstream f_spindle_separation;
        std::ofstream f_avg_mt_length;
        if (meas_spindle_separation)
            f_spindle_separation.open("spindle_separation.dat");
            f_frac_interpolar.open("frac_interpolar.dat");
            f_avg_mt_length.open("avg_mt_length.dat");
        
        /* Loop over configurations. */
        bool last_config;
        fseek(f_posit, start, SEEK_SET);
        i_config = 0;
        do {
            i_config++;
            //fprintf(stdout,"i_config: %d, n_graph: %d\n", i_config, parameters.n_graph);
            /* Read site positions from trajectory file. */
            last_config = read_positions_spindle(parameters.n_dim,
                                                 properties.sites.n_sites,
                                                 properties.anchors.n_anchors,
                                                 &properties.time,
                                                 properties.unit_cell.h,
                                                 properties.sites.r,
                                                 properties.anchors.r_anchor,
                                                 properties.anchors.u_anchor,
                                                 f_posit);
            /* Compute unit cell vectors, unit cell volume, and related quantities. */
            unit_cell_dimensions(parameters.n_dim,
                                 properties.unit_cell.h,
                                 properties.unit_cell.h_inv,
                                 properties.unit_cell.a,
                                 properties.unit_cell.b,
                                 properties.unit_cell.a_perp,
                                 &(properties.unit_cell.volume));
        
            /* Update bond vectors. */
            update_bond_vectors(parameters.n_dim,
                                parameters.n_periodic,
                                properties.bonds.n_bonds,
                                properties.unit_cell.h,
                                properties.sites.s,
                                properties.sites.r,
                                properties.bonds.bond_site_1,
                                properties.bonds.bond_site_2,
                                properties.bonds.v_bond,
                                properties.bonds.u_bond,
                                properties.bonds.length,
                                properties.bonds.length2);
            update_bond_positions(parameters.n_dim,
                                  parameters.n_periodic,
                                  properties.bonds.n_bonds,
                                  properties.unit_cell.h,
                                  properties.unit_cell.h_inv,
                                  properties.sites.r,
                                  properties.bonds.bond_site_1,
                                  properties.bonds.v_bond,
                                  properties.bonds.r_bond,
                                  properties.bonds.s_bond);
            update_bond_site_positions(parameters.n_dim,
                                       parameters.n_periodic,
                                       properties.bonds.n_bonds,
                                       properties.sites.n_sites,
                                       properties.unit_cell.h,
                                       properties.unit_cell.h_inv,
                                       properties.bonds.bond_site_1,
                                       properties.bonds.bond_site_2,
                                       properties.bonds.r_bond,
                                       properties.bonds.u_bond,
                                       properties.bonds.length,
                                       properties.sites.r,
                                       properties.sites.s);

            properties.chromosomes.ReadState(&parameters, &properties);
            properties.crosslinks.ReadState(&parameters, &properties);

            properties.time = i_config * parameters.n_posit * parameters.delta;
            properties.i_current_step = (i_config-1) * parameters.n_posit;

            if (last_config)
                break;

            //print_simulation_step(&parameters, &properties);

            /* Display configuration. */
            #ifndef NOGRAPH
            if (parameters.graph_flag && (i_config*parameters.n_posit) % parameters.n_graph == 0 && meas_spindle_separation) {
                /* Rotate graphics. */

                //graphics.Draw(properties.bonds.n_bonds,
                //              properties.unit_cell.h,
                //              properties.bonds.r_bond,
                //              properties.bonds.u_bond,
                //              properties.bonds.length);
                graphics.Draw(properties.bonds.n_bonds,
                              properties.unit_cell.h,
                              properties.bonds.r_bond,
                              properties.bonds.u_bond,
                              properties.bonds.length,
                              properties.crosslinks.n_types_,
                              properties.crosslinks.stage_0_xlinks_,
                              properties.crosslinks.stage_1_xlinks_,
                              properties.crosslinks.stage_2_xlinks_,
                              properties.anchors.n_anchors,
                              1,
                              properties.anchors.color_,
                              properties.anchors.r_anchor,
                              properties.anchors.u_anchor,
                              properties.anchors.v_anchor,
                              properties.anchors.w_anchor,
                              properties.anchors.diameter,
                              properties.chromosomes.nchromosomes_,
                              &properties.chromosomes);

                //std::cout << "Step[" << i_config*parameters.n_posit << "]\n";
                //for (int itype = 0; itype < properties.crosslinks.n_types_; ++itype) {
                //    std::cout << "   nfree[" << itype << "] = " << properties.crosslinks.n_free_[itype] << std::endl;
                //}
                //for (int itype = 0; itype < properties.crosslinks.n_types_; ++itype) {
                //    std::cout << "   nbound1[" << itype << "] = " << properties.crosslinks.n_bound_1_[itype][0]
                //        << ", " << properties.crosslinks.n_bound_1_[itype][1] << std::endl;
                //}
                //for (int itype = 0; itype < properties.crosslinks.n_types_; ++itype) {
                //    std::cout << "   nbound2[" << itype << "] = " << properties.crosslinks.n_bound_2_[itype] << std::endl;
                //}

                //XXX STOP CHANGING TO THIS!!!!!!!
                grabber(graphics.width_fb_, graphics.height_fb_, parameters.grab_file, (i_config*parameters.n_posit)/parameters.n_graph);
                //XXX This fucks up movie code//grabber(graphics.width_fb_, graphics.height_fb_, "snapshot", (i_config*parameters.n_posit)/parameters.n_graph);
            }
            #endif
            if (meas_spindle_separation) {
                double dr[3];
                for (int i = 0; i < 3; ++i)
                    dr[i] = properties.anchors.r_anchor[1][i] - properties.anchors.r_anchor[0][i];
                f_spindle_separation << properties.time << " " << sqrt(dot_product(3,dr,dr)) << std::endl;
            }
            f_frac_interpolar << properties.time << " " << 
                frac_interpolar(&parameters, &properties) << " " << 
                frac_length_interpolar(&parameters, &properties) << " " <<
                norm_interpolar_length(&parameters, &properties) << " " <<
                real_interpolar_length(&parameters, &properties) << " " <<
                std::endl;
            f_avg_mt_length << properties.time  << " " << 
                avg_mt_length(&properties) << std::endl;

            if (is_spindle(&parameters, &properties, center->second)) {
                bin_lengths(&parameters, &properties, distributions, streams);
            }

        } while (!last_config);
        f_spindle_separation.close();
    
        for (dist_map::iterator dist = distributions.begin();
             dist != distributions.end();
             ++dist) {
            dist->second.NormalizeNmeas();
            dist->second.OutputBinary(dist->first + "_dist" + center->first);
        }

        for (stream_output::iterator stream = streams.begin();
             stream != streams.end();
             ++stream) {
            (stream->second)->close();
            delete stream->second;
        }
    }

    
    /* Close output files. */
    fclose(f_posit);

    /* Normal termination. */
    std::cout << "Successful run!\n";
    exit(0);
}
