chiark / gitweb /
replace free_f_data flag with more general munging feature, use for proper refcountin...
authorstevenj <stevenj@alum.mit.edu>
Sat, 5 Jun 2010 00:19:16 +0000 (20:19 -0400)
committerstevenj <stevenj@alum.mit.edu>
Sat, 5 Jun 2010 00:19:16 +0000 (20:19 -0400)
darcs-hash:20100605001916-c8de0-37678738283a4bc9aa454687ec1bf49ecc7adca7.gz

api/f77api.c
api/f77funcs_.h
api/nlopt-in.hpp
api/nlopt-internal.h
api/nlopt.h
api/options.c
swig/nlopt-guile.i
swig/nlopt-python.i

index 2961e4be0bcaa95d693c7da27457e28aeadff35b..da709aeb7689f95c0085b48103d4df9a281fe8e1 100644 (file)
@@ -36,6 +36,13 @@ typedef struct {
      void *f_data;
 } f77_func_data;
 
+static void *free_f77_func_data(void *p) { free(p); return NULL; }
+static void *dup_f77_func_data(void *p) { 
+     void *pnew = (void*) malloc(sizeof(f77_func_data));
+     if (pnew) memcpy(pnew, p, sizeof(f77_func_data));
+     return pnew;
+}
+
 static double f77_func_wrap_old(int n, const double *x, double *grad, void *data)
 {
      f77_func_data *d = (f77_func_data *) data;
index bac7e8204792cb2895e02f049dc576c63e1ff221..48b2ab163557287daf15e695824c049e09171791 100644 (file)
@@ -36,17 +36,13 @@ void F77_(nlo_create,NLO_CREATE)(nlopt_opt *opt, int *alg, int *n)
      if (*n < 0) *opt = NULL;
      else {
          *opt = nlopt_create((nlopt_algorithm) *alg, (unsigned) *n);
-         nlopt_set_free_f_data(*opt, 1);
+         nlopt_set_munge(*opt, free_f77_func_data, dup_f77_func_data);
      }
 }
 
 void F77_(nlo_copy,NLO_COPY)(nlopt_opt *nopt, nlopt_opt *opt)
 {
      *nopt = nlopt_copy(*opt);
-     if (*nopt && nlopt_dup_f_data(*nopt, sizeof(f77_func_data)) < 0) {
-         nlopt_destroy(*nopt);
-         *nopt = NULL;
-     }
 }
 
 void F77_(nlo_destroy,NLO_DESTROY)(nlopt_opt *opt)
index 5936d2c4ee49873e99b42bf873755b9f85d6f4fa..a6f2641bc1c549954a1f85eb83d5c92a3442b081 100644 (file)
@@ -88,8 +88,38 @@ namespace nlopt {
       opt *o;
       func f; void *f_data;
       vfunc vf;
+      nlopt_munge munge_destroy, munge_copy; // non-NULL for SWIG wrappers
     } myfunc_data;
 
