#include "mex.h"
#include "matrix.h"
#include "NetworkFinder.h"
#include <iostream>
#include <vector>
#include <math.h>

//**************************************************************************************************************************************************

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// @args: imat, nodes, N
{
    mxArray    *tmp;
    int        r,c,N;
    std::vector< std::vector<int> > imat;
    std::vector<Node> nodes;
    double *p;
    int i,j;
    
    /* check proper input and output */
    if(nrhs!=3) mexErrMsgTxt("Three arguments required.");
    else if(nlhs > 3) mexErrMsgTxt("Too many output arguments.");
    else if(!mxIsStruct(prhs[1])) mexErrMsgTxt("Third input must be a structure.");

    r = mxGetM(prhs[0]);
	c = mxGetN(prhs[0]);
    if(r!=c) mexErrMsgTxt("First argument must be a square matrix.");
	imat.resize(r);
	for (i=0; i<r; i++) imat[i].resize(c);
	p = mxGetPr(prhs[0]);
	for (i=0; i<c; i++) {
		for (j=0; j<r; j++) {
			imat[j][i] = *p;
			p++;
		}
	}
    
    N = mxGetNumberOfElements(prhs[1]);
    Node t;
    
    
    for (i=0; i<N; i++) {
        // 1st field: trg
        tmp = mxGetFieldByNumber(prhs[1], i, 0);
        t.trg = *mxGetPr(tmp);
        
        // 2nd field: tfs
        tmp = mxGetFieldByNumber(prhs[1], i, 1);
        c = mxGetNumberOfElements(tmp);
        p = mxGetPr(tmp);
        t.tfs.resize(c);
        for (j=0; j<c; j++) {
            t.tfs[j] = *p;
            p++;
        }
        
        // 3rd field: score
        tmp = mxGetFieldByNumber(prhs[1], i, 2);
        t.score = *mxGetPr(tmp);
        
        nodes.push_back(t);
    }
    
    N = int(*mxGetPr(prhs[2]));
    
    NetworkFinder net(imat);
    for (i=0; i<nodes.size(); i++) net.AddNode(nodes[i]);
    net.FindAll(N);
    
    net.ScoreNetworks();
    std::vector<double> netsc = net.GetNetScores();
    
    
    std::vector< std::vector<Node> > res = net.GetBestNetworks(N);
    
    // OUTPUT
    int L = 0;
    for (i=0; i<N; i++) {
        if (res[i].size() > L) L = res[i].size();
    }
    
    plhs[0] = mxCreateDoubleMatrix(L,N, mxREAL);
    p = mxGetPr(plhs[0]);
    
    for (i=0; i<N; i++) {
        for (j=0; j<L; j++) {
            if (j < res[i].size()) *p = res[i][j].node_id;
            else *p = -1;
            
            p++;	
        }
    }
    
    plhs[1] = mxCreateDoubleMatrix(netsc.size(),1, mxREAL);
    p = mxGetPr(plhs[1]);
    for (i = 0; i < netsc.size(); i++) {
        *p = netsc[i];
        p++;
    }
    
    plhs[2] = mxCreateDoubleMatrix(2,1, mxREAL);
    p = mxGetPr(plhs[2]);
    *p = net.GetN();
    p++;
    *p = net.GetNN();
    
    return;
}
