#include "mex.h"
#include "LamSimAnn.h"
#include "Energy_fit.h"

//**************************************************************************************************************************************************

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// @args: logic, pt1, pt2, pt3, trg_syn, LB, UB, lambda, tau, theta0
{
	mwSize r,c;
	int i,j;
	int Np=0,logic;
	std::vector< std::vector<double> > pt1,pt2,pt3;
	std::vector<double> lb,ub,trgs;
	double *p;
	
	if (nrhs < 7) mexErrMsgTxt("At least 7 arguments required.");

	logic = int(*mxGetPr(prhs[0]));
	
	r = mxGetM(prhs[1]);
	c = mxGetN(prhs[1]);
	pt1.resize(r);
	for (i=0; i<r; i++) pt1[i].resize(c);
	p = mxGetPr(prhs[1]);
	for (i=0; i<c; i++) {
		for (j=0; j<r; j++) {
			pt1[j][i] = *p;
			p++;
		}
	}
	
	r = mxGetM(prhs[2]);
	c = mxGetN(prhs[2]);
	if (r != 0) {
		pt2.resize(r);
		for (i=0; i<r; i++) pt2[i].resize(c);
		p = mxGetPr(prhs[2]);
		for (i=0; i<c; i++) {
			for (j=0; j<r; j++) {
				pt2[j][i] = *p;
				p++;
			}
		}
	} else {
		Np = 1;
	}
	
	r = mxGetM(prhs[3]);
	c = mxGetN(prhs[3]);
	if (Np != 1) {
		if (r != 0) {
			pt3.resize(r);
			for (i=0; i<r; i++) pt3[i].resize(c);
			p = mxGetPr(prhs[3]);
			for (i=0; i<c; i++) {
				for (j=0; j<r; j++) {
					pt3[j][i] = *p;
					p++;
				}
			}
			Np = 3;
		} else {
			Np = 2;
		}
	}
	
	r = mxGetM(prhs[4]);
	c = mxGetN(prhs[4]);
	trgs.resize(c);
	p = mxGetPr(prhs[4]);
	for (i=0; i<c; i++) {
		trgs[i] = *p;
		p++;
	}
	
	r = mxGetM(prhs[5]);
	c = mxGetN(prhs[5]);
	lb.resize(c);
	p = mxGetPr(prhs[5]);
	for (i=0; i<c; i++) {
		lb[i] = *p;
		p++;
	}
	
	r = mxGetM(prhs[6]);
	c = mxGetN(prhs[6]);
	ub.resize(c);
	p = mxGetPr(prhs[6]);
	for (i=0; i<c; i++) {
		ub[i] = *p;
		p++;
	}

	double lam=0.0001,theta=0.1;
	int tau=100;
	
	if (nrhs > 7) {
		lam = *mxGetPr(prhs[7]);
		if (nrhs > 8) {
			tau = int(*mxGetPr(prhs[8]));
			if (nrhs > 9) {
				theta = *mxGetPr(prhs[9]);
			}
		}
	}
	
	LamSimAnn sim(lam,tau,theta);

	if (Np == 1) {
		Energy_fit fct(logic,trgs,lb,ub,pt1);
		sim.Run(&fct);
	} else if (Np == 2) {
		Energy_fit fct(logic,trgs,lb,ub,pt1,pt2);
		sim.Run(&fct);
	} else if (Np == 3) {
		Energy_fit fct(logic,trgs,lb,ub,pt1,pt2,pt3);
		sim.Run(&fct);
	}
	
	std::vector<double> tmp = sim.GetBestPar();
	plhs[0] = mxCreateDoubleMatrix(1,tmp.size(), mxREAL);
	p = mxGetPr(plhs[0]);
	for (i=0; i<tmp.size(); i++) {
		*p = tmp[i];
		p++;
	}
	
	double tmp2 = sim.GetBestEnergy();
	plhs[1] = mxCreateDoubleMatrix(1,1, mxREAL);
	p = mxGetPr(plhs[1]);
	*p = tmp2;
	
	if (nlhs > 2) {
		tmp = sim.GetMeanEnergyHist();
		plhs[2] = mxCreateDoubleMatrix(tmp.size(),1, mxREAL);
		p = mxGetPr(plhs[2]);
		for (i=0; i<tmp.size(); i++) {
			*p = tmp[i];
			p++;
		}
		
		tmp = sim.GetTHist();
		plhs[3] = mxCreateDoubleMatrix(tmp.size(),1, mxREAL);
		p = mxGetPr(plhs[3]);
		for (i=0; i<tmp.size(); i++) {
			*p = tmp[i];
			p++;
		}
		
		tmp = sim.GetAccRatioHist();
		plhs[4] = mxCreateDoubleMatrix(tmp.size(),1, mxREAL);
		p = mxGetPr(plhs[4]);
		for (i=0; i<tmp.size(); i++) {
			*p = tmp[i];
			p++;
		}
	}
	
	
}
