#include "mex.h"
#include "LamSimAnn.h"
#include "Energy_fit.h"
#include "Energy_fit_lambda.h"
#include <iostream>

//**************************************************************************************************************************************************

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// @args: logic, pt1, pt2, pt3, trg_syn, LB, UB, lambda, tau, theta0, start_par
{
	mwSize r,c;
	int i,j;
	int Np=0,logic;
	std::vector<double> pt1,pt2,pt3;
	std::vector<double> lb,ub,trgs;
    std::vector<double> start_par;
	double *p;
	
	if (nrhs < 7) mexErrMsgTxt("At least 7 arguments required.");

	logic = int(*mxGetPr(prhs[0]));
	
	c = mxGetN(prhs[1]);
	pt1.resize(c);
	p = mxGetPr(prhs[1]);
	for (i=0; i<c; i++) {
        pt1[i] = *p;
        p++;
	}
	
	c = mxGetN(prhs[2]);
	if (c != 0) {
		pt2.resize(c);
		p = mxGetPr(prhs[2]);
		for (i=0; i<c; i++) {
            pt2[i] = *p;
            p++;
		}
	} else {
		Np = 1;
	}
	
	c = mxGetN(prhs[3]);
	if (Np != 1) {
		if (c != 0) {
			pt3.resize(c);
			p = mxGetPr(prhs[3]);
			for (i=0; i<c; i++) {
                pt3[i] = *p;
                p++;
			}
			Np = 3;
		} else {
			Np = 2;
		}
	}
	
	r = mxGetM(prhs[4]);
	c = mxGetN(prhs[4]);
	trgs.resize(c);
	p = mxGetPr(prhs[4]);
	for (i=0; i<c; i++) {
		trgs[i] = *p;
		p++;
	}
	
	r = mxGetM(prhs[5]);
	c = mxGetN(prhs[5]);
	lb.resize(c);
	p = mxGetPr(prhs[5]);
	for (i=0; i<c; i++) {
		lb[i] = *p;
		p++;
	}
	
	r = mxGetM(prhs[6]);
	c = mxGetN(prhs[6]);
	ub.resize(c);
	p = mxGetPr(prhs[6]);
	for (i=0; i<c; i++) {
		ub[i] = *p;
		p++;
	}

	double lam=0.0001,theta=0.1;
	int tau=100;
	
	if (nrhs > 7) {
		lam = *mxGetPr(prhs[7]);
		if (nrhs > 8) {
			tau = int(*mxGetPr(prhs[8]));
			if (nrhs > 9) {
				theta = *mxGetPr(prhs[9]);
			}
		}
	}
    
    if (nrhs == 11) {
        c = mxGetN(prhs[10]);
        start_par.resize(c);
        p = mxGetPr(prhs[10]);
        for (i=0; i<c; i++) {
            start_par[i] = *p;
            p++;
        }
    }
    	
	LamSimAnn sim(lam,tau,theta,10,1.e-10);

	if (Np == 1) {
		Energy_fit_lambda fct(logic,trgs,lb,ub,pt1);
		if (start_par.size()==0) {
            sim.Run(&fct);
        } else {
            sim.Run(&fct,&start_par[0]);
        }
	} else if (Np == 2) {
		Energy_fit_lambda fct(logic,trgs,lb,ub,pt1,pt2);
		if (start_par.size()==0) {
            sim.Run(&fct);
        } else {
            sim.Run(&fct,&start_par[0]);
        }
	} else if (Np == 3) {
		Energy_fit_lambda fct(logic,trgs,lb,ub,pt1,pt2,pt3);
		if (start_par.size()==0) {
            sim.Run(&fct);
        } else {
            sim.Run(&fct,&start_par[0]);
        }
	}
	
	std::vector<double> tmp = sim.GetBestPar();
	plhs[0] = mxCreateDoubleMatrix(1,tmp.size(), mxREAL);
	p = mxGetPr(plhs[0]);
	for (i=0; i<tmp.size(); i++) {
		*p = tmp[i];
		p++;
	}
	
	double tmp2 = sim.GetBestEnergy();
	plhs[1] = mxCreateDoubleMatrix(1,1, mxREAL);
	p = mxGetPr(plhs[1]);
	*p = tmp2;
	
	if (nlhs > 2) {
		tmp = sim.GetMeanEnergyHist();
		plhs[2] = mxCreateDoubleMatrix(tmp.size(),1, mxREAL);
		p = mxGetPr(plhs[2]);
		for (i=0; i<tmp.size(); i++) {
			*p = tmp[i];
			p++;
		}
		
		tmp = sim.GetTHist();
		plhs[3] = mxCreateDoubleMatrix(tmp.size(),1, mxREAL);
		p = mxGetPr(plhs[3]);
		for (i=0; i<tmp.size(); i++) {
			*p = tmp[i];
			p++;
		}
		
		tmp = sim.GetAccRatioHist();
		plhs[4] = mxCreateDoubleMatrix(tmp.size(),1, mxREAL);
		p = mxGetPr(plhs[4]);
		for (i=0; i<tmp.size(); i++) {
			*p = tmp[i];
			p++;
		}
	}
	
	
}
