chiark / gitweb /
scheme callbacks now take grad as vector argument, and must use side-effects to modify it
authorstevenj <stevenj@alum.mit.edu>
Mon, 14 Jun 2010 22:27:41 +0000 (18:27 -0400)
committerstevenj <stevenj@alum.mit.edu>
Mon, 14 Jun 2010 22:27:41 +0000 (18:27 -0400)
darcs-hash:20100614222741-c8de0-d8e26b0d530e3644688b4170c5760abea7985c4f.gz

swig/nlopt-guile.i

index 8f108d01a961e4f8bbe4fe89288083c656ef4d01..d9987a30e0fad1de08fe1957c1eacfc20f11448b 100644 (file)
@@ -33,29 +33,23 @@ static void *free_guilefunc(void *p) {
 static void *dup_guilefunc(void *p) { 
   scm_gc_protect_object((SCM) p); return p; }
 
-// func wrapper around Guile function (val . grad) = f(x)
+// func wrapper around Guile function val = f(x, grad)
 static double func_guile(unsigned n, const double *x, double *grad, void *f) {
   SCM xscm = scm_c_make_vector(n, SCM_UNSPECIFIED);
   for (unsigned i = 0; i < n; ++i)
-    scm_c_vector_set_x(xscm, i, scm_make_real(x[i]));
-  SCM ret = scm_call_1((SCM) f, xscm);
-  if (scm_real_p(ret)) {
-    if (grad) throw std::invalid_argument("missing gradient");
-    return scm_to_double(ret);
-  }
-  else if (scm_is_pair(ret)) { /* must be (cons value gradient) */
-    SCM valscm = SCM_CAR(ret), grad_scm = grad_scm;
-    if (grad
-       && scm_is_vector(grad_scm)
-       && scm_c_vector_length(grad_scm) == n) {
-      for (unsigned i = 0; i < n; ++i)
-       grad[i] = scm_to_double(scm_c_vector_ref(grad_scm, i));
+    SCM_SIMPLE_VECTOR_SET(xscm, i, scm_make_real(x[i]));
+  SCM grad_scm = grad ? scm_c_make_vector(n, SCM_UNSPECIFIED) : SCM_BOOL_F;
+  SCM ret = scm_call_2((SCM) f, xscm, grad_scm);
+  if (!scm_real_p(ret))
+    throw std::invalid_argument("invalid result passed to nlopt");
+  if (grad) {
+    for (unsigned i = 0; i < n; ++i) {
+      if (!scm_real_p(ret)) 
+       throw std::invalid_argument("invalid gradient passed to nlopt");
+      grad[i] = scm_to_double(SCM_SIMPLE_VECTOR_REF(grad_scm, i));
     }
-    else throw std::invalid_argument("invalid gradient");
-    if (scm_real_p(valscm))
-      return scm_to_double(valscm);
   }
-  throw std::invalid_argument("invalid result passed to nlopt");
+  return scm_to_double(ret);
 }
 %}