chiark / gitweb /
python interface for mconstraint feature
authorstevenj <stevenj@alum.mit.edu>
Tue, 6 Jul 2010 20:58:19 +0000 (16:58 -0400)
committerstevenj <stevenj@alum.mit.edu>
Tue, 6 Jul 2010 20:58:19 +0000 (16:58 -0400)
darcs-hash:20100706205819-c8de0-4544309b57ae3e6abec33e0b145ca178ffbeee2c.gz

swig/nlopt-python.i

index 37fdabc2a8d9bad574c4a90da3144c16ee4e721c..ed58b699b235b17339a2a11b9c6c012eb8741ea7 100644 (file)
@@ -150,6 +150,33 @@ static double func_python(unsigned n, const double *x, double *grad, void *f)
   }
   return val;
 }
+
+static void mfunc_python(unsigned m, double *result,
+                        unsigned n, const double *x, double *grad, void *f)
+{
+  npy_intp nsz = npy_intp(n), msz = npy_intp(m), mnsz = npy_intp(m * n);
+  npy_intp sz0 = 0, stride1 = sizeof(double);
+  PyObject *xpy = PyArray_New(&PyArray_Type, 1, &nsz, NPY_DOUBLE, &stride1,
+                             const_cast<double*>(x), // not NPY_WRITEABLE
+                             0, NPY_C_CONTIGUOUS | NPY_ALIGNED, NULL);
+  PyObject *rpy = PyArray_SimpleNewFromData(1, &msz, NPY_DOUBLE, result);
+  PyObject *gradpy = grad
+    ? PyArray_SimpleNewFromData(1, &mnsz, NPY_DOUBLE, grad)
+    : PyArray_SimpleNew(1, &sz0, NPY_DOUBLE);
+  
+  PyObject *arglist = Py_BuildValue("OOO", rpy, xpy, gradpy);
+  PyObject *res = PyEval_CallObject((PyObject *) f, arglist);
+  Py_XDECREF(res);
+  Py_DECREF(arglist);
+
+  Py_DECREF(gradpy);
+  Py_DECREF(rpy);
+  Py_DECREF(xpy);
+
+  if (PyErr_Occurred()) {
+    throw nlopt::forced_stop(); // just stop, don't call PyErr_Clear()
+  }
+}
 %}
 
 %typemap(in)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
@@ -161,3 +188,13 @@ static double func_python(unsigned n, const double *x, double *grad, void *f)
 %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
   $1 = PyCallable_Check($input);
 }
+
+%typemap(in)(nlopt::mfunc f, void *f_data, nlopt_munge md, nlopt_munge mc) {
+  $1 = mfunc_python;
+  $2 = dup_pyfunc((void*) $input);
+  $3 = free_pyfunc;
+  $4 = dup_pyfunc;
+}
+%typecheck(SWIG_TYPECHECK_POINTER)(nlopt::mfunc f, void *f_data, nlopt_munge md, nlopt_munge mc) {
+  $1 = PyCallable_Check($input);
+}