chiark / gitweb /
support function handles from Matlab
authorstevenj <stevenj@alum.mit.edu>
Tue, 4 Sep 2007 02:15:50 +0000 (22:15 -0400)
committerstevenj <stevenj@alum.mit.edu>
Tue, 4 Sep 2007 02:15:50 +0000 (22:15 -0400)
darcs-hash:20070904021550-c8de0-49623cb50c1970965af5b72315b656c38025c367.gz

octave/Makefile.am
octave/nlopt_minimize-mex.c

index bbc7a4b6a7f2a6f5f7b6e3510444377dce8b77f9..c05a0d73a579d2ae462c24e25009edb9fd3d9ede 100644 (file)
@@ -7,7 +7,7 @@ octdir = $(OCT_INSTALL_DIR)
 mdir = $(M_INSTALL_DIR)
 if WITH_OCTAVE
 oct_DATA = nlopt_minimize.oct
-m_DATA = $(MFILES)
+m_DATA = $(MFILES) nlopt_minimize.m
 endif
 
 nlopt_minimize.oct: nlopt_minimize-oct.cc nlopt_minimize_usage.h
@@ -21,7 +21,7 @@ nlopt_minimize_usage.h: $(srcdir)/nlopt_minimize.m
 #######################################################################
 mexdir = $(MEX_INSTALL_DIR)
 if WITH_MATLAB
-mex_DATA = nlopt_minimize.$(MEXSUFF) $(MFILES)
+mex_DATA = nlopt_minimize.$(MEXSUFF) $(MFILES) nlopt_minimize.m
 endif
 
 nlopt_minimize.$(MEXSUFF): nlopt_minimize-mex.c
@@ -29,6 +29,6 @@ nlopt_minimize.$(MEXSUFF): nlopt_minimize-mex.c
 
 #######################################################################
 
-EXTRA_DIST = nlopt_minimize-oct.cc nlopt_minimize-mex.c $(MFILES)
+EXTRA_DIST = nlopt_minimize-oct.cc nlopt_minimize-mex.c $(MFILES) nlopt_minimize.m
 
 CLEANFILES = nlopt_minimize.oct nlopt_minimize_usage.h nlopt_minimize.$(MEXSUFF) nlopt_minimize-oct.o
index 6fef5126cc9eab6ee22f5cfa0d78ddddfc041857..1d9a6000896f92de8ab2d50b26f2fdd53d786f24 100644 (file)
@@ -30,7 +30,7 @@ typedef struct {
      char f[FLEN];
      mxArray *plhs[2];
      mxArray *prhs[MAXRHS];
-     int nrhs;
+     int xrhs, nrhs;
 } user_function_data;
 
 static double user_function(int n, const double *x,
@@ -41,7 +41,7 @@ static double user_function(int n, const double *x,
   double f;
 
   d->plhs[0] = d->plhs[1] = NULL;
-  memcpy(mxGetPr(d->prhs[0]), x, n * sizeof(double));
+  memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
 
   CHECK(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs, 
                           d->nrhs, d->prhs, d->f),
@@ -87,18 +87,28 @@ void mexFunction(int nlhs, mxArray *plhs[],
           "unknown algorithm");
 
      /* function f = prhs[1] */
-     CHECK(mxIsChar(prhs[1]), "f must be a string");
-     CHECK(mxGetString(prhs[1], d.f, FLEN) == 0,
-         "error reading function name string (too long?)");
-     /* ... for mexCallMATLAB */
+     CHECK(mxIsChar(prhs[1]) || mxIsFunctionHandle(prhs[1]), 
+          "f must be a function handle or function name");
+     if (mxIsChar(prhs[1])) {
+         CHECK(mxGetString(prhs[1], d.f, FLEN) == 0,
+               "error reading function name string (too long?)");
+         d.nrhs = 1;
+         d.xrhs = 0;
+     }
+     else {
+         d.prhs[0] = prhs[1];
+         strcpy(d.f, "feval");
+         d.nrhs = 2;
+         d.xrhs = 1;
+     }
      
      /* Cell f_data = prhs[2] */
      CHECK(mxIsCell(prhs[2]), "f_data must be a Cell array");
      CHECK(mxGetM(prhs[2]) * mxGetN(prhs[2]) + 1 <= MAXRHS,
           "user function cannot have more than " STRIZE(MAXRHS) " arguments");
-     d.nrhs = mxGetM(prhs[2]) * mxGetN(prhs[2]) + 1;
-     for (i = 0; i < d.nrhs - 1; ++i)
-         d.prhs[1+i] = mxGetCell(prhs[2], i);
+     d.nrhs += mxGetM(prhs[2]) * mxGetN(prhs[2]);
+     for (i = 0; i < d.nrhs - (1+d.xrhs); ++i)
+         d.prhs[(1+d.xrhs)+i] = mxGetCell(prhs[2], i);
 
      /* lb = prhs[3] */
      CHECK(mxIsDouble(prhs[3]) && !mxIsComplex(prhs[3])
@@ -147,7 +157,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
      x = mxGetPr(x_mx);
      memcpy(x, x0, sizeof(double) * n);
 
-     d.prhs[0] = mxCreateDoubleMatrix(1, n, mxREAL);
+     d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
      
      ret = nlopt_minimize(algorithm,
                          n,
@@ -156,7 +166,7 @@ void mexFunction(int nlhs, mxArray *plhs[],
                          minf_max, ftol_rel, ftol_abs, xtol_rel, xtol_abs,
                          maxeval, maxtime);
 
-     mxDestroyArray(d.prhs[0]);
+     mxDestroyArray(d.prhs[d.xrhs]);
 
      plhs[0] = x_mx;
      if (nlhs > 1) {