chiark / gitweb /
b275f9c4568dda5d9421f1eacec1ee33bc014b12
[nlopt.git] / octave / nlopt_optimize-mex.c
1 /* Copyright (c) 2007-2011 Massachusetts Institute of Technology
2  *
3  * Permission is hereby granted, free of charge, to any person obtaining
4  * a copy of this software and associated documentation files (the
5  * "Software"), to deal in the Software without restriction, including
6  * without limitation the rights to use, copy, modify, merge, publish,
7  * distribute, sublicense, and/or sell copies of the Software, and to
8  * permit persons to whom the Software is furnished to do so, subject to
9  * the following conditions:
10  * 
11  * The above copyright notice and this permission notice shall be
12  * included in all copies or substantial portions of the Software.
13  * 
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 
21  */
22
23 /* Matlab MEX interface to NLopt, and in particular to nlopt_optimize */
24
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <math.h>
29 #include <mex.h>
30
31 #include "nlopt.h"
32
33 #define CHECK0(cond, msg) if (!(cond)) mexErrMsgTxt(msg);
34
35 static double struct_val_default(const mxArray *s, const char *name, double dflt)
36 {
37      mxArray *val = mxGetField(s, 0, name);
38      if (val) {
39           CHECK0(mxIsNumeric(val) && !mxIsComplex(val) 
40                 && mxGetM(val) * mxGetN(val) == 1,
41                 "opt fields, other than xtol_abs, must be real scalars");
42           return mxGetScalar(val);
43      }
44      return dflt;
45 }
46
47 static double *struct_arrval(const mxArray *s, const char *name, unsigned n,
48                              double *dflt)
49 {
50      mxArray *val = mxGetField(s, 0, name);
51      if (val) {
52           CHECK0(mxIsNumeric(val) && !mxIsComplex(val) 
53                 && mxGetM(val) * mxGetN(val) == n,
54                 "opt vector field is not of length n");
55           return mxGetPr(val);
56      }
57      return dflt;
58 }
59
60 static mxArray *struct_funcval(const mxArray *s, const char *name)
61 {
62      mxArray *val = mxGetField(s, 0, name);
63      if (val) {
64           CHECK0(mxIsChar(val) || mxIsFunctionHandle(val),
65                  "opt function field is not a function handle/name");
66           return val;
67      }
68      return NULL;
69 }
70
71 static double *fill(double *arr, unsigned n, double val)
72 {
73      unsigned i;
74      for (i = 0; i < n; ++i) arr[i] = val;
75      return arr;
76 }
77
78 #define FLEN 128 /* max length of user function name */
79 #define MAXRHS 3 /* max nrhs for user function */
80 typedef struct user_function_data_s {
81      char f[FLEN];
82      mxArray *plhs[2];
83      mxArray *prhs[MAXRHS];
84      int xrhs, nrhs;
85      int verbose, neval;
86      struct user_function_data_s *dpre;
87 } user_function_data;
88
89 static double user_function(unsigned n, const double *x,
90                             double *gradient, /* NULL if not needed */
91                             void *d_)
92 {
93   user_function_data *d = (user_function_data *) d_;
94   double f;
95
96   d->plhs[0] = d->plhs[1] = NULL;
97   memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
98
99   CHECK0(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs, 
100                            d->nrhs, d->prhs, d->f),
101         "error calling user function");
102
103   CHECK0(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0]) 
104         && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == 1,
105         "user function must return real scalar");
106   f = mxGetScalar(d->plhs[0]);
107   mxDestroyArray(d->plhs[0]);
108   if (gradient) {
109      CHECK0(mxIsDouble(d->plhs[1]) && !mxIsComplex(d->plhs[1])
110            && (mxGetM(d->plhs[1]) == 1 || mxGetN(d->plhs[1]) == 1)
111            && mxGetM(d->plhs[1]) * mxGetN(d->plhs[1]) == n,
112            "gradient vector from user function is the wrong size");
113      memcpy(gradient, mxGetPr(d->plhs[1]), n * sizeof(double));
114      mxDestroyArray(d->plhs[1]);
115   }
116   d->neval++;
117   if (d->verbose) mexPrintf("nlopt_optimize eval #%d: %g\n", d->neval, f);
118   return f;
119 }
120
121 static void user_pre(unsigned n, const double *x, const double *v,
122                        double *vpre, void *d_)
123 {
124   user_function_data *d = ((user_function_data *) d_)->dpre;
125   d->plhs[0] = d->plhs[1] = NULL;
126   memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
127   memcpy(mxGetPr(d->prhs[d->xrhs + 1]), v, n * sizeof(double));
128
129   CHECK0(0 == mexCallMATLAB(1, d->plhs, 
130                             d->nrhs, d->prhs, d->f),
131          "error calling user function");
132
133   CHECK0(mxIsDouble(d->plhs[0]) && !mxIsComplex(d->plhs[0])
134          && (mxGetM(d->plhs[0]) == 1 || mxGetN(d->plhs[0]) == 1)
135          && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == n,
136          "vpre vector from user function is the wrong size");
137   memcpy(vpre, mxGetPr(d->plhs[0]), n * sizeof(double));
138   mxDestroyArray(d->plhs[0]);
139   d->neval++;
140   if (d->verbose) mexPrintf("nlopt_optimize precond eval #%d\n", d->neval);
141 }
142
143 #define CHECK1(cond, msg) if (!(cond)) { mxFree(tmp); nlopt_destroy(opt); nlopt_destroy(local_opt); mexWarnMsgTxt(msg); return NULL; };
144
145 nlopt_opt make_opt(const mxArray *opts, unsigned n)
146 {
147      nlopt_opt opt = NULL, local_opt = NULL;
148      nlopt_algorithm algorithm;
149      double *tmp = NULL;
150      unsigned i;
151
152      algorithm = (nlopt_algorithm)
153           struct_val_default(opts, "algorithm", NLOPT_NUM_ALGORITHMS);
154      CHECK1(((int)algorithm) >= 0 && algorithm < NLOPT_NUM_ALGORITHMS,
155             "invalid opt.algorithm");
156
157      tmp = (double *) mxCalloc(n, sizeof(double));
158      opt = nlopt_create(algorithm, n);
159      CHECK1(opt, "nlopt: out of memory");
160
161      nlopt_set_lower_bounds(opt, struct_arrval(opts, "lower_bounds", n,
162                                                fill(tmp, n, -HUGE_VAL)));
163      nlopt_set_upper_bounds(opt, struct_arrval(opts, "upper_bounds", n,
164                                                fill(tmp, n, +HUGE_VAL)));
165
166      nlopt_set_stopval(opt, struct_val_default(opts, "stopval", -HUGE_VAL));
167      nlopt_set_ftol_rel(opt, struct_val_default(opts, "ftol_rel", 0.0));
168      nlopt_set_ftol_abs(opt, struct_val_default(opts, "ftol_abs", 0.0));
169      nlopt_set_xtol_rel(opt, struct_val_default(opts, "xtol_rel", 0.0));
170      nlopt_set_xtol_abs(opt, struct_arrval(opts, "xtol_abs", n,
171                                            fill(tmp, n, 0.0)));
172      nlopt_set_maxeval(opt, struct_val_default(opts, "maxeval", 0.0) < 0 ?
173                        0 : struct_val_default(opts, "maxeval", 0.0));
174      nlopt_set_maxtime(opt, struct_val_default(opts, "maxtime", 0.0));
175
176      nlopt_set_population(opt, struct_val_default(opts, "population", 0));
177      nlopt_set_vector_storage(opt, struct_val_default(opts, "vector_storage", 0));
178
179      if (struct_arrval(opts, "initial_step", n, NULL))
180           nlopt_set_initial_step(opt,
181                                  struct_arrval(opts, "initial_step", n, NULL));
182      
183      if (mxGetField(opts, 0, "local_optimizer")) {
184           const mxArray *local_opts = mxGetField(opts, 0, "local_optimizer");
185           CHECK1(mxIsStruct(local_opts),
186                  "opt.local_optimizer must be a structure");
187           CHECK1(local_opt = make_opt(local_opts, n),
188                  "error initializing local optimizer");
189           nlopt_set_local_optimizer(opt, local_opt);
190           nlopt_destroy(local_opt); local_opt = NULL;
191      }
192
193      mxFree(tmp);
194      return opt;
195 }
196
197 #define CHECK(cond, msg) if (!(cond)) { mxFree(dh); mxFree(dfc); nlopt_destroy(opt); mexErrMsgTxt(msg); }
198
199 void mexFunction(int nlhs, mxArray *plhs[],
200                  int nrhs, const mxArray *prhs[])
201 {
202      unsigned n;
203      double *x, *x0, opt_f;
204      nlopt_result ret;
205      mxArray *x_mx, *mx;
206      user_function_data d, dpre, *dfc = NULL, *dh = NULL;
207      nlopt_opt opt = NULL;
208
209      CHECK(nrhs == 2 && nlhs <= 3, "wrong number of arguments");
210
211      /* options = prhs[0] */
212      CHECK(mxIsStruct(prhs[0]), "opt must be a struct");
213      
214      /* x0 = prhs[1] */
215      CHECK(mxIsDouble(prhs[1]) && !mxIsComplex(prhs[1])
216            && (mxGetM(prhs[1]) == 1 || mxGetN(prhs[1]) == 1),
217            "x must be real row or column vector");
218      n = mxGetM(prhs[1]) * mxGetN(prhs[1]),
219      x0 = mxGetPr(prhs[1]);
220
221      CHECK(opt = make_opt(prhs[0], n), "error initializing nlopt options");
222
223      d.neval = 0;
224      d.verbose = (int) struct_val_default(prhs[0], "verbose", 0);
225
226      /* function f = prhs[1] */
227      mx = struct_funcval(prhs[0], "min_objective");
228      if (!mx) mx = struct_funcval(prhs[0], "max_objective");
229      CHECK(mx, "either opt.min_objective or opt.max_objective must exist");
230      if (mxIsChar(mx)) {
231           CHECK(mxGetString(mx, d.f, FLEN) == 0,
232                 "error reading function name string (too long?)");
233           d.nrhs = 1;
234           d.xrhs = 0;
235      }
236      else {
237           d.prhs[0] = mx;
238           strcpy(d.f, "feval");
239           d.nrhs = 2;
240           d.xrhs = 1;
241      }
242      d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
243
244      if ((mx = struct_funcval(prhs[0], "pre"))) {
245           CHECK(mxIsChar(mx) || mxIsFunctionHandle(mx),
246                 "pre must contain function handles or function names");
247           if (mxIsChar(mx)) {
248                CHECK(mxGetString(mx, dpre.f, FLEN) == 0,
249                      "error reading function name string (too long?)");
250                dpre.nrhs = 2;
251                dpre.xrhs = 0;
252           }
253           else {
254                dpre.prhs[0] = mx;
255                strcpy(dpre.f, "feval");
256                dpre.nrhs = 3;
257                dpre.xrhs = 1;
258           }
259           dpre.verbose = d.verbose > 2;
260           dpre.neval = 0;
261           dpre.prhs[dpre.xrhs] = d.prhs[d.xrhs];
262           dpre.prhs[d.xrhs+1] = mxCreateDoubleMatrix(1, n, mxREAL);
263           d.dpre = &dpre;
264
265           if (struct_funcval(prhs[0], "min_objective"))
266                nlopt_set_precond_min_objective(opt, user_function,user_pre,&d);
267           else
268                nlopt_set_precond_max_objective(opt, user_function,user_pre,&d);
269      }
270      else {
271           dpre.nrhs = 0;
272           if (struct_funcval(prhs[0], "min_objective"))
273                nlopt_set_min_objective(opt, user_function, &d);
274           else
275                nlopt_set_max_objective(opt, user_function, &d);
276      }
277
278      if ((mx = mxGetField(prhs[0], 0, "fc"))) {
279           int j, m;
280           double *fc_tol;
281           
282           CHECK(mxIsCell(mx), "fc must be a Cell array");
283           m = mxGetM(mx) * mxGetN(mx);;
284           dfc = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
285           fc_tol = struct_arrval(prhs[0], "fc_tol", m, NULL);
286
287           for (j = 0; j < m; ++j) {
288                mxArray *fc = mxGetCell(mx, j);
289                CHECK(mxIsChar(fc) || mxIsFunctionHandle(fc),
290                      "fc must contain function handles or function names");
291                if (mxIsChar(fc)) {
292                     CHECK(mxGetString(fc, dfc[j].f, FLEN) == 0,
293                      "error reading function name string (too long?)");
294                     dfc[j].nrhs = 1;
295                     dfc[j].xrhs = 0;
296                }
297                else {
298                     dfc[j].prhs[0] = fc;
299                     strcpy(dfc[j].f, "feval");
300                     dfc[j].nrhs = 2;
301                     dfc[j].xrhs = 1;
302                }
303                dfc[j].verbose = d.verbose > 1;
304                dfc[j].neval = 0;
305                dfc[j].prhs[dfc[j].xrhs] = d.prhs[d.xrhs];
306                CHECK(nlopt_add_inequality_constraint(opt, user_function,
307                                                      dfc + j,
308                                                      fc_tol ? fc_tol[j] : 0)
309                      > 0, "nlopt error adding inequality constraint");
310           }
311      }
312
313
314      if ((mx = mxGetField(prhs[0], 0, "h"))) {
315           int j, m;
316           double *h_tol;
317           
318           CHECK(mxIsCell(mx), "h must be a Cell array");
319           m = mxGetM(mx) * mxGetN(mx);;
320           dh = (user_function_data *) mxCalloc(m, sizeof(user_function_data));
321           h_tol = struct_arrval(prhs[0], "h_tol", m, NULL);
322
323           for (j = 0; j < m; ++j) {
324                mxArray *h = mxGetCell(mx, j);
325                CHECK(mxIsChar(h) || mxIsFunctionHandle(h),
326                      "h must contain function handles or function names");
327                if (mxIsChar(h)) {
328                     CHECK(mxGetString(h, dh[j].f, FLEN) == 0,
329                      "error reading function name string (too long?)");
330                     dh[j].nrhs = 1;
331                     dh[j].xrhs = 0;
332                }
333                else {
334                     dh[j].prhs[0] = h;
335                     strcpy(dh[j].f, "feval");
336                     dh[j].nrhs = 2;
337                     dh[j].xrhs = 1;
338                }
339                dh[j].verbose = d.verbose > 1;
340                dh[j].neval = 0;
341                dh[j].prhs[dh[j].xrhs] = d.prhs[d.xrhs];
342                CHECK(nlopt_add_equality_constraint(opt, user_function,
343                                                      dh + j,
344                                                    h_tol ? h_tol[j] : 0)
345                      > 0, "nlopt error adding equality constraint");
346           }
347      }
348
349
350      x_mx = mxCreateDoubleMatrix(mxGetM(prhs[1]), mxGetN(prhs[1]), mxREAL);
351      x = mxGetPr(x_mx);
352      memcpy(x, x0, sizeof(double) * n);
353
354      ret = nlopt_optimize(opt, x, &opt_f);
355
356      mxFree(dh);
357      mxFree(dfc);
358      mxDestroyArray(d.prhs[d.xrhs]);
359      if (dpre.nrhs > 0) mxDestroyArray(dpre.prhs[d.xrhs+1]);
360      nlopt_destroy(opt);
361
362      plhs[0] = x_mx;
363      if (nlhs > 1) {
364           plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);
365           *(mxGetPr(plhs[1])) = opt_f;
366      }
367      if (nlhs > 2) {
368           plhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
369           *(mxGetPr(plhs[2])) = (int) ret;
370      }
371 }