chiark / gitweb /
version, copyright-year bump
[nlopt.git] / octave / nlopt_minimize_constrained-mex.c
1 /* Copyright (c) 2007-2010 Massachusetts Institute of Technology
2  *
3  * Permission is hereby granted, free of charge, to any person obtaining
4  * a copy of this software and associated documentation files (the
5  * "Software"), to deal in the Software without restriction, including
6  * without limitation the rights to use, copy, modify, merge, publish,
7  * distribute, sublicense, and/or sell copies of the Software, and to
8  * permit persons to whom the Software is furnished to do so, subject to
9  * the following conditions:
10  * 
11  * The above copyright notice and this permission notice shall be
12  * included in all copies or substantial portions of the Software.
13  * 
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18  * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19  * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20  * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 
21  */
22
23 /* Matlab MEX interface to NLopt, and in particular to nlopt_minimize_constrained */
24
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <math.h>
29 #include <mex.h>
30
31 #include "nlopt.h"
32
33 #define xSTRIZE(x) #x
34 #define STRIZE(x) xSTRIZE(x)
35 #define CHECK(cond, msg) if (!(cond)) mexErrMsgTxt(msg);
36
37 static double struct_val_default(const mxArray *s, const char *name, double dflt)
38 {
39      mxArray *val = mxGetField(s, 0, name);
40      if (val) {
41           CHECK(mxIsNumeric(val) && !mxIsComplex(val) 
42                 && mxGetM(val) * mxGetN(val) == 1,
43                 "stop fields, other than xtol_abs, must be real scalars");
44           return mxGetScalar(val);
45      }
46      return dflt;
47 }
48
49 #define FLEN 128 /* max length of user function name */
50 #define MAXRHS 128 /* max nrhs for user function */
51 typedef struct {
52      char f[FLEN];
53      mxArray *plhs[2];
54      mxArray *prhs[MAXRHS];
55      int xrhs, nrhs;
56      int verbose, neval;
57 } user_function_data;
58
59 static double user_function(int n, const double *x,
60                             double *gradient, /* NULL if not needed */
61                             void *d_)
62 {
63   user_function_data *d = (user_function_data *) d_;
64   double f;
65
66   d->plhs[0] = d->plhs[1] = NULL;
67   memcpy(mxGetPr(d->prhs[d->xrhs]), x, n * sizeof(double));
68
69   CHECK(0 == mexCallMATLAB(gradient ? 2 : 1, d->plhs, 
70                            d->nrhs, d->prhs, d->f),
71         "error calling user function");
72
73   CHECK(mxIsNumeric(d->plhs[0]) && !mxIsComplex(d->plhs[0]) 
74         && mxGetM(d->plhs[0]) * mxGetN(d->plhs[0]) == 1,
75         "user function must return real scalar");
76   f = mxGetScalar(d->plhs[0]);
77   mxDestroyArray(d->plhs[0]);
78   if (gradient) {
79      CHECK(mxIsDouble(d->plhs[1]) && !mxIsComplex(d->plhs[1])
80            && (mxGetM(d->plhs[1]) == 1 || mxGetN(d->plhs[1]) == 1)
81            && mxGetM(d->plhs[1]) * mxGetN(d->plhs[1]) == n,
82            "gradient vector from user function is the wrong size");
83      memcpy(gradient, mxGetPr(d->plhs[1]), n * sizeof(double));
84      mxDestroyArray(d->plhs[1]);
85   }
86   d->neval++;
87   if (d->verbose) mexPrintf("nlopt_minimize_constrained eval #%d: %g\n", d->neval, f);
88   return f;
89 }                                
90
91 void mexFunction(int nlhs, mxArray *plhs[],
92                  int nrhs, const mxArray *prhs[])
93 {
94      nlopt_algorithm algorithm;
95      int n, m, i, j;
96      double *lb, *ub, *x, *x0;
97      double minf_max, ftol_rel, ftol_abs, xtol_rel, *xtol_abs, maxtime;
98      int maxeval;
99      nlopt_result ret;
100      mxArray *x_mx;
101      double minf = HUGE_VAL;
102      user_function_data d, *dc;
103
104      CHECK(nrhs == 9 && nlhs <= 3, "wrong number of arguments");
105
106      /* algorithm = prhs[0] */
107      CHECK(mxIsNumeric(prhs[0]) && !mxIsComplex(prhs[0]) 
108            && mxGetM(prhs[0]) * mxGetN(prhs[0]) == 1,
109            "algorithm must be real (integer) scalar");
110      algorithm = (nlopt_algorithm) (mxGetScalar(prhs[0]) + 0.5);
111      CHECK(algorithm >= 0 && algorithm < NLOPT_NUM_ALGORITHMS,
112            "unknown algorithm");
113
114      /* function f = prhs[1] */
115      CHECK(mxIsChar(prhs[1]) || mxIsFunctionHandle(prhs[1]), 
116            "f must be a function handle or function name");
117      if (mxIsChar(prhs[1])) {
118           CHECK(mxGetString(prhs[1], d.f, FLEN) == 0,
119                 "error reading function name string (too long?)");
120           d.nrhs = 1;
121           d.xrhs = 0;
122      }
123      else {
124           d.prhs[0] = (mxArray *) prhs[1];
125           strcpy(d.f, "feval");
126           d.nrhs = 2;
127           d.xrhs = 1;
128      }
129      
130      /* Cell f_data = prhs[2] */
131      CHECK(mxIsCell(prhs[2]), "f_data must be a Cell array");
132      CHECK(mxGetM(prhs[2]) * mxGetN(prhs[2]) + 1 <= MAXRHS,
133            "user function cannot have more than " STRIZE(MAXRHS) " arguments");
134      d.nrhs += mxGetM(prhs[2]) * mxGetN(prhs[2]);
135      for (i = 0; i < d.nrhs - (1+d.xrhs); ++i)
136           d.prhs[(1+d.xrhs)+i] = mxGetCell(prhs[2], i);
137
138      /* m = length(fc = prhs[3]) = length(fc_data = prhs[4])  */
139      CHECK(mxIsCell(prhs[3]), "fc must be a Cell array");
140      CHECK(mxIsCell(prhs[4]), "fc_data must be a Cell array");
141      m = mxGetM(prhs[3]) * mxGetN(prhs[3]);
142      CHECK(m == mxGetM(prhs[4]) * mxGetN(prhs[4]), "fc and fc_data must have the same length");
143      dc = (user_function_data *) malloc(sizeof(user_function_data) * m);
144
145      for (j = 0; j < m; ++j) {
146           mxArray *fc, *fc_data;
147
148           /* function fc = phrs[3] */
149           fc = mxGetCell(prhs[3], j);
150           CHECK(mxIsChar(fc) || mxIsFunctionHandle(fc),
151                 "fc must be Cell array of function handles or function names");
152           if (mxIsChar(fc)) {
153                CHECK(mxGetString(fc, dc[j].f, FLEN) == 0,
154                      "error reading function name string (too long?)");
155                dc[j].nrhs = 1;
156                dc[j].xrhs = 0;
157           }
158           else {
159                dc[j].prhs[0] = fc;
160                strcpy(dc[j].f, "feval");
161                dc[j].nrhs = 2;
162                dc[j].xrhs = 1;
163           }
164           
165           /* Cell fc_data = prhs[4] */
166           fc_data = mxGetCell(prhs[4], j);
167           CHECK(mxIsCell(fc_data), "fc_data must be a Cell array of Cell arrays");
168           CHECK(mxGetM(fc_data) * mxGetN(fc_data) + 1 <= MAXRHS,
169                 "user function cannot have more than " STRIZE(MAXRHS) " arguments");
170           dc[j].nrhs += mxGetM(fc_data) * mxGetN(fc_data);
171           for (i = 0; i < dc[j].nrhs - (1+dc[j].xrhs); ++i)
172                dc[j].prhs[(1+dc[j].xrhs)+i] = mxGetCell(fc_data, i);
173      }
174
175      /* lb = prhs[5] */
176      CHECK(mxIsDouble(prhs[5]) && !mxIsComplex(prhs[5])
177            && (mxGetM(prhs[5]) == 1 || mxGetN(prhs[5]) == 1),
178            "lb must be real row or column vector");
179      lb = mxGetPr(prhs[5]);
180      n = mxGetM(prhs[5]) * mxGetN(prhs[5]);
181
182      /* ub = prhs[6] */
183      CHECK(mxIsDouble(prhs[6]) && !mxIsComplex(prhs[6])
184            && (mxGetM(prhs[6]) == 1 || mxGetN(prhs[6]) == 1)
185            && mxGetM(prhs[6]) * mxGetN(prhs[6]) == n,
186            "ub must be real row or column vector of same length as lb");
187      ub = mxGetPr(prhs[6]);
188
189      /* x0 = prhs[7] */
190      CHECK(mxIsDouble(prhs[7]) && !mxIsComplex(prhs[7])
191            && (mxGetM(prhs[7]) == 1 || mxGetN(prhs[7]) == 1)
192            && mxGetM(prhs[7]) * mxGetN(prhs[7]) == n,
193            "x must be real row or column vector of same length as lb");
194      x0 = mxGetPr(prhs[7]);
195
196      /* stopping criteria = prhs[8] */
197      CHECK(mxIsStruct(prhs[8]), "stopping criteria must be a struct");
198      minf_max = struct_val_default(prhs[8], "minf_max", -HUGE_VAL);
199      ftol_rel = struct_val_default(prhs[8], "ftol_rel", 0);
200      ftol_abs = struct_val_default(prhs[8], "ftol_abs", 0);
201      xtol_rel = struct_val_default(prhs[8], "xtol_rel", 0);
202      maxeval = (int) (struct_val_default(prhs[8], "maxeval", -1) + 0.5);
203      maxtime = struct_val_default(prhs[8], "maxtime", -1);
204      d.verbose = (int) struct_val_default(prhs[8], "verbose", 0);
205      d.neval = 0;
206      for (i = 0; i < m; ++i) {
207           dc[i].verbose = d.verbose > 1;
208           dc[i].neval = 0;
209      }
210      {
211           mxArray *val = mxGetField(prhs[8], 0, "xtol_abs");
212           if (val) {
213                CHECK(mxIsNumeric(val) && !mxIsComplex(val) 
214                      && (mxGetM(val) == 1 || mxGetN(val) == 1)
215                      && mxGetM(val) * mxGetN(val) == n,
216                      "stop.xtol_abs must be real row/col vector of length n");
217                xtol_abs = mxGetPr(val);
218           }
219           else
220                xtol_abs = NULL;
221      }
222
223
224      x_mx = mxCreateDoubleMatrix(1, n, mxREAL);
225      x = mxGetPr(x_mx);
226      memcpy(x, x0, sizeof(double) * n);
227
228      d.prhs[d.xrhs] = mxCreateDoubleMatrix(1, n, mxREAL);
229      for (i = 0; i < m;++i)
230           dc[i].prhs[dc[i].xrhs] = d.prhs[d.xrhs];
231      
232      ret = nlopt_minimize_constrained(algorithm,
233                                       n,
234                                       user_function, &d,
235                                       m, user_function, dc,
236                                       sizeof(user_function_data),
237                                       lb, ub, x, &minf, minf_max, 
238                                       ftol_rel, ftol_abs, xtol_rel, xtol_abs,
239                                       maxeval, maxtime);
240
241      mxDestroyArray(d.prhs[d.xrhs]);
242      free(dc);
243
244      plhs[0] = x_mx;
245      if (nlhs > 1) {
246           plhs[1] = mxCreateDoubleMatrix(1, 1, mxREAL);
247           *(mxGetPr(plhs[1])) = minf;
248      }
249      if (nlhs > 2) {
250           plhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
251           *(mxGetPr(plhs[2])) = (int) ret;
252      }
253 }