chiark / gitweb /
added Matlab plug-in
authorstevenj <stevenj@alum.mit.edu>
Tue, 4 Sep 2007 01:20:15 +0000 (21:20 -0400)
committerstevenj <stevenj@alum.mit.edu>
Tue, 4 Sep 2007 01:20:15 +0000 (21:20 -0400)
darcs-hash:20070904012015-c8de0-b18f22f76ac879415165b98ed8d34c52b062a481.gz

configure.ac
octave/Makefile.am
octave/nlopt_minimize-mex.c [new file with mode: 0644]
octave/nlopt_minimize-oct.cc [moved from octave/nlopt_minimize.cc with 100% similarity]

index 0b1e50485fdb215453cc8190ff191dce7b85d42b..5d188145cabb6c5b0a7b548d8fd70a4d5484d131 100644 (file)
@@ -92,6 +92,47 @@ AC_SUBST(OCT_INSTALL_DIR)
 AC_SUBST(M_INSTALL_DIR)
 AC_SUBST(MKOCTFILE)
 
+dnl -----------------------------------------------------------------------
+dnl Compiling Matlab plug-in
+
+AC_ARG_VAR(MEX_INSTALL_DIR, [where to install Matlab .mex plug-ins])
+AC_ARG_VAR(MEX, [name of mex program to compile Matlab plug-ins])
+AC_CHECK_PROGS(MEX, mex, echo)
+if test "$MEX" = "echo"; then
+       AC_MSG_WARN([can't find mex: won't be able to compile Matlab plugin])
+elif test x"$MEX_INSTALL_DIR" = "x"; then
+     AC_MSG_CHECKING([for extension of compiled mex files])
+     rm -f conftest*
+     cat > conftest.c <<EOF
+#include <mex.h>
+void mexFunction(int nlhs, mxArray *plhs[[]],
+                 int nrhs, const mxArray *prhs[[]]) { }
+EOF
+     if $MEX conftest.c; then
+        MEXSUFF=`ls conftest.m* | head -1 | cut -d'.' -f2`
+       AC_MSG_RESULT($MEXSUFF)
+       AC_CHECK_PROGS(MATLAB, matlab, echo)
+     else
+           AC_MSG_WARN([$MEX failed to compile a simple file; won't compile Matlab plugin])
+       MATLAB=echo
+     fi
+
+     # try to find installation directory
+     if test x"$MATLAB" = xecho; then
+        AC_MSG_WARN([can't fine Matlab; won't compile Matlab plugin])
+     else
+        AC_MSG_CHECKING(for MATLAB mex installation dir)
+       matlabpath_line=`matlab -n | grep -n MATLABPATH |head -1 |cut -d: -f1`
+       matlabpath_line=`expr $matlabpath_line + 1`
+       MEX_INSTALL_DIR=`matlab -n | tail -n +$matlabpath_line | head -1 | tr -d ' '`
+       AC_MSG_RESULT($MEX_INSTALL_DIR)
+      fi
+fi
+AM_CONDITIONAL(WITH_MATLAB, test x"$MEX_INSTALL_DIR" != "x")
+AC_SUBST(MEX_INSTALL_DIR)
+AC_SUBST(MEX)
+AC_SUBST(MEXSUFF)
+
 dnl -----------------------------------------------------------------------
 dnl Debugging
 
index 8bdd59b2842f88651a646b54ba949c69e6ad54c6..5fe95647ecf2009976fadf5daf9341554337e3e8 100644 (file)
@@ -2,22 +2,34 @@ AM_CPPFLAGS = -I$(top_srcdir)/api
 
 MFILES = NLOPT_GN_DIRECT.m NLOPT_GN_DIRECT_L.m NLOPT_GN_DIRECT_L_RAND.m NLOPT_GN_DIRECT_NOSCAL.m NLOPT_GN_DIRECT_L_NOSCAL.m NLOPT_GN_DIRECT_L_RAND_NOSCAL.m NLOPT_GN_ORIG_DIRECT.m NLOPT_GN_ORIG_DIRECT_L.m NLOPT_LN_SUBPLEX.m NLOPT_GD_STOGO.m NLOPT_GD_STOGO_RAND.m NLOPT_LD_LBFGS.m NLOPT_LN_PRAXIS.m NLOPT_LD_VAR1.m NLOPT_LD_VAR2.m 
 
+#######################################################################
 octdir = $(OCT_INSTALL_DIR)
 mdir = $(M_INSTALL_DIR)
-
 if WITH_OCTAVE
 oct_DATA = nlopt_minimize.oct
 m_DATA = $(MFILES)
 endif
 
-nlopt_minimize.oct: nlopt_minimize.cc nlopt_minimize_usage.h
-       $(MKOCTFILE) $(DEFS) $(CPPFLAGS) $(srcdir)/nlopt_minimize.cc $(LDFLAGS) -L$(top_builddir)/.libs -lnlopt
+nlopt_minimize.oct: nlopt_minimize-oct.cc nlopt_minimize_usage.h
+       $(MKOCTFILE) -o $@ $(DEFS) $(DEFAULT_INCLUDES) $(INCLUDES) $(AM_CPPFLAGS) $(srcdir)/nlopt_minimize-oct.cc $(LDFLAGS) -L$(top_builddir)/.libs -lnlopt
+
+#######################################################################
+mexdir = $(MEX_INSTALL_DIR)
+if WITH_MATLAB
+mex_DATA = nlopt_minimize.$(MEXSUFF) $(MFILES)
+endif
+
+nlopt_minimize.$(MEXSUFF): nlopt_minimize-mex.c nlopt_minimize_usage.h
+       $(MEX) -output nlopt_minimize -O $(DEFS) $(DEFAULT_INCLUDES) $(INCLUDES) $(AM_CPPFLAGS) $(srcdir)/nlopt_minimize-mex.c $(LDFLAGS) -L$(top_builddir)/.libs -lnlopt
+
+#######################################################################
 
 nlopt_minimize_usage.h: $(srcdir)/nlopt_minimize.m
        echo "#define NLOPT_MINIMIZE_USAGE \\" > $@
        sed 's/\"/\\"/g' $(srcdir)/nlopt_minimize.m | sed 's,^% ,\",;s,^%,\",;s,$$,\\n\" \\,' >> $@
        echo "" >> $@
 
-EXTRA_DIST = nlopt_minimize.cc $(MFILES)
 
-CLEANFILES = nlopt_minimize.oct nlopt_minimize_usage.h
+EXTRA_DIST = nlopt_minimize-oct.cc nlopt_minimize-mex.c $(MFILES)
+
+CLEANFILES = nlopt_minimize.oct nlopt_minimize_usage.h nlopt_minimize.$(MEXSUFF) nlopt_minimize-oct.o
diff --git a/octave/nlopt_minimize-mex.c b/octave/nlopt_minimize-mex.c
new file mode 100644 (file)
index 0000000..6fef512
--- /dev/null
@@ -0,0 +1,172 @@
+/* Matlab MEX interface to NLopt, and in particular to nlopt_minimize */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <math.h>
+#include <mex.h>
+
+#include "nlopt.h"
+
+#define xSTRIZE(x) #x
+#define STRIZE(x) xSTRIZE(x)
+#define CHECK(cond, msg) if (!(cond)) mexErrMsgTxt(msg);
+
+static double struct_val_default(const mxArray *s, const char *name, double dflt)
+{
+     mxArray *val = mxGetField(s, 0, name);
+     if (val) {
+         CHECK(mxIsNumeric(val) && !mxIsComplex(val) 
+               && mxGetM(val) * mxGetN(val) == 1,
+               "stop fields, other than xtol_abs, must be real scalars");
+         return mxGetScalar(val);
+     }
+     return dflt;
+}
+
+#define FLEN 1024 /* max length of user function name */
+#define MAXRHS 1024 /* max nrhs for user function */
+typedef struct {
+     char f[FLEN];
+     mxArray *plhs[2];
+     mxArray *prhs[MAXRHS];
+     int nrhs;
+} user_function_data;
+
+static double user_function(int n, const double *x,
+                           double *gradient, /* NULL if not needed */
+                           void *d_)
+{
+  user_function_data *d = (user_function_data *) d_;
+  double f;
+
+  d->plhs[0] = d->plhs[1] = NULL;
+  memcpy(mxGetPr(d->prhs[0]), x, n * sizeof(double));
+
+  CHECK(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs, 
+                          d->nrhs, d->prhs, d->f),
+       "error calling user function");
+
+  CHECK(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0]) 
+       && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == 1,
+       "user function must return real scalar");
+  f = mxGetScalar(d->plhs[0]);
+  mxDestroyArray(d->plhs[0]);
+  if (gradient) {
+     CHECK(mxIsDouble(d->plhs[1]) && !mxIsComplex(d->plhs[1])
+          && (mxGetM(d->plhs[1]) == 1 || mxGetN(d->plhs[1]) == 1)
+          && mxGetM(d->plhs[1]) * mxGetN(d->plhs[1]) == n,
+          "gradient vector from user function is the wrong size");
+     memcpy(gradient, mxGetPr(d->plhs[1]), n * sizeof(double));
+     mxDestroyArray(d->plhs[1]);
+  }
+  return f;
+}                               
+
+void mexFunction(int nlhs, mxArray *plhs[],
+                 int nrhs, const mxArray *prhs[])
+{
+     nlopt_algorithm algorithm;
+     int n, i;
+     double *lb, *ub, *x, *x0;
+     double minf_max, ftol_rel, ftol_abs, xtol_rel, *xtol_abs, maxtime;
+     int maxeval;
+     nlopt_result ret;
+     mxArray *x_mx;
+     double minf = HUGE_VAL;
+     user_function_data d;
+
+     CHECK(nrhs == 7 && nlhs <= 3, "wrong number of arguments");
+
+     /* algorithm = prhs[0] */
+     CHECK(mxIsNumeric(prhs[0]) && !mxIsComplex(prhs[0]) 
+          && mxGetM(prhs[0]) * mxGetN(prhs[0]) == 1,
+          "algorithm must be real (integer) scalar");
+     algorithm = (nlopt_algorithm) (mxGetScalar(prhs[0]) + 0.5);
+     CHECK(algorithm >= 0 && algorithm < NLOPT_NUM_ALGORITHMS,
+          "unknown algorithm");
+
+     /* function f = prhs[1] */
+     CHECK(mxIsChar(prhs[1]), "f must be a string");
+     CHECK(mxGetString(prhs[1], d.f, FLEN) == 0,
+         "error reading function name string (too long?)");
+     /* ... for mexCallMATLAB */
+     
+     /* Cell f_data = prhs[2] */
+     CHECK(mxIsCell(prhs[2]), "f_data must be a Cell array");
+     CHECK(mxGetM(prhs[2]) * mxGetN(prhs[2]) + 1 <= MAXRHS,
+          "user function cannot have more than " STRIZE(MAXRHS) " arguments");
+     d.nrhs = mxGetM(prhs[2]) * mxGetN(prhs[2]) + 1;
+     for (i = 0; i < d.nrhs - 1; ++i)
+         d.prhs[1+i] = mxGetCell(prhs[2], i);
+
+     /* lb = prhs[3] */
+     CHECK(mxIsDouble(prhs[3]) && !mxIsComplex(prhs[3])
+          && (mxGetM(prhs[3]) == 1 || mxGetN(prhs[3]) == 1),
+          "lb must be real row or column vector");
+     lb = mxGetPr(prhs[3]);
+     n = mxGetM(prhs[3]) * mxGetN(prhs[3]);
+
+     /* ub = prhs[4] */
+     CHECK(mxIsDouble(prhs[4]) && !mxIsComplex(prhs[4])
+          && (mxGetM(prhs[4]) == 1 || mxGetN(prhs[4]) == 1)
+          && mxGetM(prhs[4]) * mxGetN(prhs[4]) == n,
+          "ub must be real row or column vector of same length as lb");
+     ub = mxGetPr(prhs[4]);
+
+     /* x0 = prhs[5] */
+     CHECK(mxIsDouble(prhs[5]) && !mxIsComplex(prhs[5])
+          && (mxGetM(prhs[5]) == 1 || mxGetN(prhs[5]) == 1)
+          && mxGetM(prhs[5]) * mxGetN(prhs[5]) == n,
+          "x must be real row or column vector of same length as lb");
+     x0 = mxGetPr(prhs[5]);
+
+     /* stopping criteria = prhs[6] */
+     CHECK(mxIsStruct(prhs[6]), "stopping criteria must be a struct");
+     minf_max = struct_val_default(prhs[6], "minf_max", -HUGE_VAL);
+     ftol_rel = struct_val_default(prhs[6], "ftol_rel", 0);
+     ftol_abs = struct_val_default(prhs[6], "ftol_abs", 0);
+     xtol_rel = struct_val_default(prhs[6], "xtol_rel", 0);
+     maxeval = (int) (struct_val_default(prhs[6], "maxeval", -1) + 0.5);
+     maxtime = struct_val_default(prhs[6], "maxtime", -1);
+     {
+         mxArray *val = mxGetField(prhs[6], 0, "xtol_abs");
+         if (val) {
+              CHECK(mxIsNumeric(val) && !mxIsComplex(val) 
+                    && (mxGetM(val) == 1 || mxGetN(val) == 1)
+                    && mxGetM(val) * mxGetN(val) == n,
+                    "stop.xtol_abs must be real row/col vector of length n");
+              xtol_abs = mxGetPr(val);
+         }
+         else
+              xtol_abs = NULL;
+     }
+
+
+     x_mx = mxCreateDoubleMatrix(1, n, mxREAL);
+     x = mxGetPr(x_mx);
+     memcpy(x, x0, sizeof(double) * n);
+
+     d.prhs[0] = mxCreateDoubleMatrix(1, n, mxREAL);
+     
+     ret = nlopt_minimize(algorithm,
+                         n,
+                         user_function, &d,
+                         lb, ub, x, &minf,
+                         minf_max, ftol_rel, ftol_abs, xtol_rel, xtol_abs,
+                         maxeval, maxtime);
+
+     mxDestroyArray(d.prhs[0]);
+
+     plhs[0] = x_mx;
+     if (nlhs > 1) {
+         plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);
+         *(mxGetPr(plhs[1])) = minf;
+     }
+     else
+         mxDestroyArray(d.plhs[0]);
+     if (nlhs > 2) {
+         plhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
+         *(mxGetPr(plhs[2])) = (int) ret;
+     }
+}