chiark / gitweb /
added prototype matlab precond interface (for objective only)
authorstevenj <stevenj@alum.mit.edu>
Tue, 15 Nov 2011 22:47:59 +0000 (17:47 -0500)
committerstevenj <stevenj@alum.mit.edu>
Tue, 15 Nov 2011 22:47:59 +0000 (17:47 -0500)
Ignore-this: 483c95a6036a94ae9f5a61201730de80

darcs-hash:20111115224759-c8de0-6d0296c52baf1a8c579303faf3f4e1e55fb832f0.gz

octave/nlopt_optimize-mex.c

index ac3cdf9f370789d003dbce9d03225f141d47e308..b275f9c4568dda5d9421f1eacec1ee33bc014b12 100644 (file)
@@ -76,13 +76,14 @@ static double *fill(double *arr, unsigned n, double val)
 }
 
 #define FLEN 128 /* max length of user function name */
-#define MAXRHS 2 /* max nrhs for user function */
-typedef struct {
+#define MAXRHS 3 /* max nrhs for user function */
+typedef struct user_function_data_s {
      char f[FLEN];
      mxArray *plhs[2];
      mxArray *prhs[MAXRHS];
      int xrhs, nrhs;
      int verbose, neval;
+     struct user_function_data_s *dpre;
 } user_function_data;
 
 static double user_function(unsigned n, const double *x,
@@ -117,6 +118,28 @@ static double user_function(unsigned n, const double *x,
   return f;
 }
 
+static void user_pre(unsigned n, const double *x, const double *v,
+                      double *vpre, void *d_)
+{
+  user_function_data *d = ((user_function_data *) d_)->dpre;
+  d->plhs[0] = d->plhs[1] = NULL;
+  memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
+  memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double));
+
+  CHECK0(0 == mexCallMATLAB(1, d->plhs, 
+                           d->nrhs, d->prhs, d->f),
+        "error calling user function");
+
+  CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0])
+        && (mxGetM(d->plhs[0]) == 1 || mxGetN(d->plhs[0]) == 1)
+        && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == n,
+        "vpre vector from user function is the wrong size");
+  memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double));
+  mxDestroyArray(d->plhs[0]);
+  d->neval++;
+  if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
+}
+
 #define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; };
 
 nlopt_opt make_opt(const mxArray *opts, unsigned n)
@@ -180,7 +203,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
      double *x, *x0, opt_f;
      nlopt_result ret;
      mxArray *x_mx, *mx;
-     user_function_data d, *dfc = NULL, *dh = NULL;
+     user_function_data d, dpre, *dfc = NULL, *dh = NULL;
      nlopt_opt opt = NULL;
 
      CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments");
@@ -217,11 +240,40 @@ void mexFunction(int nlhs, mxArray *plhs[],
          d.xrhs = 1;
      }
      d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
-     if (struct_funcval(prhs[0], "min_objective"))
-         nlopt_set_min_objective(opt, user_function, &d);
-     else
-         nlopt_set_max_objective(opt, user_function, &d);
-     
+
+     if ((mx = struct_funcval(prhs[0], "pre"))) {
+         CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
+               "pre must contain function handles or function names");
+         if (mxIsChar(mx)) {
+              CHECK(mxGetString(mx, dpre.f, FLEN) == 0,
+                     "error reading function name string (too long?)");
+              dpre.nrhs = 2;
+              dpre.xrhs = 0;
+         }
+         else {
+              dpre.prhs[0] = mx;
+              strcpy(dpre.f, "feval");
+              dpre.nrhs = 3;
+              dpre.xrhs = 1;
+         }
+         dpre.verbose = d.verbose > 2;
+         dpre.neval = 0;
+         dpre.prhs[dpre.xrhs] = d.prhs[d.xrhs];
+         dpre.prhs[d.xrhs+1] = mxCreateDoubleMatrix(1, n, mxREAL);
+         d.dpre = &dpre;
+
+         if (struct_funcval(prhs[0], "min_objective"))
+              nlopt_set_precond_min_objective(opt, user_function,user_pre,&d);
+         else
+              nlopt_set_precond_max_objective(opt, user_function,user_pre,&d);
+     }
+     else {
+         dpre.nrhs = 0;
+         if (struct_funcval(prhs[0], "min_objective"))
+              nlopt_set_min_objective(opt, user_function, &d);
+         else
+              nlopt_set_max_objective(opt, user_function, &d);
+     }
 
      if ((mx = mxGetField(prhs[0], 0, "fc"))) {
          int j, m;
@@ -304,6 +356,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
      mxFree(dh);
      mxFree(dfc);
      mxDestroyArray(d.prhs[d.xrhs]);
+     if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]);
      nlopt_destroy(opt);
 
      plhs[0] = x_mx;