From 583ebf462b355e9d82a7e04a2f11abe29ca79757 Mon Sep 17 00:00:00 2001 From: stevenj Date: Tue, 15 Nov 2011 17:47:59 -0500 Subject: [PATCH] added prototype matlab precond interface (for objective only) Ignore-this: 483c95a6036a94ae9f5a61201730de80 darcs-hash:20111115224759-c8de0-6d0296c52baf1a8c579303faf3f4e1e55fb832f0.gz --- octave/nlopt_optimize-mex.c | 69 ++++++++++++++++++++++++++++++++----- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/octave/nlopt_optimize-mex.c b/octave/nlopt_optimize-mex.c index ac3cdf9..b275f9c 100644 --- a/octave/nlopt_optimize-mex.c +++ b/octave/nlopt_optimize-mex.c @@ -76,13 +76,14 @@ static double *fill(double *arr, unsigned n, double val) } #define FLEN 128 /* max length of user function name */ -#define MAXRHS 2 /* max nrhs for user function */ -typedef struct { +#define MAXRHS 3 /* max nrhs for user function */ +typedef struct user_function_data_s { char f[FLEN]; mxArray *plhs[2]; mxArray *prhs[MAXRHS]; int xrhs, nrhs; int verbose, neval; + struct user_function_data_s *dpre; } user_function_data; static double user_function(unsigned n, const double *x, @@ -117,6 +118,28 @@ static double user_function(unsigned n, const double *x, return f; } +static void user_pre(unsigned n, const double *x, const double *v, + double *vpre, void *d_) +{ + user_function_data *d = ((user_function_data *) d_)->dpre; + d->plhs[0] = d->plhs[1] = NULL; + memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double)); + memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double)); + + CHECK0(0 == mexCallMATLAB(1, d->plhs, + d->nrhs, d->prhs, d->f), + "error calling user function"); + + CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0]) + && (mxGetM(d->plhs[0]) == 1 || mxGetN(d->plhs[0]) == 1) + && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == n, + "vpre vector from user function is the wrong size"); + memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double)); + mxDestroyArray(d->plhs[0]); + d->neval++; + if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval); +} + #define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; }; nlopt_opt make_opt(const mxArray *opts, unsigned n) @@ -180,7 +203,7 @@ void mexFunction(int nlhs, mxArray *plhs[], double *x, *x0, opt_f; nlopt_result ret; mxArray *x_mx, *mx; - user_function_data d, *dfc = NULL, *dh = NULL; + user_function_data d, dpre, *dfc = NULL, *dh = NULL; nlopt_opt opt = NULL; CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments"); @@ -217,11 +240,40 @@ void mexFunction(int nlhs, mxArray *plhs[], d.xrhs = 1; } d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL); - if (struct_funcval(prhs[0], "min_objective")) - nlopt_set_min_objective(opt, user_function, &d); - else - nlopt_set_max_objective(opt, user_function, &d); - + + if ((mx = struct_funcval(prhs[0], "pre"))) { + CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx), + "pre must contain function handles or function names"); + if (mxIsChar(mx)) { + CHECK(mxGetString(mx, dpre.f, FLEN) == 0, + "error reading function name string (too long?)"); + dpre.nrhs = 2; + dpre.xrhs = 0; + } + else { + dpre.prhs[0] = mx; + strcpy(dpre.f, "feval"); + dpre.nrhs = 3; + dpre.xrhs = 1; + } + dpre.verbose = d.verbose > 2; + dpre.neval = 0; + dpre.prhs[dpre.xrhs] = d.prhs[d.xrhs]; + dpre.prhs[d.xrhs+1] = mxCreateDoubleMatrix(1, n, mxREAL); + d.dpre = &dpre; + + if (struct_funcval(prhs[0], "min_objective")) + nlopt_set_precond_min_objective(opt, user_function,user_pre,&d); + else + nlopt_set_precond_max_objective(opt, user_function,user_pre,&d); + } + else { + dpre.nrhs = 0; + if (struct_funcval(prhs[0], "min_objective")) + nlopt_set_min_objective(opt, user_function, &d); + else + nlopt_set_max_objective(opt, user_function, &d); + } if ((mx = mxGetField(prhs[0], 0, "fc"))) { int j, m; @@ -304,6 +356,7 @@ void mexFunction(int nlhs, mxArray *plhs[], mxFree(dh); mxFree(dfc); mxDestroyArray(d.prhs[d.xrhs]); + if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]); nlopt_destroy(opt); plhs[0] = x_mx; -- 2.30.2