chiark / gitweb /
recommend building in a subdir
[nlopt.git] / octave / nlopt_optimize-mex.c
1 /* Copyright (c) 2007-2014 Massachusetts Institute of Technology
2  *
3  * Permission is hereby granted, free of charge, to any person obtaining
4  * a copy of this software and associated documentation files (the
5  * "Software"), to deal in the Software without restriction, including
6  * without limitation the rights to use, copy, modify, merge, publish,
7  * distribute, sublicense, and/or sell copies of the Software, and to
8  * permit persons to whom the Software is furnished to do so, subject to
9  * the following conditions:
10  * 
11  * The above copyright notice and this permission notice shall be
12  * included in all copies or substantial portions of the Software.
13  * 
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 
21  */
22
23 /* Matlab MEX interface to NLopt, and in particular to nlopt_optimize */
24
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <math.h>
29 #include <mex.h>
30
31 #include "nlopt.h"
32
33 #define CHECK0(cond, msg) if (!(cond)) mexErrMsgTxt(msg);
34
35 static double struct_val_default(const mxArray *s, const char *name, double dflt)
36 {
37      mxArray *val = mxGetField(s, 0, name);
38      if (val) {
39           CHECK0(mxIsNumeric(val) && !mxIsComplex(val) 
40                 && mxGetM(val) * mxGetN(val) == 1,
41                 "opt fields, other than xtol_abs, must be real scalars");
42           return mxGetScalar(val);
43      }
44      return dflt;
45 }
46
47 static double *struct_arrval(const mxArray *s, const char *name, unsigned n,
48                              double *dflt)
49 {
50      mxArray *val = mxGetField(s, 0, name);
51      if (val) {
52           CHECK0(mxIsNumeric(val) && !mxIsComplex(val) 
53                 && mxGetM(val) * mxGetN(val) == n,
54                 "opt vector field is not of length n");
55           return mxGetPr(val);
56      }
57      return dflt;
58 }
59
60 static mxArray *struct_funcval(const mxArray *s, const char *name)
61 {
62      mxArray *val = mxGetField(s, 0, name);
63      if (val) {
64           CHECK0(mxIsChar(val) || mxIsFunctionHandle(val),
65                  "opt function field is not a function handle/name");
66           return val;
67      }
68      return NULL;
69 }
70
71 static double *fill(double *arr, unsigned n, double val)
72 {
73      unsigned i;
74      for (i = 0; i < n; ++i) arr[i] = val;
75      return arr;
76 }
77
78 #define FLEN 128 /* max length of user function name */
79 #define MAXRHS 3 /* max nrhs for user function */
80 typedef struct user_function_data_s {
81      char f[FLEN];
82      mxArray *plhs[2];
83      mxArray *prhs[MAXRHS];
84      int xrhs, nrhs;
85      int verbose, neval;
86      struct user_function_data_s *dpre;
87      nlopt_opt opt;
88 } user_function_data;
89
90 static double user_function(unsigned n, const double *x,
91                             double *gradient, /* NULL if not needed */
92                             void *d_)
93 {
94   user_function_data *d = (user_function_data *) d_;
95   double f;
96
97   d->plhs[0] = d->plhs[1] = NULL;
98   memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
99
100   CHECK0(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs, 
101                            d->nrhs, d->prhs, d->f),
102         "error calling user function");
103
104   CHECK0(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0]) 
105         && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == 1,
106         "user function must return real scalar");
107   f = mxGetScalar(d->plhs[0]);
108   mxDestroyArray(d->plhs[0]);
109   if (gradient) {
110      CHECK0(mxIsDouble(d->plhs[1]) && !mxIsComplex(d->plhs[1])
111            && (mxGetM(d->plhs[1]) == 1 || mxGetN(d->plhs[1]) == 1)
112            && mxGetM(d->plhs[1]) * mxGetN(d->plhs[1]) == n,
113            "gradient vector from user function is the wrong size");
114      memcpy(gradient, mxGetPr(d->plhs[1]), n * sizeof(double));
115      mxDestroyArray(d->plhs[1]);
116   }
117   d->neval++;
118   if (d->verbose) mexPrintf("nlopt_optimize eval #%d: %g\n", d->neval, f);
119   if (mxIsNaN(f)) nlopt_force_stop(d->opt);
120   return f;
121 }
122
123 static void user_pre(unsigned n, const double *x, const double *v,
124                        double *vpre, void *d_)
125 {
126   user_function_data *d = ((user_function_data *) d_)->dpre;
127   d->plhs[0] = d->plhs[1] = NULL;
128   memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
129   memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double));
130
131   CHECK0(0 == mexCallMATLAB(1, d->plhs, 
132                             d->nrhs, d->prhs, d->f),
133          "error calling user function");
134
135   CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0])
136          && (mxGetM(d->plhs[0]) == 1 || mxGetN(d->plhs[0]) == 1)
137          && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == n,
138          "vpre vector from user function is the wrong size");
139   memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double));
140   mxDestroyArray(d->plhs[0]);
141   d->neval++;
142   if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
143 }
144
145 #define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; };
146
147 nlopt_opt make_opt(const mxArray *opts, unsigned n)
148 {
149      nlopt_opt opt = NULL, local_opt = NULL;
150      nlopt_algorithm algorithm;
151      double *tmp = NULL;
152      unsigned i;
153
154      algorithm = (nlopt_algorithm)
155           struct_val_default(opts, "algorithm", NLOPT_NUM_ALGORITHMS);
156      CHECK1(((int)algorithm) >= 0 && algorithm < NLOPT_NUM_ALGORITHMS,
157             "invalid opt.algorithm");
158
159      tmp = (double *) mxCalloc(n, sizeof(double));
160      opt = nlopt_create(algorithm, n);
161      CHECK1(opt, "nlopt: out of memory");
162
163      nlopt_set_lower_bounds(opt, struct_arrval(opts, "lower_bounds", n,
164                                                fill(tmp, n, -HUGE_VAL)));
165      nlopt_set_upper_bounds(opt, struct_arrval(opts, "upper_bounds", n,
166                                                fill(tmp, n, +HUGE_VAL)));
167
168      nlopt_set_stopval(opt, struct_val_default(opts, "stopval", -HUGE_VAL));
169      nlopt_set_ftol_rel(opt, struct_val_default(opts, "ftol_rel", 0.0));
170      nlopt_set_ftol_abs(opt, struct_val_default(opts, "ftol_abs", 0.0));
171      nlopt_set_xtol_rel(opt, struct_val_default(opts, "xtol_rel", 0.0));
172      nlopt_set_xtol_abs(opt, struct_arrval(opts, "xtol_abs", n,
173                                            fill(tmp, n, 0.0)));
174      nlopt_set_maxeval(opt, struct_val_default(opts, "maxeval", 0.0) < 0 ?
175                        0 : struct_val_default(opts, "maxeval", 0.0));
176      nlopt_set_maxtime(opt, struct_val_default(opts, "maxtime", 0.0));
177
178      nlopt_set_population(opt, struct_val_default(opts, "population", 0));
179      nlopt_set_vector_storage(opt, struct_val_default(opts, "vector_storage", 0));
180
181      if (struct_arrval(opts, "initial_step", n, NULL))
182           nlopt_set_initial_step(opt,
183                                  struct_arrval(opts, "initial_step", n, NULL));
184      
185      if (mxGetField(opts, 0, "local_optimizer")) {
186           const mxArray *local_opts = mxGetField(opts, 0, "local_optimizer");
187           CHECK1(mxIsStruct(local_opts),
188                  "opt.local_optimizer must be a structure");
189           CHECK1(local_opt = make_opt(local_opts, n),
190                  "error initializing local optimizer");
191           nlopt_set_local_optimizer(opt, local_opt);
192           nlopt_destroy(local_opt); local_opt = NULL;
193      }
194
195      mxFree(tmp);
196      return opt;
197 }
198
199 #define CHECK(cond, msg) if (!(cond)) { mxFree(dh); mxFree(dfc); nlopt_destroy(opt); mexErrMsgTxt(msg); }
200
201 void mexFunction(int nlhs, mxArray *plhs[],
202                  int nrhs, const mxArray *prhs[])
203 {
204      unsigned n;
205      double *x, *x0, opt_f;
206      nlopt_result ret;
207      mxArray *x_mx, *mx;
208      user_function_data d, dpre, *dfc = NULL, *dh = NULL;
209      nlopt_opt opt = NULL;
210
211      CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments");
212
213      /* options = prhs[0] */
214      CHECK(mxIsStruct(prhs[0]), "opt must be a struct");
215      
216      /* x0 = prhs[1] */
217      CHECK(mxIsDouble(prhs[1]) && !mxIsComplex(prhs[1])
218            && (mxGetM(prhs[1]) == 1 || mxGetN(prhs[1]) == 1),
219            "x must be real row or column vector");
220      n = mxGetM(prhs[1]) * mxGetN(prhs[1]),
221      x0 = mxGetPr(prhs[1]);
222
223      CHECK(opt = make_opt(prhs[0], n), "error initializing nlopt options");
224
225      d.neval = 0;
226      d.verbose = (int) struct_val_default(prhs[0], "verbose", 0);
227      d.opt = opt;
228
229      /* function f = prhs[1] */
230      mx = struct_funcval(prhs[0], "min_objective");
231      if (!mx) mx = struct_funcval(prhs[0], "max_objective");
232      CHECK(mx, "either opt.min_objective or opt.max_objective must exist");
233      if (mxIsChar(mx)) {
234           CHECK(mxGetString(mx, d.f, FLEN) == 0,
235                 "error reading function name string (too long?)");
236           d.nrhs = 1;
237           d.xrhs = 0;
238      }
239      else {
240           d.prhs[0] = mx;
241           strcpy(d.f, "feval");
242           d.nrhs = 2;
243           d.xrhs = 1;
244      }
245      d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
246
247      if ((mx = struct_funcval(prhs[0], "pre"))) {
248           CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
249                 "pre must contain function handles or function names");
250           if (mxIsChar(mx)) {
251                CHECK(mxGetString(mx, dpre.f, FLEN) == 0,
252                      "error reading function name string (too long?)");
253                dpre.nrhs = 2;
254                dpre.xrhs = 0;
255           }
256           else {
257                dpre.prhs[0] = mx;
258                strcpy(dpre.f, "feval");
259                dpre.nrhs = 3;
260                dpre.xrhs = 1;
261           }
262           dpre.verbose = d.verbose > 2;
263           dpre.opt = opt;
264           dpre.neval = 0;
265           dpre.prhs[dpre.xrhs] = d.prhs[d.xrhs];
266           dpre.prhs[d.xrhs+1] = mxCreateDoubleMatrix(1, n, mxREAL);
267           d.dpre = &dpre;
268
269           if (struct_funcval(prhs[0], "min_objective"))
270                nlopt_set_precond_min_objective(opt, user_function,user_pre,&d);
271           else
272                nlopt_set_precond_max_objective(opt, user_function,user_pre,&d);
273      }
274      else {
275           dpre.nrhs = 0;
276           if (struct_funcval(prhs[0], "min_objective"))
277                nlopt_set_min_objective(opt, user_function, &d);
278           else
279                nlopt_set_max_objective(opt, user_function, &d);
280      }
281
282      if ((mx = mxGetField(prhs[0], 0, "fc"))) {
283           int j, m;
284           double *fc_tol;
285           
286           CHECK(mxIsCell(mx), "fc must be a Cell array");
287           m = mxGetM(mx) * mxGetN(mx);;
288           dfc = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
289           fc_tol = struct_arrval(prhs[0], "fc_tol", m, NULL);
290
291           for (j = 0; j < m; ++j) {
292                mxArray *fc = mxGetCell(mx, j);
293                CHECK(mxIsChar(fc) || mxIsFunctionHandle(fc),
294                      "fc must contain function handles or function names");
295                if (mxIsChar(fc)) {
296                     CHECK(mxGetString(fc, dfc[j].f, FLEN) == 0,
297                      "error reading function name string (too long?)");
298                     dfc[j].nrhs = 1;
299                     dfc[j].xrhs = 0;
300                }
301                else {
302                     dfc[j].prhs[0] = fc;
303                     strcpy(dfc[j].f, "feval");
304                     dfc[j].nrhs = 2;
305                     dfc[j].xrhs = 1;
306                }
307                dfc[j].verbose = d.verbose > 1;
308                dfc[j].opt = opt;
309                dfc[j].neval = 0;
310                dfc[j].prhs[dfc[j].xrhs] = d.prhs[d.xrhs];
311                CHECK(nlopt_add_inequality_constraint(opt, user_function,
312                                                      dfc + j,
313                                                      fc_tol ? fc_tol[j] : 0)
314                      > 0, "nlopt error adding inequality constraint");
315           }
316      }
317
318
319      if ((mx = mxGetField(prhs[0], 0, "h"))) {
320           int j, m;
321           double *h_tol;
322           
323           CHECK(mxIsCell(mx), "h must be a Cell array");
324           m = mxGetM(mx) * mxGetN(mx);;
325           dh = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
326           h_tol = struct_arrval(prhs[0], "h_tol", m, NULL);
327
328           for (j = 0; j < m; ++j) {
329                mxArray *h = mxGetCell(mx, j);
330                CHECK(mxIsChar(h) || mxIsFunctionHandle(h),
331                      "h must contain function handles or function names");
332                if (mxIsChar(h)) {
333                     CHECK(mxGetString(h, dh[j].f, FLEN) == 0,
334                      "error reading function name string (too long?)");
335                     dh[j].nrhs = 1;
336                     dh[j].xrhs = 0;
337                }
338                else {
339                     dh[j].prhs[0] = h;
340                     strcpy(dh[j].f, "feval");
341                     dh[j].nrhs = 2;
342                     dh[j].xrhs = 1;
343                }
344                dh[j].verbose = d.verbose > 1;
345                dh[j].opt = opt;
346                dh[j].neval = 0;
347                dh[j].prhs[dh[j].xrhs] = d.prhs[d.xrhs];
348                CHECK(nlopt_add_equality_constraint(opt, user_function,
349                                                      dh + j,
350                                                    h_tol ? h_tol[j] : 0)
351                      > 0, "nlopt error adding equality constraint");
352           }
353      }
354
355
356      x_mx = mxCreateDoubleMatrix(mxGetM(prhs[1]), mxGetN(prhs[1]), mxREAL);
357      x = mxGetPr(x_mx);
358      memcpy(x, x0, sizeof(double) * n);
359
360      ret = nlopt_optimize(opt, x, &opt_f);
361
362      mxFree(dh);
363      mxFree(dfc);
364      mxDestroyArray(d.prhs[d.xrhs]);
365      if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]);
366      nlopt_destroy(opt);
367
368      plhs[0] = x_mx;
369      if (nlhs > 1) {
370           plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);
371           *(mxGetPr(plhs[1])) = opt_f;
372      }
373      if (nlhs > 2) {
374           plhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
375           *(mxGetPr(plhs[2])) = (int) ret;
376      }
377 }