}
#define FLEN 128 /* max length of user function name */
-#define MAXRHS 2 /* max nrhs for user function */
-typedef struct {
+#define MAXRHS 3 /* max nrhs for user function */
+typedef struct user_function_data_s {
char f[FLEN];
mxArray *plhs[2];
mxArray *prhs[MAXRHS];
int xrhs, nrhs;
int verbose, neval;
+ struct user_function_data_s *dpre;
} user_function_data;
static double user_function(unsigned n, const double *x,
return f;
}
+static void user_pre(unsigned n, const double *x, const double *v,
+ double *vpre, void *d_)
+{
+ user_function_data *d = ((user_function_data *) d_)->dpre;
+ d->plhs[0] = d->plhs[1] = NULL;
+ memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
+ memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double));
+
+ CHECK0(0 == mexCallMATLAB(1, d->plhs,
+ d->nrhs, d->prhs, d->f),
+ "error calling user function");
+
+ CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0])
+ && (mxGetM(d->plhs[0]) == 1 || mxGetN(d->plhs[0]) == 1)
+ && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == n,
+ "vpre vector from user function is the wrong size");
+ memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double));
+ mxDestroyArray(d->plhs[0]);
+ d->neval++;
+ if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
+}
+
#define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; };
nlopt_opt make_opt(const mxArray *opts, unsigned n)
double *x, *x0, opt_f;
nlopt_result ret;
mxArray *x_mx, *mx;
- user_function_data d, *dfc = NULL, *dh = NULL;
+ user_function_data d, dpre, *dfc = NULL, *dh = NULL;
nlopt_opt opt = NULL;
CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments");
d.xrhs = 1;
}
d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
- if (struct_funcval(prhs[0], "min_objective"))
- nlopt_set_min_objective(opt, user_function, &d);
- else
- nlopt_set_max_objective(opt, user_function, &d);
-
+
+ if ((mx = struct_funcval(prhs[0], "pre"))) {
+ CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
+ "pre must contain function handles or function names");
+ if (mxIsChar(mx)) {
+ CHECK(mxGetString(mx, dpre.f, FLEN) == 0,
+ "error reading function name string (too long?)");
+ dpre.nrhs = 2;
+ dpre.xrhs = 0;
+ }
+ else {
+ dpre.prhs[0] = mx;
+ strcpy(dpre.f, "feval");
+ dpre.nrhs = 3;
+ dpre.xrhs = 1;
+ }
+ dpre.verbose = d.verbose > 2;
+ dpre.neval = 0;
+ dpre.prhs[dpre.xrhs] = d.prhs[d.xrhs];
+ dpre.prhs[d.xrhs+1] = mxCreateDoubleMatrix(1, n, mxREAL);
+ d.dpre = &dpre;
+
+ if (struct_funcval(prhs[0], "min_objective"))
+ nlopt_set_precond_min_objective(opt, user_function,user_pre,&d);
+ else
+ nlopt_set_precond_max_objective(opt, user_function,user_pre,&d);
+ }
+ else {
+ dpre.nrhs = 0;
+ if (struct_funcval(prhs[0], "min_objective"))
+ nlopt_set_min_objective(opt, user_function, &d);
+ else
+ nlopt_set_max_objective(opt, user_function, &d);
+ }
if ((mx = mxGetField(prhs[0], 0, "fc"))) {
int j, m;
mxFree(dh);
mxFree(dfc);
mxDestroyArray(d.prhs[d.xrhs]);
+ if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]);
nlopt_destroy(opt);
plhs[0] = x_mx;