//
//  NetworkFinder.cpp
//  IrreducibleNetworks
//
//  Created by Patrick Hillenbrand on 17.05.12.
//  Copyright (c) 2012 __MyCompanyName__. All rights reserved.
//

#include "NetworkFinder.h"
#include <vector>
#include <stdexcept>
#include <algorithm>
#include <iostream>
#include <stdio.h>

NetworkFinder::NetworkFinder (std::vector< std::vector<int> > imat_)
{
    imat = imat_;
}

//********************************************************************************************************

long NetworkFinder::FindAll (int nb)
{
    m_N = 0;
    m_nn = 0;
    curr_netw.clear();
    in_netw.clear();
    
    networks.resize(nb);
    net_scores.resize(nb);
    score_map.resize(nb);
    
    for (int i = 0; i < nb; i++) {
        net_scores[i] = 1.;
        score_map[i] = i;
    }
    
    ListNetworks(0);
    
    return networks.size();
}

//********************************************************************************************************

int NetworkFinder::isConnected ()
{
    int j,k;
    int inc = 1;
    m_nn+=inc;
    
    if (in_netw.size() <= 1) return 0;
    else {
        for (k = 0; k < in_netw.size(); k++) {
            for (j = 0; j < curr_netw[in_netw[k]].tfs.size(); j++) {
                if (std::find(in_netw.begin(), in_netw.end(), curr_netw[in_netw[k]].tfs[j])==in_netw.end()) return 0;
            }
            nodes.clear();
            ListNodes(in_netw[k]);
            if (nodes.size() != in_netw.size()) break;
        }
        if (k == in_netw.size()){
            StoreNetw(0);
            return inc;
        }
        else return 0;
    }
}

//********************************************************************************************************

void NetworkFinder::ListNodes (int n)
{
    nodes.push_back(n);
    for (int i = 0; i < curr_netw[n].tfs.size(); i++) {
        if (std::find(nodes.begin(),nodes.end(),curr_netw[n].tfs[i]) == nodes.end()) {
            ListNodes(curr_netw[n].tfs[i]);
        }
    }
}

//********************************************************************************************************

void NetworkFinder::ListNetworks(int n)
{
    // list all TFs
    std::vector<int> tfs;
    for (int i = 0; i < imat.size(); i++) {
        if (imat[i][n] == 1) tfs.push_back(i);
    }
    
    // list all possible regulations of this node
    curr_netw.resize(curr_netw.size()+1);
    
    // a) node is not in network
    curr_netw.back().tfs.clear();
    if (n == imat.size()-1) {
        m_N+=isConnected();
    } else {
        ListNetworks(n+1);
    }
    
    // b) single interactions
    in_netw.push_back(n);
    curr_netw.back().tfs.resize(1);
    for (int i = 0; i < tfs.size(); i++) {
        curr_netw.back().tfs[0] = tfs[i];
        if (n == imat.size()-1) {
            m_N+=isConnected();
        } else {
            ListNetworks(n+1);
        }
    }
    
    // c) combinatorial interactions
    curr_netw.back().tfs.resize(2);
    for (int i = 0; i < tfs.size() - 1; i++) {
        curr_netw.back().tfs[0] = tfs[i];
        for (int j = i+1; j < tfs.size(); j++) {
            curr_netw.back().tfs[1] = tfs[j];
            if (n == imat.size()-1) {
                m_N+=isConnected();
            } else {
                ListNetworks(n+1);
            }
        }
    }
    
    // shorten network again
    curr_netw.pop_back();
    in_netw.pop_back();
    
}

//********************************************************************************************************

void NetworkFinder::StoreNetw(int in)
{
    if (in == in_netw.size()) {
        //compute score
        double score = 0.;
        int i,j;
        for (i = 0; i < in_netw.size(); i++) {
            curr_netw[in_netw[i]].trg = in_netw[i];
            for (j = 0; j < all_nodes.size(); j++) {
                if (curr_netw[in_netw[i]] == all_nodes[j]) break;
            }
            score += all_nodes[j].score;
        }
        score /= double(in_netw.size());
        
        if (score < net_scores[0]) {
            
            for (i = 0; i < net_scores.size(); i++) {
                if (net_scores[i] < score) break;
            }
            
            net_scores.insert(net_scores.begin()+i,score);
            net_scores.erase(net_scores.begin());
            j = score_map.front();
            score_map.insert(score_map.begin()+i,j);
            score_map.erase(score_map.begin());
            
            networks[j].clear();
            for (i = 0; i < in_netw.size(); i++) {
                curr_netw[in_netw[i]].trg = in_netw[i];
                networks[j].push_back(curr_netw[in_netw[i]]);
            }
            
        }
        
    } else {
        StoreNetw(in+1);

    }
}

//********************************************************************************************************

void NetworkFinder::AddNode (Node node)
{
    all_nodes.push_back(node);
}

//********************************************************************************************************

void NetworkFinder::ScoreNetworks ()
{
    int i,j;
    net_scores.resize(networks.size());
    
    // score all networks
    for (int n = 0; n < net_scores.size(); n++) {
    
        net_scores[n] = 0.;
        
        for (i = 0; i < networks[n].size(); i++) {
            for (j = 0; j < all_nodes.size(); j++) {
                if (networks[n][i] == all_nodes[j]) break;
            }
            if (j == all_nodes.size()) {
                char buffer [200];
                sprintf(buffer,"n=%i\none node has not been found:\ntrg:\t%i\ntf:\t%i,%i\nscore:\t%f\n",n,networks[n][i].trg,networks[n][i].tfs[0],networks[n][i].tfs[1],networks[n][i].score);
                throw std::runtime_error(buffer);
            } else {
                net_scores[n] += all_nodes[j].score;
                networks[n][i].node_id = j;
            }
        }
        
        net_scores[n] /= double(networks[n].size());
    
    }
}

//********************************************************************************************************

std::vector< std::vector<Node> > NetworkFinder::GetBestNetworks (int howmany)
{
    std::vector< std::vector<Node> > res;
    int i,bi;
    
    for (int n = 0; n < howmany; n++) {
        bi = 0;
        for (i = 0; i < net_scores.size(); i++) {
            if (net_scores[i] < net_scores[bi]) bi = i;
        }
        
        res.push_back(networks[bi]);
        networks.erase(networks.begin()+bi);
        net_scores.erase(net_scores.begin()+bi);
    }
    
    return res;
}

//********************************************************************************************************

bool operator== (Node x, Node y)
{
    if (x.trg != y.trg) return false;
    if (x.tfs.size() != y.tfs.size()) return false;
    int j;
    for (int i = 0; i < x.tfs.size(); i++) {
        for (j = 0; j < y.tfs.size(); j++) {
            if (x.tfs[i] == y.tfs[j]) break;
        }
        if (j == y.tfs.size()) return false;
    }
    return true;
}




