chiark / gitweb /
nlopt (2.4.2+dfsg-8) unstable; urgency=medium
[nlopt.git] / swig / nlopt-python.i
1 // -*- C++ -*-
2
3 //////////////////////////////////////////////////////////////////////////////
4 // Converting NLopt/C++ exceptions to Python exceptions
5
6 %{
7
8 #define ExceptionSubclass(EXCNAME, EXCDOC)                              \
9   static PyTypeObject MyExc_ ## EXCNAME = {                             \
10     PyVarObject_HEAD_INIT(NULL, 0)                                              \
11       "nlopt." # EXCNAME,                                               \
12       sizeof(PyBaseExceptionObject),                                    \
13       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,                   \
14       Py_TPFLAGS_DEFAULT,                                               \
15       PyDoc_STR(EXCDOC)                                                 \
16   };                                                                    \
17   static void init_ ## EXCNAME(PyObject *m) {                           \
18     MyExc_ ## EXCNAME .tp_base = (PyTypeObject *) PyExc_Exception;      \
19     PyType_Ready(&MyExc_ ## EXCNAME);                                   \
20     Py_INCREF(&MyExc_ ## EXCNAME);                                      \
21     PyModule_AddObject(m, # EXCNAME, (PyObject *) &MyExc_ ## EXCNAME);  \
22   }
23
24
25 ExceptionSubclass(ForcedStop,
26                   "Python version of nlopt::forced_stop exception.")
27
28 ExceptionSubclass(RoundoffLimited,
29                   "Python version of nlopt::roundoff_limited exception.")
30
31 %}
32
33 %init %{
34   init_ForcedStop(m);
35   init_RoundoffLimited(m);
36 %}
37 %pythoncode %{
38   ForcedStop = _nlopt.ForcedStop
39   RoundoffLimited = _nlopt.RoundoffLimited
40   __version__ = str(_nlopt.version_major())+'.'+str(_nlopt.version_minor())+'.'+str(_nlopt.version_bugfix())
41 %}
42
43 %typemap(throws) std::bad_alloc %{
44   PyErr_SetString(PyExc_MemoryError, ($1).what());
45   SWIG_fail;
46 %}
47
48 %typemap(throws) nlopt::forced_stop %{
49   if (!PyErr_Occurred())
50     PyErr_SetString((PyObject*)&MyExc_ForcedStop, "NLopt forced stop");
51   SWIG_fail;
52 %}
53
54 %typemap(throws) nlopt::roundoff_limited %{
55   PyErr_SetString((PyObject*)&MyExc_RoundoffLimited, "NLopt roundoff-limited");
56   SWIG_fail;
57 %}
58
59 //////////////////////////////////////////////////////////////////////////////
60
61 %{
62 #define SWIG_FILE_WITH_INIT
63 %}
64 %include "numpy.i"
65 %init %{
66   import_array();
67 %}
68 %numpy_typemaps(double, NPY_DOUBLE, unsigned)
69
70 //////////////////////////////////////////////////////////////////////////////
71 // numpy.i does not include maps for std::vector<double>, so I add them here,
72 // taking advantage of the conversion functions provided by numpy.i
73
74 // Typemap for input arguments of type const std::vector<double> &
75 %typecheck(SWIG_TYPECHECK_POINTER, fragment="NumPy_Macros")
76   const std::vector<double> &
77 {
78   $1 = is_array($input) || PySequence_Check($input);
79 }
80 %typemap(in, fragment="NumPy_Fragments")
81   const std::vector<double> &
82 (PyArrayObject* array=NULL, int is_new_object=0, std::vector<double> arrayv)
83 {
84   npy_intp size[1] = { -1 };
85   array = obj_to_array_allow_conversion($input, NPY_DOUBLE, &is_new_object);
86   if (!array || !require_dimensions(array, 1) ||
87       !require_size(array, size, 1)) SWIG_fail;
88   arrayv = std::vector<double>(array_size(array,0));
89   $1 = &arrayv;
90   {
91     double *arr_data = (double *) array_data(array);
92     int arr_i, arr_s = array_stride(array,0) / sizeof(double);
93     int arr_sz = array_size(array,0);
94     for (arr_i = 0; arr_i < arr_sz; ++arr_i)
95       arrayv[arr_i] = arr_data[arr_i * arr_s];
96   }
97 }
98 %typemap(freearg)
99   const std::vector<double> &
100 {
101   if (is_new_object$argnum && array$argnum)
102     { Py_DECREF(array$argnum); }
103 }
104
105 // Typemap for return values of type std::vector<double>
106 %typemap(out, fragment="NumPy_Fragments") std::vector<double>
107 {
108   npy_intp sz = $1.size();
109   $result = PyArray_SimpleNew(1, &sz, NPY_DOUBLE);
110   std::memcpy(array_data($result), $1.empty() ? NULL : &$1[0],
111               sizeof(double) * sz);
112 }
113
114 //////////////////////////////////////////////////////////////////////////////
115 // Wrapper for objective function callbacks
116
117 %{
118 static void *free_pyfunc(void *p) { Py_DECREF((PyObject*) p); return p; }
119 static void *dup_pyfunc(void *p) { Py_INCREF((PyObject*) p); return p; }
120
121 #if NPY_API_VERSION < 0x00000007
122 #  define NPY_ARRAY_C_CONTIGUOUS NPY_C_CONTIGUOUS
123 #  define NPY_ARRAY_ALIGNED NPY_ALIGNED
124 #endif
125
126 static double func_python(unsigned n, const double *x, double *grad, void *f)
127 {
128   npy_intp sz = npy_intp(n), sz0 = 0, stride1 = sizeof(double);
129   PyObject *xpy = PyArray_New(&PyArray_Type, 1, &sz, NPY_DOUBLE, &stride1,
130                               const_cast<double*>(x), // not NPY_WRITEABLE
131                               0, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL);
132   PyObject *gradpy = grad
133     ? PyArray_SimpleNewFromData(1, &sz, NPY_DOUBLE, grad)
134     : PyArray_SimpleNew(1, &sz0, NPY_DOUBLE);
135   
136   PyObject *arglist = Py_BuildValue("OO", xpy, gradpy);
137   PyObject *result = PyEval_CallObject((PyObject *) f, arglist);
138   Py_DECREF(arglist);
139
140   Py_DECREF(gradpy);
141   Py_DECREF(xpy);
142
143   double val = HUGE_VAL;
144   if (PyErr_Occurred()) {
145     Py_XDECREF(result);
146     throw nlopt::forced_stop(); // just stop, don't call PyErr_Clear()
147   }
148   else if (result && PyFloat_Check(result)) {
149     val = PyFloat_AsDouble(result);
150     Py_DECREF(result);
151   }
152   else {
153     Py_XDECREF(result);
154     throw std::invalid_argument("invalid result passed to nlopt");
155   }
156   return val;
157 }
158
159 static void mfunc_python(unsigned m, double *result,
160                          unsigned n, const double *x, double *grad, void *f)
161 {
162   npy_intp nsz = npy_intp(n), msz = npy_intp(m);
163   npy_intp mnsz[2] = {msz, nsz};
164   npy_intp sz0 = 0, stride1 = sizeof(double);
165   PyObject *xpy = PyArray_New(&PyArray_Type, 1, &nsz, NPY_DOUBLE, &stride1,
166                               const_cast<double*>(x), // not NPY_WRITEABLE
167                               0, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL);
168   PyObject *rpy = PyArray_SimpleNewFromData(1, &msz, NPY_DOUBLE, result);
169   PyObject *gradpy = grad
170     ? PyArray_SimpleNewFromData(2, mnsz, NPY_DOUBLE, grad)
171     : PyArray_SimpleNew(1, &sz0, NPY_DOUBLE);
172   
173   PyObject *arglist = Py_BuildValue("OOO", rpy, xpy, gradpy);
174   PyObject *res = PyEval_CallObject((PyObject *) f, arglist);
175   Py_XDECREF(res);
176   Py_DECREF(arglist);
177
178   Py_DECREF(gradpy);
179   Py_DECREF(rpy);
180   Py_DECREF(xpy);
181
182   if (PyErr_Occurred()) {
183     throw nlopt::forced_stop(); // just stop, don't call PyErr_Clear()
184   }
185 }
186 %}
187
188 %typemap(in)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
189   $1 = func_python;
190   $2 = dup_pyfunc((void*) $input);
191   $3 = free_pyfunc;
192   $4 = dup_pyfunc;
193 }
194 %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::func f, void *f_data, nlopt_munge md, nlopt_munge mc) {
195   $1 = PyCallable_Check($input);
196 }
197
198 %typemap(in)(nlopt::mfunc mf, void *f_data, nlopt_munge md, nlopt_munge mc) {
199   $1 = mfunc_python;
200   $2 = dup_pyfunc((void*) $input);
201   $3 = free_pyfunc;
202   $4 = dup_pyfunc;
203 }
204 %typecheck(SWIG_TYPECHECK_POINTER)(nlopt::mfunc mf, void *f_data, nlopt_munge md, nlopt_munge mc) {
205   $1 = PyCallable_Check($input);
206 }