#include "data_collection.h"

DataCollection::DataCollection(){}

DataCollection::DataCollection(system_parameters *parameters, system_properties *properties) : properties_(properties), parameters_(parameters){
    node_ = YAML::LoadFile(parameters->data_analysis_file);/*{{{*/
    n_dim_ = parameters->n_dim;
    avg_flag_ = node_["avg_flag"].as<int>();
    avg_steps_ = node_["avg_steps"].as<int>();
    threshold_ = node_["threshold"].as<int>();

    CreateOutFileMap();
    InitData();
    //ClearData();
}/*}}}*/

void DataCollection::InitData(){
    xl_data_.Init(parameters_, properties_);/*{{{*/
    //bead_data_; FIXME Will only work with optical trap stuff
    mt_data_.Init(n_dim_);
    spb_data_.Init(n_dim_);
    w_data_.Init(n_dim_);
}/*}}}*/

void DataCollection::ClearData(){
    xl_data_.Clear();/*{{{*/
    //bead_data_.Clear(); FIXME Will only work with optical trap stuff
    mt_data_.Clear();
    spb_data_.Clear();
    w_data_.Clear();
}/*}}}*/

/*Creates Maps and Headers
 * Multiple files can be specified in the analysis file under data_collection_type dictionary.{{{
 * The header functions also designate what structures are used.*/
void DataCollection::CreateOutFileMap(){
    std::string avg_title = "";/*{{{*/
    std::string path = node_["data_collection_path"].as<std::string>();
    //switch (avg_flag_) {
        //case 0: avg_title = "";break;
        //case 1: avg_title = ":(averaged over " + std::to_string(avg_steps_) + " steps)";break;
        //case 2: avg_title = ":(summed over " + std::to_string(avg_steps_) + " steps)";break;
    //}

    YAML::Node file_types = node_["data_collection_type"];
    std::string suffix_title = node_["suffix_title"].as<std::string>();
    mt_data_.is_used = true;

    for (YAML::const_iterator type_it=file_types.begin(); 
            type_it !=file_types.end(); type_it++ ){
        std::string header;
        std::string type = (*type_it)["name"].as<std::string>();
        std::string file_name = path+type+"_"+suffix_title;

        if( type.compare("optical_trap") == 0){
            header = "n_steps";
            header += GetOpticalTrapHeader(&type_it);
            bead_data_.is_used = true;
        }
        if( type.compare("crosslinks") == 0 ){
            header = "n_steps";
            header += GetCrosslinkHeader(&type_it);
            xl_data_.is_used = true;
        }
        if (type.compare("wall_test") == 0 ){
            header = "time";
            header += GetWallTestHeader(&type_it);
            w_data_.is_used = true;
        }
        if ( type.compare("spb_data") == 0){
            header = "n_steps";
            header += GetSPBHeader(&type_it);
            spb_data_.is_used = true;
        }

        outfiles_[type];
        
        outfiles_[type].open(file_name, std::fstream::out);
        if (!outfiles_[type].is_open())
            std::cerr<<"Output "<< file_name <<" file did not open\n";
        else
            outfiles_[type]<<header<<avg_title<<"\n";
    }
}/*}}}*/

std::string DataCollection::GetOpticalTrapHeader(YAML::const_iterator *ot_node){
    std::string header ="";/*{{{*/
    data_types_["optical_trap"];

    //Loop over the different kinds of information you want included your file
    for (YAML::const_iterator data_it = (**ot_node)["include"].begin(); 
            data_it!=(**ot_node)["include"].end(); ++data_it){
        if(data_it->as<std::string>().compare("force") == 0){
            for (int i_trap=0; i_trap < *(bead_data_.n_traps); i_trap++){
                header += ":f_x"+std::to_string(i_trap)+":f_y"+std::to_string(i_trap)+":f_z"+std::to_string(i_trap);
            }
            data_types_["optical_trap"].push_back("force");
            continue;
        }
        if(data_it->as<std::string>().compare("position") == 0){
            for (int i_trap=0; i_trap < *(bead_data_.n_traps); i_trap++){
                header += ":r_x"+std::to_string(i_trap)+":r_y"+std::to_string(i_trap)+":r_z"+std::to_string(i_trap);
            }
            data_types_["optical_trap"].push_back("position");
            continue;
        }
        if(data_it->as<std::string>().compare("overlap") == 0){
            header += ":overlap";
            data_types_["optical_trap"].push_back("overlap");
            continue;
        }
        if(data_it->as<std::string>().compare("orientation") == 0){
            header += ":cos(theta)";
            data_types_["optical_trap"].push_back("orientation");
            continue;
        }
        if(data_it->as<std::string>().compare("extension") == 0){
            header += ":extension";
            data_types_["optical_trap"].push_back("extension");
            continue;
        }
        if(data_it->as<std::string>().compare("extension") == 0){
            header += ":extension";
            data_types_["optical_trap"].push_back("extension");
            continue;
        }
    }
    return header;
}/*}}}*/

