chiark / gitweb /
e82f28697ce70cb54732fee15afe9b8a3219915e
[nlopt.git] / swig / nlopt-python.i
1 // -*- C++ -*-
2
3 //////////////////////////////////////////////////////////////////////////////
4 // Converting NLopt/C++ exceptions to Python exceptions
5
6 %{
7
8 #define ExceptionSubclass(EXCNAME, EXCDOC)                              \
9   static PyTypeObject MyExc_ ## EXCNAME = {                             \
10     PyObject_HEAD_INIT(NULL)                                            \
11       0,                                                                \
12       "nlopt." # EXCNAME,                                               \
13       sizeof(PyBaseExceptionObject),                                    \
14       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,                   \
15       Py_TPFLAGS_DEFAULT,                                               \
16       PyDoc_STR(EXCDOC)                                                 \
17   };                                                                    \
18   static void init_ ## EXCNAME(PyObject *m) {                           \
19     MyExc_ ## EXCNAME .tp_base = (PyTypeObject *) PyExc_Exception;      \
20     PyType_Ready(&MyExc_ ## EXCNAME);                                   \
21     Py_INCREF(&MyExc_ ## EXCNAME);                                      \
22     PyModule_AddObject(m, # EXCNAME, (PyObject *) &MyExc_ ## EXCNAME);  \
23   }
24
25
26 ExceptionSubclass(ForcedStop,
27                   "Python version of nlopt::forced_stop exception.")
28
29 ExceptionSubclass(RoundoffLimited,
30                   "Python version of nlopt::roundoff_limited exception.")
31
32 %}
33
34 %init %{
35   init_ForcedStop(m);
36   init_RoundoffLimited(m);
37 %}
38 %pythoncode %{
39   ForcedStop = _nlopt.ForcedStop
40   RoundoffLimited = _nlopt.RoundoffLimited
41 %}
42
43 %typemap(throws) std::bad_alloc %{
44   PyErr_SetString(PyExc_MemoryError, ($1).what());
45   SWIG_fail;
46 %}
47
48 %typemap(throws) nlopt::forced_stop %{
49   if (!PyErr_Occurred())
50     PyErr_SetString((PyObject*)&MyExc_ForcedStop, "NLopt forced stop");
51   SWIG_fail;
52 %}
53
54 %typemap(throws) nlopt::roundoff_limited %{
55   PyErr_SetString((PyObject*)&MyExc_RoundoffLimited, "NLopt roundoff-limited");
56   SWIG_fail;
57 %}
58
59 //////////////////////////////////////////////////////////////////////////////
60
61 %{
62 #define SWIG_FILE_WITH_INIT
63 #define array_stride(a,i)        (((PyArrayObject *)a)->strides[i])
64 %}
65 %include "numpy.i"
66 %init %{
67   import_array();
68 %}
69 %numpy_typemaps(double, NPY_DOUBLE, unsigned)
70
71 //////////////////////////////////////////////////////////////////////////////
72 // numpy.i does not include maps for std::vector<double>, so I add them here,
73 // taking advantage of the conversion functions provided by numpy.i
74
75 // Typemap for input arguments of type const std::vector<double> &
76 %typecheck(SWIG_TYPECHECK_POINTER, fragment="NumPy_Macros")
77   const std::vector<double> &
78 {
79   $1 = is_array($input) || PySequence_Check($input);
80 }
81 %typemap(in, fragment="NumPy_Fragments")
82   const std::vector<double> &
83 (PyArrayObject* array=NULL, int is_new_object=0, std::vector<double> arrayv)
84 {
85   npy_intp size[1] = { -1 };
86   array = obj_to_array_allow_conversion($input, NPY_DOUBLE, &is_new_object);
87   if (!array || !require_dimensions(array, 1) ||
88       !require_size(array, size, 1)) SWIG_fail;
89   arrayv = std::vector<double>(array_size(array,0));
90   $1 = &arrayv;
91   {
92     double *arr_data = (double *) array_data(array);
93     int arr_i, arr_s = array_stride(array,0) / sizeof(double);
94     int arr_sz = array_size(array,0);
95     for (arr_i = 0; arr_i < arr_sz; ++arr_i)
96       arrayv[arr_i] = arr_data[arr_i * arr_s];
97   }
98 }
99 %typemap(freearg)
100   const std::vector<double> &
101 {
102   if (is_new_object$argnum && array$argnum)
103     { Py_DECREF(array$argnum); }
104 }
105
106 // Typemap for return values of type std::vector<double>
107 %typemap(out, fragment="NumPy_Fragments") std::vector<double>
108 {
109   npy_intp sz = $1.size();
110   $result = PyArray_SimpleNew(1, &sz, NPY_DOUBLE);
111   std::memcpy(array_data($result), $1.empty() ? NULL : &$1[0],
112               sizeof(double) * sz);
113 }
114
115 //////////////////////////////////////////////////////////////////////////////
116 // Wrapper for objective function callbacks
117
118 %{
119 static void *free_pyfunc(void *p) { Py_DECREF((PyObject*) p); return p; }
120 static void *dup_pyfunc(void *p) { Py_INCREF((PyObject*) p); return p; }
121
122 static double func_python(unsigned n, const double *x, double *grad, void *f)
123 {
124   npy_intp sz = npy_intp(n), sz0 = 0, stride1 = sizeof(double);
125   PyObject *xpy = PyArray_New(&PyArray_Type, 1, &sz, NPY_DOUBLE, &stride1,
126                               const_cast<double*>(x), // not NPY_WRITEABLE
127                               0, NPY_C_CONTIGUOUS | NPY_ALIGNED, NULL);
128   PyObject *gradpy = grad
129     ? PyArray_SimpleNewFromData(1, &sz, NPY_DOUBLE, grad)
130     : PyArray_SimpleNew(1, &sz0, NPY_DOUBLE);
131   
132   PyObject *arglist = Py_BuildValue("OO", xpy, gradpy);
133   PyObject *result = PyEval_CallObject((PyObject *) f, arglist);
134   Py_DECREF(arglist);
135
136   Py_DECREF(gradpy);
137   Py_DECREF(xpy);
138
139   double val = HUGE_VAL;
140   if (PyErr_Occurred()) {
141     Py_XDECREF(result);
142     throw nlopt::forced_stop(); // just stop, don't call PyErr_Clear()
143   }
144   else if (result && PyFloat_Check(result)) {
145     val = PyFloat_AsDouble(result);
146     Py_DECREF(result);
147   }
148   else {
149     Py_XDECREF(result);
150     throw std::invalid_argument("invalid result passed to nlopt");
151   }
152   return val;
153 }
154
155 static void mfunc_python(unsigned m, double *result,
156                          unsigned n, const double *x, double *grad, void *f)
157 {
158   npy_intp nsz = npy_intp(n), msz = npy_intp(m);
159   npy_intp mnsz[2] = {msz, nsz};
160   npy_intp sz0 = 0, stride1 = sizeof(double);
161   PyObject *xpy = PyArray_New(&PyArray_Type, 1, &nsz, NPY_DOUBLE, &stride1,
162                               const_cast<double*>(x), // not NPY_WRITEABLE
163                               0, NPY_C_CONTIGUOUS | NPY_ALIGNED, NULL);
164   PyObject *rpy = PyArray_SimpleNewFromData(1, &msz, NPY_DOUBLE, result);
165   PyObject *gradpy = grad
166     ? PyArray_SimpleNewFromData(2, mnsz, NPY_DOUBLE, grad)
167     : PyArray_SimpleNew(1, &sz0, NPY_DOUBLE);
168   
169   PyObject *arglist = Py_BuildValue("OOO", rpy, xpy, gradpy);
170   PyObject *res = PyEval_CallObject((PyObject *) f, arglist);
171   Py_XDECREF(res);
172   Py_DECREF(arglist);
173
174   Py_DECREF(gradpy);
175   Py_DECREF(rpy);
176   Py_DECREF(xpy);
177
178   if (PyErr_Occurred()) {
179     throw nlopt::forced_stop(); // just stop, don't call PyErr_Clear()
180   }
181 }
182 %}
183
184 %typemap(in)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
185   $1 = func_python;
186   $2 = dup_pyfunc((void*) $input);
187   $3 = free_pyfunc;
188   $4 = dup_pyfunc;
189 }
190 %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
191   $1 = PyCallable_Check($input);
192 }
193
194 %typemap(in)(nlopt::mfunc mf, void *f_data, nlopt_munge md, nlopt_munge mc) {
195   $1 = mfunc_python;
196   $2 = dup_pyfunc((void*) $input);
197   $3 = free_pyfunc;
198   $4 = dup_pyfunc;
199 }
200 %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::mfunc mf, void *f_data, nlopt_munge md, nlopt_munge mc) {
201   $1 = PyCallable_Check($input);
202 }