From: stevenj Date: Sat, 5 Jun 2010 00:19:16 +0000 (-0400) Subject: replace free_f_data flag with more general munging feature, use for proper refcountin... X-Git-Url: http://www.chiark.greenend.org.uk/ucgi/~ianmdlvl/git?a=commitdiff_plain;h=6998400850365868de2c19ddc9b8146a2a57f8b6;p=nlopt.git replace free_f_data flag with more general munging feature, use for proper refcounting of function pointers in SWIG guile/python wrappers darcs-hash:20100605001916-c8de0-37678738283a4bc9aa454687ec1bf49ecc7adca7.gz --- diff --git a/api/f77api.c b/api/f77api.c index 2961e4b..da709ae 100644 --- a/api/f77api.c +++ b/api/f77api.c @@ -36,6 +36,13 @@ typedef struct { void *f_data; } f77_func_data; +static void *free_f77_func_data(void *p) { free(p); return NULL; } +static void *dup_f77_func_data(void *p) { + void *pnew = (void*) malloc(sizeof(f77_func_data)); + if (pnew) memcpy(pnew, p, sizeof(f77_func_data)); + return pnew; +} + static double f77_func_wrap_old(int n, const double *x, double *grad, void *data) { f77_func_data *d = (f77_func_data *) data; diff --git a/api/f77funcs_.h b/api/f77funcs_.h index bac7e82..48b2ab1 100644 --- a/api/f77funcs_.h +++ b/api/f77funcs_.h @@ -36,17 +36,13 @@ void F77_(nlo_create,NLO_CREATE)(nlopt_opt *opt, int *alg, int *n) if (*n < 0) *opt = NULL; else { *opt = nlopt_create((nlopt_algorithm) *alg, (unsigned) *n); - nlopt_set_free_f_data(*opt, 1); + nlopt_set_munge(*opt, free_f77_func_data, dup_f77_func_data); } } void F77_(nlo_copy,NLO_COPY)(nlopt_opt *nopt, nlopt_opt *opt) { *nopt = nlopt_copy(*opt); - if (*nopt && nlopt_dup_f_data(*nopt, sizeof(f77_func_data)) < 0) { - nlopt_destroy(*nopt); - *nopt = NULL; - } } void F77_(nlo_destroy,NLO_DESTROY)(nlopt_opt *opt) diff --git a/api/nlopt-in.hpp b/api/nlopt-in.hpp index 5936d2c..a6f2641 100644 --- a/api/nlopt-in.hpp +++ b/api/nlopt-in.hpp @@ -88,8 +88,38 @@ namespace nlopt { opt *o; func f; void *f_data; vfunc vf; + nlopt_munge munge_destroy, munge_copy; // non-NULL for SWIG wrappers } myfunc_data; + // free/destroy f_data in nlopt_destroy and nlopt_copy, respectively + static void *free_myfunc_data(void *p) { + myfunc_data *d = (myfunc_data *) p; + if (d) { + if (d->f_data && d->munge_destroy) d->munge_destroy(d->f_data); + delete d; + } + return NULL; + } + static void *dup_myfunc_data(void *p) { + myfunc_data *d = (myfunc_data *) p; + if (d) { + void *f_data; + if (d->f_data && d->munge_copy) { + f_data = d->munge_copy(d->f_data); + if (!f_data) return NULL; + } + else + f_data = d->f_data; + myfunc_data *dnew = new myfunc_data; + if (dnew) { + *dnew = *d; + dnew->f_data = f_data; + } + return (void*) dnew; + } + else return NULL; + } + // nlopt_func wrapper that catches exceptions static double myfunc(unsigned n, const double *x, double *grad, void *d_) { myfunc_data *d = reinterpret_cast(d_); @@ -140,20 +170,18 @@ namespace nlopt { xtmp(0), gradtmp(0), gradtmp0(0), last_result(nlopt::FAILURE), last_optf(HUGE_VAL) { if (!o) throw std::bad_alloc(); - nlopt_set_free_f_data(o, 1); + nlopt_set_munge(o, free_myfunc_data, dup_myfunc_data); } opt(const opt& f) : o(nlopt_copy(f.o)), xtmp(f.xtmp), gradtmp(f.gradtmp), gradtmp0(0), last_result(f.last_result), last_optf(f.last_optf) { if (f.o && !o) throw std::bad_alloc(); - mythrow(nlopt_dup_f_data(o, sizeof(myfunc_data))); } opt& operator=(opt const& f) { if (this == &f) return *this; // self-assignment nlopt_destroy(o); o = nlopt_copy(f.o); if (f.o && !o) throw std::bad_alloc(); - mythrow(nlopt_dup_f_data(o, sizeof(myfunc_data))); xtmp = f.xtmp; gradtmp = f.gradtmp; last_result = f.last_result; last_optf = f.last_optf; return *this; @@ -195,32 +223,55 @@ namespace nlopt { // Set the objective function void set_min_objective(func f, void *f_data) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_set_min_objective(o, myfunc, d)); // d freed via o } void set_min_objective(vfunc vf, void *f_data) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_set_min_objective(o, myvfunc, d)); // d freed via o alloc_tmp(); } void set_max_objective(func f, void *f_data) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_set_max_objective(o, myfunc, d)); // d freed via o } void set_max_objective(vfunc vf, void *f_data) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_set_max_objective(o, myvfunc, d)); // d freed via o alloc_tmp(); } + // for internal use in SWIG wrappers -- variant that + // takes ownership of f_data, with munging for destroy/copy + void set_min_objective(func f, void *f_data, + nlopt_munge md, nlopt_munge mc) { + myfunc_data *d = new myfunc_data; + if (!d) throw std::bad_alloc(); + d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = md; d->munge_copy = mc; + mythrow(nlopt_set_min_objective(o, myfunc, d)); // d freed via o + } + void set_max_objective(func f, void *f_data, + nlopt_munge md, nlopt_munge mc) { + myfunc_data *d = new myfunc_data; + if (!d) throw std::bad_alloc(); + d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = md; d->munge_copy = mc; + mythrow(nlopt_set_max_objective(o, myfunc, d)); // d freed via o + } + // Nonlinear constraints: void remove_inequality_constraints() { @@ -228,15 +279,17 @@ namespace nlopt { mythrow(ret); } void add_inequality_constraint(func f, void *f_data, double tol=0) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_add_inequality_constraint(o, myfunc, d, tol)); } void add_inequality_constraint(vfunc vf, void *f_data, double tol=0) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_add_inequality_constraint(o, myvfunc, d, tol)); alloc_tmp(); } @@ -246,19 +299,41 @@ namespace nlopt { mythrow(ret); } void add_equality_constraint(func f, void *f_data, double tol=0) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_add_equality_constraint(o, myfunc, d, tol)); } void add_equality_constraint(vfunc vf, void *f_data, double tol=0) { - myfunc_data *d = (myfunc_data *) std::malloc(sizeof(myfunc_data)); + myfunc_data *d = new myfunc_data; if (!d) throw std::bad_alloc(); d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf; + d->munge_destroy = d->munge_copy = NULL; mythrow(nlopt_add_equality_constraint(o, myvfunc, d, tol)); alloc_tmp(); } + // For internal use in SWIG wrappers (see also above) + void add_inequality_constraint(func f, void *f_data, + nlopt_munge md, nlopt_munge mc, + double tol=0) { + myfunc_data *d = new myfunc_data; + if (!d) throw std::bad_alloc(); + d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = md; d->munge_copy = mc; + mythrow(nlopt_add_inequality_constraint(o, myfunc, d, tol)); + } + void add_equality_constraint(func f, void *f_data, + nlopt_munge md, nlopt_munge mc, + double tol=0) { + myfunc_data *d = new myfunc_data; + if (!d) throw std::bad_alloc(); + d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL; + d->munge_destroy = md; d->munge_copy = mc; + mythrow(nlopt_add_equality_constraint(o, myfunc, d, tol)); + } + #define NLOPT_GETSET_VEC(name) \ void set_##name(double val) { \ mythrow(nlopt_set_##name##1(o, val)); \ diff --git a/api/nlopt-internal.h b/api/nlopt-internal.h index e3389ef..fc43db3 100644 --- a/api/nlopt-internal.h +++ b/api/nlopt-internal.h @@ -50,7 +50,7 @@ struct nlopt_opt_s { unsigned p_alloc; /* number of inequality constraints allocated */ nlopt_constraint *h; /* equality constraints, length p_alloc */ - int free_f_data; /* flag (for f77 api) to free f_data in nlopt_destroy */ + nlopt_munge munge_on_destroy, munge_on_copy; /* hack for wrappers */ /* stopping criteria */ double stopval; /* stop when f reaches stopval or better */ diff --git a/api/nlopt.h b/api/nlopt.h index 8d45723..c785743 100644 --- a/api/nlopt.h +++ b/api/nlopt.h @@ -248,11 +248,14 @@ NLOPT_EXTERN nlopt_result nlopt_set_initial_step1(nlopt_opt opt, double dx); NLOPT_EXTERN nlopt_result nlopt_get_initial_step(const nlopt_opt opt, const double *x, double *dx); -/* set to 1: nlopt_destroy should call free() on all of the f_data pointers - (for the objective, constraints, etcetera) ... mainly for internal use */ -NLOPT_EXTERN nlopt_result nlopt_set_free_f_data(nlopt_opt opt, int val); -NLOPT_EXTERN int nlopt_get_free_f_data(const nlopt_opt opt); -NLOPT_EXTERN nlopt_result nlopt_dup_f_data(nlopt_opt opt, size_t sz); +/* the following are functions mainly designed to be used internally + by the Fortran and SWIG wrappers, allow us to tel nlopt_destroy and + nlopt_copy to do something to the f_data pointers (e.g. free or + duplicate them, respectively) */ +typedef void* (*nlopt_munge)(void *p); +NLOPT_EXTERN void nlopt_set_munge(nlopt_opt opt, + nlopt_munge munge_on_destroy, + nlopt_munge munge_on_copy); /*************************** DEPRECATED API **************************/ /* The new "object-oriented" API is preferred, since it allows us to diff --git a/api/options.c b/api/options.c index e766ca4..1f592d1 100644 --- a/api/options.c +++ b/api/options.c @@ -34,13 +34,14 @@ void nlopt_destroy(nlopt_opt opt) { if (opt) { - if (opt->free_f_data) { + if (opt->munge_on_destroy) { + nlopt_munge munge = opt->munge_on_destroy; unsigned i; - free(opt->f_data); + munge(opt->f_data); for (i = 0; i < opt->m; ++i) - free(opt->fc[i].f_data); + munge(opt->fc[i].f_data); for (i = 0; i < opt->p; ++i) - free(opt->h[i].f_data); + munge(opt->h[i].f_data); } free(opt->lb); free(opt->ub); free(opt->xtol_abs); @@ -65,7 +66,7 @@ nlopt_opt nlopt_create(nlopt_algorithm algorithm, unsigned n) opt->n = n; opt->f = NULL; opt->f_data = NULL; opt->maximize = 0; - opt->free_f_data = 0; + opt->munge_on_destroy = opt->munge_on_copy = NULL; opt->lb = opt->ub = NULL; opt->m = opt->m_alloc = 0; @@ -117,7 +118,10 @@ nlopt_opt nlopt_copy(const nlopt_opt opt) nopt->local_opt = NULL; nopt->dx = NULL; opt->force_stop_child = NULL; - opt->free_f_data = 0; + + nlopt_munge munge = nopt->munge_on_copy; + if (munge && nopt->f_data) + if (!(nopt->f_data = munge(nopt->f_data))) goto oom; if (opt->n > 0) { nopt->lb = (double *) malloc(sizeof(double) * (opt->n)); @@ -138,6 +142,12 @@ nlopt_opt nlopt_copy(const nlopt_opt opt) * (opt->m)); if (!nopt->fc) goto oom; memcpy(nopt->fc, opt->fc, sizeof(nlopt_constraint) * (opt->m)); + if (munge) + for (unsigned i = 0; i < opt->m; ++i) + if (nopt->fc[i].f_data && + !(nopt->fc[i].f_data + = munge(nopt->fc[i].f_data))) + goto oom; } if (opt->p) { @@ -146,6 +156,12 @@ nlopt_opt nlopt_copy(const nlopt_opt opt) * (opt->p)); if (!nopt->h) goto oom; memcpy(nopt->h, opt->h, sizeof(nlopt_constraint) * (opt->p)); + if (munge) + for (unsigned i = 0; i < opt->p; ++i) + if (nopt->h[i].f_data && + !(nopt->h[i].f_data + = munge(nopt->h[i].f_data))) + goto oom; } if (opt->local_opt) { @@ -162,6 +178,7 @@ nlopt_opt nlopt_copy(const nlopt_opt opt) return nopt; oom: + nopt->munge_on_destroy = NULL; // better to leak mem than crash nlopt_destroy(nopt); return NULL; } @@ -262,6 +279,11 @@ nlopt_result nlopt_get_upper_bounds(nlopt_opt opt, double *ub) nlopt_result nlopt_remove_inequality_constraints(nlopt_opt opt) { if (!opt) return NLOPT_INVALID_ARGS; + if (opt->munge_on_destroy) { + nlopt_munge munge = opt->munge_on_destroy; + for (unsigned i = 0; i < opt->m; ++i) + munge(opt->fc[i].f_data); + } free(opt->fc); opt->fc = NULL; opt->m = opt->m_alloc = 0; @@ -315,6 +337,11 @@ nlopt_result nlopt_add_inequality_constraint(nlopt_opt opt, nlopt_result nlopt_remove_equality_constraints(nlopt_opt opt) { if (!opt) return NLOPT_INVALID_ARGS; + if (opt->munge_on_destroy) { + nlopt_munge munge = opt->munge_on_destroy; + for (unsigned i = 0; i < opt->p; ++i) + munge(opt->h[i].f_data); + } free(opt->h); opt->h = NULL; opt->p = opt->p_alloc = 0; @@ -528,40 +555,13 @@ nlopt_result nlopt_set_default_initial_step(nlopt_opt opt, const double *x) /*************************************************************************/ -GETSET(free_f_data, int, free_f_data) - -/* the dup_f_data function replaces all f_data pointers with a new - pointer to a duplicate block of memory, assuming all non-NULL - f_data pointers point to a block of sz bytes... this is pretty - exclusively intended for internal use (e.g. it may lead to a - double-free if one subsequently calles add_inequality_constraint - etc.), e.g. in the C++ API */ - -static int dup(void **p, size_t sz) { - if (*p) { - void *pdup = malloc(sz); - if (pdup) { - memcpy(pdup, *p, sz); - *p = pdup; - return 1; - } - else return 0; - } - else return 1; -} - -nlopt_result nlopt_dup_f_data(nlopt_opt opt, size_t sz) { +void nlopt_set_munge(nlopt_opt opt, + nlopt_munge munge_on_destroy, + nlopt_munge munge_on_copy) { if (opt) { - unsigned i; - if (!dup(&opt->f_data, sz)) return NLOPT_OUT_OF_MEMORY; - for (i = 0; i < opt->m; ++i) - if (!dup(&opt->fc[i].f_data, sz)) return NLOPT_OUT_OF_MEMORY; - for (i = 0; i < opt->p; ++i) - if (!dup(&opt->h[i].f_data, sz)) return NLOPT_OUT_OF_MEMORY; - nlopt_set_free_f_data(opt, 1); // nlopt_destroy must now free f_data! - return NLOPT_SUCCESS; + opt->munge_on_destroy = munge_on_destroy; + opt->munge_on_copy = munge_on_copy; } - return NLOPT_INVALID_ARGS; } /*************************************************************************/ diff --git a/swig/nlopt-guile.i b/swig/nlopt-guile.i index 0c70518..1978a13 100644 --- a/swig/nlopt-guile.i +++ b/swig/nlopt-guile.i @@ -1,23 +1,31 @@ // -*- C++ -*- %{ +// because our f_data pointer to the Scheme function is stored on the +// heap, rather than the stack, it may be missed by the Guile garbage +// collection and be accidentally freed. Hence, use NLopts munge +// feature to prevent this, by incrementing Guile's reference count. +static void *free_guilefunc(void *p) { + scm_gc_unprotect_object((SCM) p); return p; } +static void *dup_guilefunc(void *p) { + scm_gc_protect_object((SCM) p); return p; } + // vfunc wrapper around Guile function (val . grad) = f(x) -static double vfunc_guile(const std::vector &x, - std::vector &grad, void *f) { - SCM xscm = scm_c_make_vector(x.size(), SCM_UNSPECIFIED); - for (unsigned i = 0; i < x.size(); ++i) +static double func_guile(unsigned n, const double *x, double *grad, void *f) { + SCM xscm = scm_c_make_vector(n, SCM_UNSPECIFIED); + for (unsigned i = 0; i < n; ++i) scm_c_vector_set_x(xscm, i, scm_make_real(x[i])); SCM ret = scm_call_1((SCM) f, xscm); if (scm_real_p(ret)) { - if (grad.size()) throw std::invalid_argument("missing gradient"); + if (grad) throw std::invalid_argument("missing gradient"); return scm_to_double(ret); } else if (scm_is_pair(ret)) { /* must be (cons value gradient) */ SCM valscm = SCM_CAR(ret), grad_scm = grad_scm; - if (grad.size() > 0 + if (grad && scm_is_vector(grad_scm) - && scm_c_vector_length(grad_scm) == grad.size()) { - for (unsigned i = 0; i < grad.size(); ++i) + && scm_c_vector_length(grad_scm) == n) { + for (unsigned i = 0; i < n; ++i) grad[i] = scm_to_double(scm_c_vector_ref(grad_scm, i)); } else throw std::invalid_argument("invalid gradient"); @@ -28,9 +36,11 @@ static double vfunc_guile(const std::vector &x, } %} -%typemap(in)(nlopt::vfunc vf, void *f_data) { +%typemap(in)(nlopt::vfunc vf, void *f_data, nlopt_munge md, nlopt_munge mc) { $1 = vfunc_guile; - $2 = (void*) $input; // input is SCM pointer to Scheme function + $2 = dup_guilefunc((void*) $input); // input = SCM pointer to Scheme function + $3 = free_guilefunc; + $4 = dup_guilefunc; } %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::vfunc vf, void *f_data) { $1 = SCM_NFALSEP(scm_procedure_p($input)); diff --git a/swig/nlopt-python.i b/swig/nlopt-python.i index 32e3305..cd46acd 100644 --- a/swig/nlopt-python.i +++ b/swig/nlopt-python.i @@ -59,6 +59,9 @@ // Wrapper for objective function callbacks %{ +static void *free_pyfunc(void *p) { Py_DECREF((PyObject*) p); return p; } +static void *dup_pyfunc(void *p) { Py_INCREF((PyObject*) p); return p; } + static double func_python(unsigned n, const double *x, double *grad, void *f) { npy_intp sz = npy_intp(n), sz0 = 0; @@ -85,10 +88,11 @@ static double func_python(unsigned n, const double *x, double *grad, void *f) } %} -%typemap(in)(nlopt::func f, void *f_data) { - Py_INCREF($input); +%typemap(in)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) { $1 = func_python; - $2 = (void*) $input; + $2 = dup_pyfunc((void*) $input); + $3 = free_pyfunc; + $4 = dup_pyfunc; } %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, void *f_data) { $1 = PyCallable_Check($input);