std::string DataCollection::GetCrosslinkHeader(YAML::const_iterator *ot_node){
    std::string header ="";/*{{{*/
    data_types_["crosslinks"];

    if ((**ot_node)["mt_index"])
        xl_data_.xlink_mt = (**ot_node)["mt_index"].as<int>();

    for (YAML::const_iterator data_it = (**ot_node)["include"].begin(); 
            data_it!=(**ot_node)["include"].end(); ++data_it){

        if(data_it->as<std::string>().compare("xlink_distribution") == 0){
            //xl_data_.bin_num = int(properties_->bonds.length[xl_data_.xlink_mt]/4);
            for (int i=0; i<xl_data_.bin_num; i++)
                header += ":" + std::to_string(i*100);
            data_types_["crosslinks"].push_back("xlink_distribution");
            continue;
        }
        if(data_it->as<std::string>().compare("stage_1_xlinks") == 0){
            header += ":s1_xlinks";
            data_types_["crosslinks"].push_back("stage_1_xlinks");
            continue;
        }
        if(data_it->as<std::string>().compare("stage_2_xlinks") == 0){
            header += ":s2_xlinks";
            data_types_["crosslinks"].push_back("stage_2_xlinks");
            continue;
        }
    }
    return header;
}/*}}}*/

std::string DataCollection::GetWallTestHeader( YAML::const_iterator *ot_node){
    std::string header ="";/*{{{*/
    data_types_["wall_test"];

    if ((**ot_node)["mt_index"])
        w_data_.mt_index = (**ot_node)["mt_index"].as<int>();

    for (YAML::const_iterator data_it = (**ot_node)["include"].begin(); 
            data_it!=(**ot_node)["include"].end(); ++data_it){

        if(data_it->as<std::string>().compare("protrusion") == 0){
            header += ":protrusion";
            data_types_["wall_test"].push_back("protrusion");
            continue;
        }
        if(data_it->as<std::string>().compare("velocity") == 0){
            header += ":velocity";
            data_types_["wall_test"].push_back("velocity");
            continue;
        }
        if(data_it->as<std::string>().compare("force") == 0){
            header += ":f_wall";
            data_types_["wall_test"].push_back("force");
            continue;
        }
    }
    return header;
}/*}}}*/

std::string DataCollection::GetSPBHeader(YAML::const_iterator *ot_node){return "";}
/*}}}*/

/*Output functions*/ 
void DataCollection::WriteOutputs(){/*{{{*/
    n_steps_ = properties_->i_current_step;/*{{{*/

    for ( auto type:data_types_){
        if (type.first.compare("optical_trap") == 0)
            WriteOpticalTrapOutputs();
        else if (type.first.compare("crosslinks") == 0)
            WriteCrosslinkOutputs();
        else if (type.first.compare("wall_test") == 0)
            WriteWallOutputs();
        //left open for other data collection types
    }
    ClearData();
}/*}}}*/

void DataCollection::WriteOpticalTrapOutputs(){
    outfiles_["optical_trap"] << n_steps_;/*{{{*/
    //double avg_factor = ((avg_flag_ == 1) ? avg_steps_:1);

    //for(auto type : data_types_["optical_trap"]){ 
        //if(type.compare("force") == 0){
            //for(int i_trap=0; i_trap<properties_->bonds.n_traps; i_trap++){
                //for(int i=0; i<n_dim_; i++)
                    //outfiles_["optical_trap"] << ":"<<f_trap_[i_trap][i]/avg_factor;
            //}
        //}
        //else if(type.compare("position") == 0){
            //for(int i_trap=0; i_trap<properties_->bonds.n_traps; i_trap++){
                //for(int i=0; i<n_dim_; i++)
                    //outfiles_["optical_trap"] << ":"<<r_bead_[i_trap][i]/avg_factor;
            //}
        //}
        //else if(type.compare("orientation") == 0){
        //}
        //else if(type.compare("overlap") == 0){
            //outfiles_["optical_trap"] << ":"<<mt_data_.overlap/avg_factor; continue;
        //}
        //else if(type.compare("extension") == 0){
            //outfiles_["optical_trap"] << ":"<<bead_data_.extension/avg_factor;
        //}
    //}
    //outfiles_["optical_trap"] << "\n";
}/*}}}*/

