chiark / gitweb /
3c9a4efc72cccee59e8c10d74846dad51b22dbc1
[nlopt.git] / octave / nlopt_minimize-oct.cc
1 #include <octave/oct.h>
2 #include <octave/oct-map.h>
3 #include <octave/ov.h>
4 #include <math.h>
5 #include <stdio.h>
6
7 #include "nlopt.h"
8 #include "nlopt_minimize_usage.h"
9
10 static double struct_val_default(Octave_map &m, const std::string& k,
11                                  double dflt)
12 {
13   if (m.contains(k)) {
14     if (m.contents(k).length() == 1 && (m.contents(k))(0).is_real_scalar())
15       return (m.contents(k))(0).double_value();
16   }
17   return dflt;
18 }
19
20 static Matrix struct_val_default(Octave_map &m, const std::string& k,
21                                  Matrix &dflt)
22 {
23   if (m.contains(k)) {
24     if ((m.contents(k)).length() == 1) {
25       if ((m.contents(k))(0).is_real_scalar())
26         return Matrix(1, dflt.length(), (m.contents(k))(0).double_value());
27       else if ((m.contents(k))(0).is_real_matrix())
28         return (m.contents(k))(0).matrix_value();
29     }
30   }
31   return dflt;
32 }
33
34 typedef struct {
35   octave_function *f;
36   Cell f_data;
37   int neval, verbose;
38 } user_function_data;
39
40 static double user_function(int n, const double *x,
41                             double *gradient, /* NULL if not needed */
42                             void *data_)
43 {
44   user_function_data *data = (user_function_data *) data_;
45   octave_value_list args(1 + data->f_data.length(), 0);
46   Matrix xm(1,n);
47   for (int i = 0; i < n; ++i)
48     xm(i) = x[i];
49   args(0) = xm;
50   for (int i = 0; i < data->f_data.length(); ++i)
51     args(1 + i) = data->f_data(i);
52   octave_value_list res = data->f->do_multi_index_op(gradient ? 2 : 1, args); 
53   if (res.length() < (gradient ? 2 : 1))
54     gripe_user_supplied_eval("nlopt_minimize");
55   else if (!res(0).is_real_scalar()
56            || (gradient && !res(1).is_real_matrix()
57                && !(n == 1 && res(1).is_real_scalar())))
58     gripe_user_returned_invalid("nlopt_minimize");
59   else {
60     if (gradient) {
61       if (n == 1 && res(1).is_real_scalar())
62         gradient[0] = res(1).double_value();
63       else {
64         Matrix grad = res(1).matrix_value();
65         for (int i = 0; i < n; ++i)
66           gradient[i] = grad(i);
67       }
68     }
69     data->neval++;
70     if (data->verbose) printf("nlopt_minimize eval #%d: %g\n", 
71                               data->neval, res(0).double_value());
72     return res(0).double_value();
73   }
74   return 0;
75 }                                
76
77 #define CHECK(cond, msg) if (!(cond)) { fprintf(stderr, msg "\n\n"); print_usage("nlopt_minimize"); return retval; }
78
79 DEFUN_DLD(nlopt_minimize, args, nargout, NLOPT_MINIMIZE_USAGE)
80 {
81   octave_value_list retval;
82   double A;
83
84   CHECK(args.length() == 7 && nargout <= 3, "wrong number of args");
85
86   CHECK(args(0).is_real_scalar(), "n must be real scalar");
87   nlopt_algorithm algorithm = nlopt_algorithm(args(0).int_value());
88
89   user_function_data d;
90   CHECK(args(1).is_function() || args(1).is_function_handle(), 
91         "f must be function");
92   d.f = args(1).function_value();
93   CHECK(args(2).is_cell(), "f_data must be cell array");
94   d.f_data = args(2).cell_value();
95
96   CHECK(args(3).is_real_matrix() || args(3).is_real_scalar(),
97         "lb must be real vector");
98   Matrix lb = args(3).is_real_scalar() ?
99     Matrix(1, 1, args(3).double_value()) : args(3).matrix_value();
100   int n = lb.length();
101   
102   CHECK(args(4).is_real_matrix() || args(4).is_real_scalar(),
103         "ub must be real vector");
104   Matrix ub = args(4).is_real_scalar() ?
105     Matrix(1, 1, args(4).double_value()) : args(4).matrix_value();
106   CHECK(n == ub.length(), "lb and ub must have same length");
107
108   CHECK(args(5).is_real_matrix() || args(5).is_real_scalar(),
109         "x must be real vector");
110   Matrix x = args(5).is_real_scalar() ?
111     Matrix(1, 1, args(5).double_value()) : args(5).matrix_value();
112   CHECK(n == x.length(), "x and lb/ub must have same length");
113
114   CHECK(args(6).is_map(), "stop must be structure");
115   Octave_map stop = args(6).map_value();
116   double minf_max = struct_val_default(stop, "minf_max", -HUGE_VAL);
117   double ftol_rel = struct_val_default(stop, "ftol_rel", 0);
118   double ftol_abs = struct_val_default(stop, "ftol_abs", 0);
119   double xtol_rel = struct_val_default(stop, "xtol_rel", 0);
120   Matrix zeros(1, n, 0.0);
121   Matrix xtol_abs = struct_val_default(stop, "xtol_abs", zeros);
122   CHECK(n == xtol_abs.length(), "stop.xtol_abs must have same length as x");
123   int maxeval = int(struct_val_default(stop, "maxeval", -1));
124   double maxtime = struct_val_default(stop, "maxtime", -1);
125   
126   double minf = HUGE_VAL;
127   nlopt_result ret = nlopt_minimize(algorithm,
128                                     n,
129                                     user_function, &d,
130                                     lb.data(), ub.data(),
131                                     x.fortran_vec(), &minf,
132                                     minf_max, ftol_rel, ftol_abs,
133                                     xtol_rel, xtol_abs.data(),
134                                     maxeval, maxtime);
135                                     
136   retval(0) = x;
137   if (nargout > 1)
138     retval(1) = minf;
139   if (nargout > 2)
140     retval(2) = int(ret);
141
142   return retval;
143 }