+    // free/destroy f_data in nlopt_destroy and nlopt_copy, respectively
+    static void *free_myfunc_data(void *p) { 
+      myfunc_data *d = (myfunc_data *) p;
+      if (d) {
+       if (d->f_data && d->munge_destroy) d->munge_destroy(d->f_data);
+       delete d;
+      }
+      return NULL;
+    }
+    static void *dup_myfunc_data(void *p) {
+      myfunc_data *d = (myfunc_data *) p;
+      if (d) {
+       void *f_data;
+       if (d->f_data && d->munge_copy) {
+         f_data = d->munge_copy(d->f_data);
+         if (!f_data) return NULL;
+       }
+       else
+         f_data = d->f_data;
+       myfunc_data *dnew = new myfunc_data;
+       if (dnew) {
+         *dnew = *d;
+         dnew->f_data = f_data;
+       }
+       return (void*) dnew;
+      }
+      else return NULL;
+    }
+
     // nlopt_func wrapper that catches exceptions
     static double myfunc(unsigned n, const double *x, double *grad, void *d_) {
       myfunc_data *d = reinterpret_cast<myfunc_data*>(d_);
@@ -140,20 +170,18 @@ namespace nlopt {
       xtmp(0), gradtmp(0), gradtmp0(0),
       last_result(nlopt::FAILURE), last_optf(HUGE_VAL) {
       if (!o) throw std::bad_alloc();
-      nlopt_set_free_f_data(o, 1);
+      nlopt_set_munge(o, free_myfunc_data, dup_myfunc_data);
     }
     opt(const opt& f) : o(nlopt_copy(f.o)), 
                        xtmp(f.xtmp), gradtmp(f.gradtmp), gradtmp0(0),
                        last_result(f.last_result), last_optf(f.last_optf) {
       if (f.o && !o) throw std::bad_alloc();
-      mythrow(nlopt_dup_f_data(o, sizeof(myfunc_data)));
     }
     opt& operator=(opt const& f) {
       if (this == &f) return *this; // self-assignment
       nlopt_destroy(o);
       o = nlopt_copy(f.o);
       if (f.o && !o) throw std::bad_alloc();
-      mythrow(nlopt_dup_f_data(o, sizeof(myfunc_data)));
       xtmp = f.xtmp; gradtmp = f.gradtmp;
       last_result = f.last_result; last_optf = f.last_optf;
       return *this;
@@ -195,32 +223,55 @@ namespace nlopt {
 
     // Set the objective function
     void set_min_objective(func f, void *f_data) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_set_min_objective(o, myfunc, d)); // d freed via o
     }
     void set_min_objective(vfunc vf, void *f_data) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_set_min_objective(o, myvfunc, d)); // d freed via o
       alloc_tmp();
     }
     void set_max_objective(func f, void *f_data) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_set_max_objective(o, myfunc, d)); // d freed via o
     }
     void set_max_objective(vfunc vf, void *f_data) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_set_max_objective(o, myvfunc, d)); // d freed via o
       alloc_tmp();
     }
 
+    // for internal use in SWIG wrappers -- variant that
+    // takes ownership of f_data, with munging for destroy/copy
+    void set_min_objective(func f, void *f_data,
+                          nlopt_munge md, nlopt_munge mc) {
+      myfunc_data *d = new myfunc_data;
+      if (!d) throw std::bad_alloc();
+      d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = md; d->munge_copy = mc;
+      mythrow(nlopt_set_min_objective(o, myfunc, d)); // d freed via o
+    }
+    void set_max_objective(func f, void *f_data,
+                          nlopt_munge md, nlopt_munge mc) {
+      myfunc_data *d = new myfunc_data;
+      if (!d) throw std::bad_alloc();
+      d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = md; d->munge_copy = mc;
+      mythrow(nlopt_set_max_objective(o, myfunc, d)); // d freed via o
+    }
+
     // Nonlinear constraints:
 
     void remove_inequality_constraints() {
@@ -228,15 +279,17 @@ namespace nlopt {
       mythrow(ret);
     }
     void add_inequality_constraint(func f, void *f_data, double tol=0) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_add_inequality_constraint(o, myfunc, d, tol));
     }
     void add_inequality_constraint(vfunc vf, void *f_data, double tol=0) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_add_inequality_constraint(o, myvfunc, d, tol));
       alloc_tmp();
     }
@@ -246,19 +299,41 @@ namespace nlopt {
       mythrow(ret);
     }
     void add_equality_constraint(func f, void *f_data, double tol=0) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_add_equality_constraint(o, myfunc, d, tol));
     }
     void add_equality_constraint(vfunc vf, void *f_data, double tol=0) {
-      myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data));
+      myfunc_data *d = new myfunc_data;
       if (!d) throw std::bad_alloc();
       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
+      d->munge_destroy = d->munge_copy = NULL;
       mythrow(nlopt_add_equality_constraint(o, myvfunc, d, tol));
       alloc_tmp();
     }
 