void DataCollection::WriteCrosslinkOutputs(){
    outfiles_["crosslinks"] << n_steps_;/*{{{*/
    double avg_factor = ((avg_flag_ == 1) ? avg_steps_:1);

    for(auto type : data_types_["crosslinks"]){ 
        if(type.compare("xlink_distribution") == 0){
            for(int i=0; i<xl_data_.bin_num; i++)
                outfiles_["crosslinks"] << ":"<< xl_data_.xlink_dist[i]/avg_factor;
        }
        else if(type.compare("stage_1_xlinks") == 0){
            outfiles_["crosslinks"] << ":"<<double(xl_data_.stage_1_xlinks)/avg_factor; 
        }
        else if(type.compare("stage_2_xlinks") == 0){
            outfiles_["crosslinks"] << ":"<<double(xl_data_.stage_2_xlinks)/avg_factor; 
        }
    }
    outfiles_["crosslinks"] << "\n";
}/*}}}*/

void DataCollection::WriteWallOutputs(){
    outfiles_["wall_test"] << properties_->time;/*{{{*/
    double avg_factor = ((avg_flag_ == 1) ? avg_steps_:1);

    for(auto type : data_types_["wall_test"]){ 
        if(type.compare("protrusion") == 0)
            outfiles_["wall_test"] << ":"<< w_data_.dr/avg_factor; 
        else if(type.compare("velocity") == 0)
            outfiles_["wall_test"] << ":"<<w_data_.vel/avg_factor;
        else if(type.compare("force") == 0)
            outfiles_["wall_test"] << ":"<<w_data_.f_par/avg_factor; 
    }
    outfiles_["wall_test"] << "\n";
}/*}}}*/
/*}}}*/

/*Data Collection and calculations*/
void DataCollection::CalculateValues(){/*{{{*/
    n_steps_ = properties_->i_current_step;/*{{{*/
    //Check to make sure that you are only adding data when avg_flag is 1, 2, or the proper number of steps is reached
    if ( (avg_flag_ != 0) || (n_steps_ % avg_steps_ == 0)){
        for(auto type : data_types_){ 
            if (type.first.compare("optical_trap") == 0)
                AddOpticalTrapData();
            else if (type.first.compare("crosslinks") == 0)
                AddCrosslinkData();
            else if (type.first.compare("wall_test") == 0)
                AddWallData();
        //Left open for other data that we might want to collect
        }
    }
    if (n_steps_ % avg_steps_ == 0)
        WriteOutputs();
}/*}}}*/

void DataCollection::AddOpticalTrapData(){
    for(auto type:data_types_["optical_trap"]){ /*{{{*/
        if(type.compare("force") == 0){
            CalcTrapForce(); continue;
        }
        if(type.compare("position") == 0){
            CalcBeadPosition(); continue;
        }
        if(type.compare("orientation") == 0){
            CalcOrientation(); continue;
        }
        if(type.compare("overlap") == 0){
            CalcOverlap(); continue;
        }
        if(type.compare("extension") == 0){
            CalcExtension(); continue;
        }
    }
}/*}}}*/

void DataCollection::AddCrosslinkData(){
    for(auto type:data_types_["crosslinks"]){ /*{{{*/
        if(type.compare("xlink_distribution") == 0){
            CalcXlinkDistribution(); continue;
        }
        if(type.compare("stage_1_xlinks") == 0){
            CalcStage1Xlink(); continue;
        }
        if(type.compare("stage_2_xlinks") == 0){
            CalcStage2Xlink(); continue;
        }
    }

}/*}}}*/

void DataCollection::AddWallData(){
    for(auto type:data_types_["wall_test"]){ /*{{{*/
        if(type.compare("protrusion") == 0){
            CalcProtrusion(); 
        }
        else if(type.compare("velocity") == 0){
            CalcVelocity(); 
        }
        else if(type.compare("force") == 0){
            CalcForce();
        }
    }
}/*}}}*/

//Optical trap calculations
void DataCollection::CalcExtension(){/*{{{*/
    //const bond_properties *bonds = &(properties_->bonds);
    //double extension0 = 0;
    //double extension = 0;

    //if (bonds->n_traps == 1){
    //    for (int i=0; i<n_dim_; i++){
    //        extension += SQR(bonds->r_bead[0][i]-bonds->r_trap[0][i]);
    //    }
    //}
    //else if(bonds->n_traps >1){
    //    for (int i=0; i<n_dim_; i++){
    //        extension0 += SQR(bonds->r_trap[1][i]- bonds->r_trap[0][i] );
    //        extension += SQR(bonds->r_bead[1][i]-bonds->r_bead[0][i]);
    //    }
    //}
    //else
    //    std::cerr<<" Need at least one trap to calculate extension\n";
    //bead_data_.extension += sqrt(extension) - sqrt(extension0);

}

//TODO Make this so that you can do the orientation between every mt. Only does the first one at the moment
void DataCollection::CalcOrientation(){
    mt_data_.orientation += dot_product(n_dim_, properties_->bonds.u_bond[0], properties_->bonds.u_bond[1]);
}

