chiark / gitweb /
replace free_f_data flag with more general munging feature, use for proper refcountin...
[nlopt.git] / api / nlopt-in.hpp
1 /* Copyright (c) 2007-2010 Massachusetts Institute of Technology
2  *
3  * Permission is hereby granted, free of charge, to any person obtaining
4  * a copy of this software and associated documentation files (the
5  * "Software"), to deal in the Software without restriction, including
6  * without limitation the rights to use, copy, modify, merge, publish,
7  * distribute, sublicense, and/or sell copies of the Software, and to
8  * permit persons to whom the Software is furnished to do so, subject to
9  * the following conditions:
10  * 
11  * The above copyright notice and this permission notice shall be
12  * included in all copies or substantial portions of the Software.
13  * 
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 
21  */
22
23 // C++ style wrapper around NLopt API
24 // nlopt.hpp is AUTOMATICALLY GENERATED from nlopt-in.hpp - edit the latter!
25
26 #include <nlopt.h>
27
28 #include <vector>
29 #include <stdexcept>
30 #include <new>
31 #include <cstdlib>
32 #include <cstring>
33 #include <cmath>
34
35 // convenience overloading for below (not in nlopt:: since has nlopt_ prefix)
36 inline nlopt_result nlopt_get_initial_step(const nlopt_opt opt, double *dx) {
37       return nlopt_get_initial_step(opt, (const double *) NULL, dx);
38 }
39
40 namespace nlopt {
41
42   //////////////////////////////////////////////////////////////////////
43   // nlopt::* namespace versions of the C enumerated types
44   //          AUTOMATICALLY GENERATED, DO NOT EDIT
45   // GEN_ENUMS_HERE
46   //////////////////////////////////////////////////////////////////////
47
48   typedef nlopt_func func; // nlopt::func synoynm
49
50   // alternative to nlopt_func that takes std::vector<double>
51   // ... unfortunately requires a data copy
52   typedef double (*vfunc)(const std::vector<double> &x,
53                           std::vector<double> &grad, void *data);
54
55   //////////////////////////////////////////////////////////////////////
56   
57   // NLopt-specific exceptions (corresponding to error codes):
58   class roundoff_limited : public std::runtime_error {
59   public:
60     roundoff_limited() : std::runtime_error("nlopt roundoff-limited") {}
61   };
62
63   class forced_stop : public std::runtime_error {
64   public:
65     forced_stop() : std::runtime_error("nlopt forced stop") {}
66   };
67
68   //////////////////////////////////////////////////////////////////////
69
70   class opt {
71   private:
72     nlopt_opt o;
73     result last_result;
74     double last_optf;
75     
76     void mythrow(nlopt_result ret) const {
77       switch (ret) {
78       case NLOPT_FAILURE: throw std::runtime_error("nlopt failure");
79       case NLOPT_OUT_OF_MEMORY: throw std::bad_alloc();
80       case NLOPT_INVALID_ARGS: throw std::invalid_argument("nlopt invalid argument");
81       case NLOPT_ROUNDOFF_LIMITED: throw roundoff_limited();
82       case NLOPT_FORCED_STOP: throw forced_stop();
83       default: break;
84       }
85     }
86
87     typedef struct {
88       opt *o;
89       func f; void *f_data;
90       vfunc vf;
91       nlopt_munge munge_destroy, munge_copy; // non-NULL for SWIG wrappers
92     } myfunc_data;
93
94     // free/destroy f_data in nlopt_destroy and nlopt_copy, respectively
95     static void *free_myfunc_data(void *p) { 
96       myfunc_data *d = (myfunc_data *) p;
97       if (d) {
98         if (d->f_data && d->munge_destroy) d->munge_destroy(d->f_data);
99         delete d;
100       }
101       return NULL;
102     }
103     static void *dup_myfunc_data(void *p) {
104       myfunc_data *d = (myfunc_data *) p;
105       if (d) {
106         void *f_data;
107         if (d->f_data && d->munge_copy) {
108           f_data = d->munge_copy(d->f_data);
109           if (!f_data) return NULL;
110         }
111         else
112           f_data = d->f_data;
113         myfunc_data *dnew = new myfunc_data;
114         if (dnew) {
115           *dnew = *d;
116           dnew->f_data = f_data;
117         }
118         return (void*) dnew;
119       }
120       else return NULL;
121     }
122
123     // nlopt_func wrapper that catches exceptions
124     static double myfunc(unsigned n, const double *x, double *grad, void *d_) {
125       myfunc_data *d = reinterpret_cast<myfunc_data*>(d_);
126       try {
127         return d->f(n, x, grad, d->f_data);
128       }
129       catch (...) {
130         d->o->force_stop(); // stop gracefully, opt::optimize will re-throw
131         return HUGE_VAL;
132       }
133     }
134
135     std::vector<double> xtmp, gradtmp, gradtmp0; // scratch for myvfunc
136
137     // nlopt_func wrapper, using std::vector<double>
138     static double myvfunc(unsigned n, const double *x, double *grad, void *d_){
139       myfunc_data *d = reinterpret_cast<myfunc_data*>(d_);
140       try {
141         std::vector<double> &xv = d->o->xtmp;
142         if (n) std::memcpy(&xv[0], x, n * sizeof(double));
143         double val=d->vf(xv, grad ? d->o->gradtmp : d->o->gradtmp0, d->f_data);
144         if (grad && n) {
145           std::vector<double> &gradv = d->o->gradtmp;
146           std::memcpy(grad, &gradv[0], n * sizeof(double));
147         }
148         return val;
149       }
150       catch (...) {
151         d->o->force_stop(); // stop gracefully, opt::optimize will re-throw
152         return HUGE_VAL;
153       }
154     }
155
156     void alloc_tmp() {
157       if (xtmp.size() != nlopt_get_dimension(o)) {
158         xtmp = std::vector<double>(nlopt_get_dimension(o));
159         gradtmp = std::vector<double>(nlopt_get_dimension(o));
160       }
161     }
162
163   public:
164     // Constructors etc.
165     opt() : o(NULL), xtmp(0), gradtmp(0), gradtmp0(0), 
166             last_result(nlopt::FAILURE), last_optf(HUGE_VAL) {}
167     ~opt() { nlopt_destroy(o); }
168     opt(algorithm a, unsigned n) : 
169       o(nlopt_create(nlopt_algorithm(a), n)), 
170       xtmp(0), gradtmp(0), gradtmp0(0),
171       last_result(nlopt::FAILURE), last_optf(HUGE_VAL) {
172       if (!o) throw std::bad_alloc();
173       nlopt_set_munge(o, free_myfunc_data, dup_myfunc_data);
174     }
175     opt(const opt& f) : o(nlopt_copy(f.o)), 
176                         xtmp(f.xtmp), gradtmp(f.gradtmp), gradtmp0(0),
177                         last_result(f.last_result), last_optf(f.last_optf) {
178       if (f.o && !o) throw std::bad_alloc();
179     }
180     opt& operator=(opt const& f) {
181       if (this == &f) return *this; // self-assignment
182       nlopt_destroy(o);
183       o = nlopt_copy(f.o);
184       if (f.o && !o) throw std::bad_alloc();
185       xtmp = f.xtmp; gradtmp = f.gradtmp;
186       last_result = f.last_result; last_optf = f.last_optf;
187       return *this;
188     }
189
190     // Do the optimization:
191     result optimize(std::vector<double> &x, double &opt_f) {
192       if (o && nlopt_get_dimension(o) != x.size())
193         throw std::invalid_argument("dimension mismatch");
194       nlopt_result ret = nlopt_optimize(o, x.empty() ? NULL : &x[0], &opt_f);
195       last_result = result(ret);
196       last_optf = opt_f;
197       mythrow(ret);
198       return last_result;
199     }
200
201     // variant mainly useful for SWIG wrappers:
202     std::vector<double> optimize(const std::vector<double> &x0) {
203       std::vector<double> x(x0);
204       last_result = optimize(x, last_optf);
205       return x;
206     }
207     result last_optimize_result() const { return last_result; }
208     double last_optimum_value() const { return last_optf; }
209
210     // accessors:
211     algorithm get_algorithm() const {
212       if (!o) throw std::runtime_error("uninitialized nlopt::opt");
213       return algorithm(nlopt_get_algorithm(o));
214     }
215     const char *get_algorithm_name() const {
216       if (!o) throw std::runtime_error("uninitialized nlopt::opt");
217       return nlopt_algorithm_name(nlopt_get_algorithm(o));
218     }
219     unsigned get_dimension() const {
220       if (!o) throw std::runtime_error("uninitialized nlopt::opt");
221       return nlopt_get_dimension(o);
222     }
223
224     // Set the objective function
225     void set_min_objective(func f, void *f_data) {
226       myfunc_data *d = new myfunc_data;
227       if (!d) throw std::bad_alloc();
228       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
229       d->munge_destroy = d->munge_copy = NULL;
230       mythrow(nlopt_set_min_objective(o, myfunc, d)); // d freed via o
231     }
232     void set_min_objective(vfunc vf, void *f_data) {
233       myfunc_data *d = new myfunc_data;
234       if (!d) throw std::bad_alloc();
235       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
236       d->munge_destroy = d->munge_copy = NULL;
237       mythrow(nlopt_set_min_objective(o, myvfunc, d)); // d freed via o
238       alloc_tmp();
239     }
240     void set_max_objective(func f, void *f_data) {
241       myfunc_data *d = new myfunc_data;
242       if (!d) throw std::bad_alloc();
243       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
244       d->munge_destroy = d->munge_copy = NULL;
245       mythrow(nlopt_set_max_objective(o, myfunc, d)); // d freed via o
246     }
247     void set_max_objective(vfunc vf, void *f_data) {
248       myfunc_data *d = new myfunc_data;
249       if (!d) throw std::bad_alloc();
250       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
251       d->munge_destroy = d->munge_copy = NULL;
252       mythrow(nlopt_set_max_objective(o, myvfunc, d)); // d freed via o
253       alloc_tmp();
254     }
255
256     // for internal use in SWIG wrappers -- variant that
257     // takes ownership of f_data, with munging for destroy/copy
258     void set_min_objective(func f, void *f_data,
259                            nlopt_munge md, nlopt_munge mc) {
260       myfunc_data *d = new myfunc_data;
261       if (!d) throw std::bad_alloc();
262       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
263       d->munge_destroy = md; d->munge_copy = mc;
264       mythrow(nlopt_set_min_objective(o, myfunc, d)); // d freed via o
265     }
266     void set_max_objective(func f, void *f_data,
267                            nlopt_munge md, nlopt_munge mc) {
268       myfunc_data *d = new myfunc_data;
269       if (!d) throw std::bad_alloc();
270       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
271       d->munge_destroy = md; d->munge_copy = mc;
272       mythrow(nlopt_set_max_objective(o, myfunc, d)); // d freed via o
273     }
274
275     // Nonlinear constraints:
276
277     void remove_inequality_constraints() {
278       nlopt_result ret = nlopt_remove_inequality_constraints(o);
279       mythrow(ret);
280     }
281     void add_inequality_constraint(func f, void *f_data, double tol=0) {
282       myfunc_data *d = new myfunc_data;
283       if (!d) throw std::bad_alloc();
284       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
285       d->munge_destroy = d->munge_copy = NULL;
286       mythrow(nlopt_add_inequality_constraint(o, myfunc, d, tol));
287     }
288     void add_inequality_constraint(vfunc vf, void *f_data, double tol=0) {
289       myfunc_data *d = new myfunc_data;
290       if (!d) throw std::bad_alloc();
291       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
292       d->munge_destroy = d->munge_copy = NULL;
293       mythrow(nlopt_add_inequality_constraint(o, myvfunc, d, tol));
294       alloc_tmp();
295     }
296
297     void remove_equality_constraints() {
298       nlopt_result ret = nlopt_remove_equality_constraints(o);
299       mythrow(ret);
300     }
301     void add_equality_constraint(func f, void *f_data, double tol=0) {
302       myfunc_data *d = new myfunc_data;
303       if (!d) throw std::bad_alloc();
304       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
305       d->munge_destroy = d->munge_copy = NULL;
306       mythrow(nlopt_add_equality_constraint(o, myfunc, d, tol));
307     }
308     void add_equality_constraint(vfunc vf, void *f_data, double tol=0) {
309       myfunc_data *d = new myfunc_data;
310       if (!d) throw std::bad_alloc();
311       d->o = this; d->f = NULL; d->f_data = f_data; d->vf = vf;
312       d->munge_destroy = d->munge_copy = NULL;
313       mythrow(nlopt_add_equality_constraint(o, myvfunc, d, tol));
314       alloc_tmp();
315     }
316
317     // For internal use in SWIG wrappers (see also above)
318     void add_inequality_constraint(func f, void *f_data, 
319                                    nlopt_munge md, nlopt_munge mc,
320                                    double tol=0) {
321       myfunc_data *d = new myfunc_data;
322       if (!d) throw std::bad_alloc();
323       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
324       d->munge_destroy = md; d->munge_copy = mc;
325       mythrow(nlopt_add_inequality_constraint(o, myfunc, d, tol));
326     }
327     void add_equality_constraint(func f, void *f_data, 
328                                  nlopt_munge md, nlopt_munge mc,
329                                  double tol=0) {
330       myfunc_data *d = new myfunc_data;
331       if (!d) throw std::bad_alloc();
332       d->o = this; d->f = f; d->f_data = f_data; d->vf = NULL;
333       d->munge_destroy = md; d->munge_copy = mc;
334       mythrow(nlopt_add_equality_constraint(o, myfunc, d, tol));
335     }
336
337 #define NLOPT_GETSET_VEC(name)                                          \
338     void set_##name(double val) {                                       \
339       mythrow(nlopt_set_##name##1(o, val));                             \
340     }                                                                   \
341     void get_##name(std::vector<double> &v) const {                     \
342       if (o && nlopt_get_dimension(o) != v.size())                      \
343         throw std::invalid_argument("dimension mismatch");              \
344       mythrow(nlopt_get_##name(o, v.empty() ? NULL : &v[0]));           \
345     }                                                                   \
346     std::vector<double> get_##name() const {                    \
347       if (!o) throw std::runtime_error("uninitialized nlopt::opt");     \
348       std::vector<double> v(nlopt_get_dimension(o));                    \
349       get_##name(v);                                                    \
350       return v;                                                         \
351     }                                                                   \
352     void set_##name(const std::vector<double> &v) {                     \
353       if (o && nlopt_get_dimension(o) != v.size())                      \
354         throw std::invalid_argument("dimension mismatch");              \
355       mythrow(nlopt_set_##name(o, v.empty() ? NULL : &v[0]));           \
356     }
357
358     NLOPT_GETSET_VEC(lower_bounds)
359     NLOPT_GETSET_VEC(upper_bounds)
360
361     // stopping criteria:
362
363 #define NLOPT_GETSET(T, name)                                           \
364     T get_##name() const {                                              \
365       if (!o) throw std::runtime_error("uninitialized nlopt::opt");     \
366       return nlopt_get_##name(o);                                       \
367     }                                                                   \
368     void set_##name(T name) {                                           \
369       mythrow(nlopt_set_##name(o, name));                               \
370     }
371     NLOPT_GETSET(double, stopval)
372     NLOPT_GETSET(double, ftol_rel)
373     NLOPT_GETSET(double, ftol_abs)
374     NLOPT_GETSET(double, xtol_rel)
375     NLOPT_GETSET_VEC(xtol_abs)
376     NLOPT_GETSET(int, maxeval)
377     NLOPT_GETSET(double, maxtime)
378
379     NLOPT_GETSET(int, force_stop)
380     void force_stop() { set_force_stop(1); }
381
382     // algorithm-specific parameters:
383
384     void set_local_optimizer(const opt &lo) {
385       nlopt_result ret = nlopt_set_local_optimizer(o, lo.o);
386       mythrow(ret);
387     }
388
389     NLOPT_GETSET(unsigned, population)
390     NLOPT_GETSET_VEC(initial_step)
391
392     void set_default_initial_step(const std::vector<double> &x) {
393       nlopt_result ret 
394         = nlopt_set_default_initial_step(o, x.empty() ? NULL : &x[0]);
395       mythrow(ret);
396     }
397     void get_initial_step(const std::vector<double> &x, std::vector<double> &dx) const {
398       if (o && (nlopt_get_dimension(o) != x.size()
399                 || nlopt_get_dimension(o) != dx.size()))
400         throw std::invalid_argument("dimension mismatch");
401       nlopt_result ret = nlopt_get_initial_step(o, x.empty() ? NULL : &x[0],
402                                                 dx.empty() ? NULL : &dx[0]);
403       mythrow(ret);
404     }
405     std::vector<double> get_initial_step_(const std::vector<double> &x) const {
406       if (!o) throw std::runtime_error("uninitialized nlopt::opt");
407       std::vector<double> v(nlopt_get_dimension(o));
408       get_initial_step(x, v);
409       return v;
410     }
411   };
412
413 #undef NLOPT_GETSET
414 #undef NLOPT_GETSET_VEC
415
416   //////////////////////////////////////////////////////////////////////
417
418   void srand(unsigned long seed) { nlopt_srand(seed); }
419   void srand_time() { nlopt_srand_time(); }
420   void version(int &major, int &minor, int &bugfix) {
421     nlopt_version(&major, &minor, &bugfix);
422   }
423   int version_major() {
424     int major, minor, bugfix;
425     nlopt_version(&major, &minor, &bugfix);
426     return major;
427   }
428   int version_minor() {
429     int major, minor, bugfix;
430     nlopt_version(&major, &minor, &bugfix);
431     return minor;
432   }
433   int version_bugfix() {
434     int major, minor, bugfix;
435     nlopt_version(&major, &minor, &bugfix);
436     return bugfix;
437   }
438   const char *algorithm_name(algorithm a) {
439     return nlopt_algorithm_name(nlopt_algorithm(a));
440   }
441
442   //////////////////////////////////////////////////////////////////////
443
444 } // namespace nlopt