+    // For internal use in SWIG wrappers (see also above)
+    void add_inequality_constraint(func f, void *f_data, 
+                                  nlopt_munge md, nlopt_munge mc,
+                                  double tol=0) {
+      myfunc_data *d = new myfunc_data;
+      if (!d) throw std::bad_alloc();
+      d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = md; d->munge_copy = mc;
+      mythrow(nlopt_add_inequality_constraint(o, myfunc, d, tol));
+    }
+    void add_equality_constraint(func f, void *f_data, 
+                                nlopt_munge md, nlopt_munge mc,
+                                double tol=0) {
+      myfunc_data *d = new myfunc_data;
+      if (!d) throw std::bad_alloc();
+      d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
+      d->munge_destroy = md; d->munge_copy = mc;
+      mythrow(nlopt_add_equality_constraint(o, myfunc, d, tol));
+    }
+
 #define NLOPT_GETSET_VEC(name)                                         \
     void set_##name(double val) {                                      \
       mythrow(nlopt_set_##name##1(o, val));                            \
index e3389ef41926d8c5a128e61bc617e3c2c47e9096..fc43db3b411ff7546a287556333f1d1e73202061 100644 (file)
@@ -50,7 +50,7 @@ struct nlopt_opt_s {
      unsigned p_alloc; /* number of inequality constraints allocated */
      nlopt_constraint *h; /* equality constraints, length p_alloc */
 
-     int free_f_data; /* flag (for f77 api) to free f_data in nlopt_destroy */
+     nlopt_munge munge_on_destroy, munge_on_copy; /* hack for wrappers */
 
      /* stopping criteria */
      double stopval; /* stop when f reaches stopval or better */
index 8d45723732a882f1b55e11e68d563ce5dfaedb02..c7857436df73d5143e5d325e9b7c45b4ee152023 100644 (file)
@@ -248,11 +248,14 @@ NLOPT_EXTERN nlopt_result nlopt_set_initial_step1(nlopt_opt opt, double dx);
 NLOPT_EXTERN nlopt_result nlopt_get_initial_step(const nlopt_opt opt, 
                                                 const double *x, double *dx);
 
-/* set to 1: nlopt_destroy should call free() on all of the f_data pointers
-   (for the objective, constraints, etcetera) ... mainly for internal use */
-NLOPT_EXTERN nlopt_result nlopt_set_free_f_data(nlopt_opt opt, int val);
-NLOPT_EXTERN int nlopt_get_free_f_data(const nlopt_opt opt);
-NLOPT_EXTERN nlopt_result nlopt_dup_f_data(nlopt_opt opt, size_t sz);
+/* the following are functions mainly designed to be used internally
+   by the Fortran and SWIG wrappers, allow us to tel nlopt_destroy and
+   nlopt_copy to do something to the f_data pointers (e.g. free or
+   duplicate them, respectively) */
+typedef void* (*nlopt_munge)(void *p);
+NLOPT_EXTERN void nlopt_set_munge(nlopt_opt opt,
+                                 nlopt_munge munge_on_destroy,
+                                 nlopt_munge munge_on_copy);
 
 /*************************** DEPRECATED API **************************/
 /* The new "object-oriented" API is preferred, since it allows us to
index e766ca4bd1c9bde0cc86b9d3ac33d4d6678ec999..1f592d1cb0a81efd9505d7a6a11f55dce1c8debd 100644 (file)
 void nlopt_destroy(nlopt_opt opt)
 {
      if (opt) {
-         if (opt->free_f_data) {
+         if (opt->munge_on_destroy) {
+              nlopt_munge munge = opt->munge_on_destroy;
               unsigned i;
-              free(opt->f_data);
+              munge(opt->f_data);
               for (i = 0; i < opt->m; ++i)
-                   free(opt->fc[i].f_data);
+                   munge(opt->fc[i].f_data);
               for (i = 0; i < opt->p; ++i)
-                   free(opt->h[i].f_data);
+                   munge(opt->h[i].f_data);
          }
          free(opt->lb); free(opt->ub);
          free(opt->xtol_abs);
@@ -65,7 +66,7 @@ nlopt_opt nlopt_create(nlopt_algorithm algorithm, unsigned n)
          opt->n = n;
          opt->f = NULL; opt->f_data = NULL;
          opt->maximize = 0;
-         opt->free_f_data = 0;
+         opt->munge_on_destroy = opt->munge_on_copy = NULL;
 
          opt->lb = opt->ub = NULL;
          opt->m = opt->m_alloc = 0;
@@ -117,7 +118,10 @@ nlopt_opt nlopt_copy(const nlopt_opt opt)
          nopt->local_opt = NULL;
          nopt->dx = NULL;
          opt->force_stop_child = NULL;
-         opt->free_f_data = 0;
+
+         nlopt_munge munge = nopt->munge_on_copy;
+         if (munge && nopt->f_data)
+              if (!(nopt->f_data = munge(nopt->f_data))) goto oom;
 
          if (opt->n > 0) {
               nopt->lb = (double *) malloc(sizeof(double) * (opt->n));
@@ -138,6 +142,12 @@ nlopt_opt nlopt_copy(const nlopt_opt opt)
                                                      * (opt->m));
               if (!nopt->fc) goto oom;
               memcpy(nopt->fc, opt->fc, sizeof(nlopt_constraint) * (opt->m));
+              if (munge)
+                   for (unsigned i = 0; i < opt->m; ++i)
+                        if (nopt->fc[i].f_data &&
+                            !(nopt->fc[i].f_data
+                              = munge(nopt->fc[i].f_data)))
+                             goto oom;
          }
 
          if (opt->p) {
@@ -146,6 +156,12 @@ nlopt_opt nlopt_copy(const nlopt_opt opt)
                                                     * (opt->p));
               if (!nopt->h) goto oom;
               memcpy(nopt->h, opt->h, sizeof(nlopt_constraint) * (opt->p));
+              if (munge)
+                   for (unsigned i = 0; i < opt->p; ++i)
+                        if (nopt->h[i].f_data &&
+                            !(nopt->h[i].f_data
+                              = munge(nopt->h[i].f_data)))
+                             goto oom;
          }
 
          if (opt->local_opt) {
@@ -162,6 +178,7 @@ nlopt_opt nlopt_copy(const nlopt_opt opt)
      return nopt;
 
 oom:
+     nopt->munge_on_destroy = NULL; // better to leak mem than crash
      nlopt_destroy(nopt);
      return NULL;
 }
@@ -262,6 +279,11 @@ nlopt_result nlopt_get_upper_bounds(nlopt_opt opt, double *ub)
 nlopt_result nlopt_remove_inequality_constraints(nlopt_opt opt)
 {
      if (!opt) return NLOPT_INVALID_ARGS;
+     if (opt->munge_on_destroy) {
+         nlopt_munge munge = opt->munge_on_destroy;
+         for (unsigned i = 0; i < opt->m; ++i)
+              munge(opt->fc[i].f_data);
+     }
      free(opt->fc);
      opt->fc = NULL;
      opt->m = opt->m_alloc = 0;
@@ -315,6 +337,11 @@ nlopt_result nlopt_add_inequality_constraint(nlopt_opt opt,
 nlopt_result nlopt_remove_equality_constraints(nlopt_opt opt)
 {
      if (!opt) return NLOPT_INVALID_ARGS;
+     if (opt->munge_on_destroy) {
+         nlopt_munge munge = opt->munge_on_destroy;
+         for (unsigned i = 0; i < opt->p; ++i)
+              munge(opt->h[i].f_data);
+     }
      free(opt->h);
      opt->h = NULL;
      opt->p = opt->p_alloc = 0;
@@ -528,40 +555,13 @@ nlopt_result nlopt_set_default_initial_step(nlopt_opt opt, const double *x)
 
 /*************************************************************************/
 
-GETSET(free_f_data, int, free_f_data)
-
-/* the dup_f_data function replaces all f_data pointers with a new
-   pointer to a duplicate block of memory, assuming all non-NULL
-   f_data pointers point to a block of sz bytes...  this is pretty
-   exclusively intended for internal use (e.g. it may lead to a
-   double-free if one subsequently calles add_inequality_constraint
-   etc.), e.g. in the C++ API */
-
-static int dup(void **p, size_t sz) {
-     if (*p) {
-         void *pdup = malloc(sz);
-         if (pdup) {
-              memcpy(pdup, *p, sz);
-              *p = pdup;
-              return 1;
-         }
-         else return 0;
-     }
-     else return 1;
-}
-
-nlopt_result nlopt_dup_f_data(nlopt_opt opt, size_t sz) {
+void nlopt_set_munge(nlopt_opt opt,
+                    nlopt_munge munge_on_destroy,
+                    nlopt_munge munge_on_copy) {
      if (opt) {
-         unsigned i;
-         if (!dup(&opt->f_data, sz)) return NLOPT_OUT_OF_MEMORY;
-         for (i = 0; i < opt->m; ++i)
-              if (!dup(&opt->fc[i].f_data, sz)) return NLOPT_OUT_OF_MEMORY;
-         for (i = 0; i < opt->p; ++i)
-              if (!dup(&opt->h[i].f_data, sz)) return NLOPT_OUT_OF_MEMORY;
-         nlopt_set_free_f_data(opt, 1); // nlopt_destroy must now free f_data!
-         return NLOPT_SUCCESS;
+         opt->munge_on_destroy = munge_on_destroy;
+         opt->munge_on_copy = munge_on_copy;
      }
-     return NLOPT_INVALID_ARGS;
 }
 
 /*************************************************************************/
index 0c705183133dce1ed228c493e2002384519341a1..1978a130b9c275b89bb64c62702a8c5d8369adb1 100644 (file)
@@ -1,23 +1,31 @@
 // -*- C++ -*-
 
 %{
+// because our f_data pointer to the Scheme function is stored on the
+// heap, rather than the stack, it may be missed by the Guile garbage
+// collection and be accidentally freed.  Hence, use NLopts munge
+// feature to prevent this, by incrementing Guile's reference count.
+static void *free_guilefunc(void *p) { 
+  scm_gc_unprotect_object((SCM) p); return p; }
+static void *dup_guilefunc(void *p) { 
+  scm_gc_protect_object((SCM) p); return p; }
+
 // vfunc wrapper around Guile function (val . grad) = f(x)
-static double vfunc_guile(const std::vector<double> &x,
-                          std::vector<double> &grad, void *f) {
-  SCM xscm = scm_c_make_vector(x.size(), SCM_UNSPECIFIED);
-  for (unsigned i = 0; i < x.size(); ++i)
+static double func_guile(unsigned n, const double *x, double *grad, void *f) {
+  SCM xscm = scm_c_make_vector(n, SCM_UNSPECIFIED);
+  for (unsigned i = 0; i < n; ++i)
     scm_c_vector_set_x(xscm, i, scm_make_real(x[i]));
   SCM ret = scm_call_1((SCM) f, xscm);
   if (scm_real_p(ret)) {
-    if (grad.size()) throw std::invalid_argument("missing gradient");
+    if (grad) throw std::invalid_argument("missing gradient");
     return scm_to_double(ret);
   }
   else if (scm_is_pair(ret)) { /* must be (cons value gradient) */
     SCM valscm = SCM_CAR(ret), grad_scm = grad_scm;
-    if (grad.size() > 0
+    if (grad
        && scm_is_vector(grad_scm)
-       && scm_c_vector_length(grad_scm) == grad.size()) {
-      for (unsigned i = 0; i < grad.size(); ++i)
+       && scm_c_vector_length(grad_scm) == n) {
+      for (unsigned i = 0; i < n; ++i)
        grad[i] = scm_to_double(scm_c_vector_ref(grad_scm, i));
     }
     else throw std::invalid_argument("invalid gradient");
@@ -28,9 +36,11 @@ static double vfunc_guile(const std::vector<double> &x,
 }
 %}
 
-%typemap(in)(nlopt::vfunc vf, void *f_data) {
+%typemap(in)(nlopt::vfunc vf, void *f_data, nlopt_munge md, nlopt_munge mc) {
   $1 = vfunc_guile;
-  $2 = (void*) $input; // input is SCM pointer to Scheme function
+  $2 = dup_guilefunc((void*) $input); // input = SCM pointer to Scheme function
+  $3 = free_guilefunc;
+  $4 = dup_guilefunc;
 }
 %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::vfunc vf, void *f_data) {
   $1 = SCM_NFALSEP(scm_procedure_p($input));
index 32e3305808a018a972b2ae497a53702757a7f9b9..cd46acde0eac5d72e2bbd5d867498d5dd1a4dd58 100644 (file)
@@ -59,6 +59,9 @@
 // Wrapper for objective function callbacks
 
 %{
+static void *free_pyfunc(void *p) { Py_DECREF((PyObject*) p); return p; }
+static void *dup_pyfunc(void *p) { Py_INCREF((PyObject*) p); return p; }
+
 static double func_python(unsigned n, const double *x, double *grad, void *f)
 {
   npy_intp sz = npy_intp(n), sz0 = 0;
@@ -85,10 +88,11 @@ static double func_python(unsigned n, const double *x, double *grad, void *f)
 }
 %}
 
-%typemap(in)(nlopt::func f, void *f_data) {
-  Py_INCREF($input);
+%typemap(in)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
   $1 = func_python;
-  $2 = (void*) $input;
+  $2 = dup_pyfunc((void*) $input);
+  $3 = free_pyfunc;
+  $4 = dup_pyfunc;
 }
 %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, void *f_data) {
   $1 = PyCallable_Check($input);