void DataCollection::CalcOverlap(){}

void DataCollection::CalcTrapForce(){
    //const bond_properties *bonds = &(properties_->bonds);
    //double *k = bonds->traps; 
    //int i_trap = 0;
    
    //for( int i_bond = 0; i_bond < bonds->n_bonds; i_bond++){
    //    if (bonds->bond_type[i_bond] == 1){
    //        for (int i=0; i<n_dim_; i++){
    //            //This is the force exerted by the trap so F = -k(x1-x0) or F=k(x0-x1)
    //            f_trap_[i_trap][i] += k[i_bond]*(bonds->r_trap[i_trap][i] - bonds->r_bead[i_trap][i]);
    //        }
    //        i_trap++;
    //    }
    //}
}

void DataCollection::CalcBeadPosition(){
    //const bond_properties *bonds = &(properties_->bonds);
    //int i_trap = 0;
    //for( int i_bond = 0; i_bond < bonds->n_bonds; i_bond++){
    //    if (bonds->bond_type[i_bond] == 1){
    //        for (int i=0; i<n_dim_; i++){
    //            r_bead_[i_trap][i] += bonds->r_bead[i_trap][i];
    //        }
    //        i_trap++;
    //    }
    //}
}/*}}}*/

//Crosslink calculations
void DataCollection::CalcXlinkDistribution(){/*{{{*/
    int n_bonds = properties_->bonds.n_bonds;
    std::vector<XlinkEntry> *xlinks = properties_->crosslinks.stage_2_xlinks_[0];
    double length = properties_->bonds.length[xl_data_.xlink_mt];

    for (int i_bond=0; i_bond < n_bonds; i_bond++){
        for (xlink_list::iterator xlink = xlinks[i_bond].begin();
                xlink < xlinks[i_bond].end(); xlink++){
            int i_head;
            //Figure out which head of xlink is attached to the xlink_mt
            if (xlink->head_parent_[0] == xl_data_.xlink_mt) i_head = 0;
            else if (xlink->head_parent_[1] == xl_data_.xlink_mt) i_head = 1;
            else continue;

            int bin = int((xlink->cross_position_[i_head]/length)*xl_data_.bin_num);
            if (bin == xl_data_.bin_num) 
                bin--; //if position exceeds bin location put it on the last viable spot

            xl_data_.xlink_dist[bin]++;
        }
    }
}

void DataCollection::CalcStage1Xlink(){
    xl_data_.stage_1_xlinks += properties_->crosslinks.n_bound_1_[0][0] +
                    properties_->crosslinks.n_bound_1_[0][1]; //Only records the first species with index 0
}

void DataCollection::CalcStage2Xlink(){
    xl_data_.stage_2_xlinks += properties_->crosslinks.n_bound_2_[0]; //Only records the first species with index 0
}/*}}}*/

//Wall calculations
void DataCollection::GetWallForce(double f_par, int index){/*{{{*/
    if(index == w_data_.mt_index)
        w_data_.f_par = f_par;
}

void DataCollection::CalcTipPos(){
    int index = w_data_.mt_index;/*{{{*/
    double* r_bond = properties_->bonds.r_bond[index];
    double* u = properties_->bonds.u_bond[index];
    double length = properties_->bonds.length[index];

    for (int i=0; i<n_dim_; i++)
        w_data_.pos[i] = r_bond[i] + .5*u[i]*length;
}/*}}}*/

void DataCollection::CalcProtrusion(){
    CalcTipPos();/*{{{*/
    double sphere_rad = properties_->unit_cell.h[0][0]/2.0;

    w_data_.dr += sqrt(dot_product(n_dim_, w_data_.pos, w_data_.pos)) - sphere_rad;
}/*}}}*/

void DataCollection::CalcVelocity(){
    double *dr = new double[n_dim_]; /*{{{*/
    
    w_data_.t = properties_->time;
    CalcTipPos();
    for (int i=0; i<n_dim_; i++)
        dr[i] = w_data_.pos[i]-w_data_.pos0[i];

    if (w_data_.t != w_data_.t0)
        w_data_.vel += sqrt(dot_product(n_dim_, dr, dr))/(w_data_.t - w_data_.t0);

    delete[] dr;
}/*}}}*/

void DataCollection::CalcForce(){
    w_data_.f_par_total += w_data_.f_par;
}
/*}}}*/
/*}}}*/

//Check functions
bool DataCollection::WallIsUsed(){/*{{{*/
    return w_data_.is_used;
}
/*}}}*/


int DataCollection::GetAvgSteps(){return avg_steps_;}

int DataCollection::GetThreshold(){ return threshold_;}

void DataCollection::Close(){
    //for(auto type : outfiles_){
        //type.second.close();
    //}
}




