chiark / gitweb /
mp.c, catacomb/__init__.py, pyke/: Fix mixed-mode arithmetic involving `float'.
authorMark Wooding <mdw@distorted.org.uk>
Tue, 22 Oct 2019 17:31:57 +0000 (18:31 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Sat, 11 Apr 2020 11:44:21 +0000 (12:44 +0100)
This is a bit embarrassing.

>>> import catacomb as C
>>> x = C.MP(5)
>>> x == 5.1
True
>>> x < 5.1
False
>>> r = x/2
>>> r
5/2
>>> r == 2
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: an integer is required
>>> r == 2.5
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: an integer is required
>>> r*1.0
5/2
>>> r*1.1
5/2
>>> r*2.0
5
>>> r*2.5
5

Fix this nonsense.

  * Change the `obvious' arithmetic operators so that they notice that
    one of the operands is a float.  Handle this by converting to a
    Python bignum and letting Python handle the arithmetic.  The result
    is a float, which seems sensible inexact contagion.

  * Introduce a rich-comparison method which also detects a float
    operand and hands off to Python.  Python seems to get this right,
    comparing the float to the bignum in its full precision, so that's a
    win.

  * Also, modify the `IntRat' code to apply inexact contagion in the
    same way.  Comparisons may be imperfect here, but that's
    surprisingly hard to get right.

The new results:

>>> import catacomb as C
>>> x = C.MP(5)
>>> x == 5.1
False
>>> x < 5.1
True
>>> r = x/2
>>> r
5/2
>>> r == 2
False
>>> r == 2.5
True
>>> r*1.0
2.5
>>> r*1.1
2.75
>>> r*2.0
5.0
>>> r*2.5
6.25

catacomb/__init__.py
mp.c
pyke/pyke.c
pyke/pyke.h
t/t-mp.py

index 94efb6fa6897b5ea7a6c89857a0887bc5fe86272..4a98cedb3a20925c4bf6ce722e4e925a1cc74067 100644 (file)
@@ -84,6 +84,9 @@ def _bin(s): return s
 def _iteritems(dict): return dict.iteritems()
 def _itervalues(dict): return dict.itervalues()
 
+## The built-in bignum type.
+_long = long
+
 ## How to fix a name back into the right identifier.  Alas, the rules are not
 ## consistent.
 def _fixname(name):
@@ -399,6 +402,10 @@ class BaseRat (object):
 
 class IntRat (BaseRat):
   RING = MP
+  def __new__(cls, a, b):
+    if isinstance(a, float) or isinstance(b, float): return a/b
+    return super(IntRat, cls).__new__(cls, a, b)
+  def __float__(me): return float(me._n)/float(me._d)
 
 class GFRat (BaseRat):
   RING = GF
@@ -412,8 +419,12 @@ class _tmp:
   def mont(x): return MPMont(x)
   def barrett(x): return MPBarrett(x)
   def reduce(x): return MPReduce(x)
-  def __truediv__(me, you): return IntRat(me, you)
-  def __rtruediv__(me, you): return IntRat(you, me)
+  def __truediv__(me, you):
+    if isinstance(you, float): return _long(me)/you
+    else: return IntRat(me, you)
+  def __rtruediv__(me, you):
+    if isinstance(you, float): return you/_long(me)
+    else: return IntRat(you, me)
   __div__ = __truediv__
   __rdiv__ = __rtruediv__
   _repr_pretty_ = _pp_str
diff --git a/mp.c b/mp.c
index fe5e7dcb8c6df6ad5d1d25f0e331b87cb84a772c..8f66d1773376954109410bd077daf1e65564ebb3 100644 (file)
--- a/mp.c
+++ b/mp.c
@@ -209,7 +209,7 @@ mp *tomp(PyObject *o)
   PyObject *l;
   mp *x;
 
-  if (!o)
+  if (!o || PyFloat_Check(o))
     return (0);
   else if (MP_PYCHECK(o) || GF_PYCHECK(o))
     return (MP_COPY(MP_X(o)));
@@ -322,6 +322,26 @@ static int gfbinop(PyObject *x, PyObject *y, mp **xx, mp **yy)
   return (0);
 }
 
+#define FPBINOP(name, pyop)                                            \
+  static PyObject *mp_py##name(PyObject *x, PyObject *y) {             \
+    mp *xx, *yy, *zz;                                                  \
+    PyObject *l, *rc;                                                  \
+    if (PyFloat_Check(x)) {                                            \
+      l = mp_topylong(MP_X(y)); rc = PyNumber_##pyop(x, l);            \
+      Py_DECREF(l); return (rc);                                       \
+    } else if (PyFloat_Check(y)) {                                     \
+      l = mp_topylong(MP_X(x)); rc = PyNumber_##pyop(l, y);            \
+      Py_DECREF(l); return (rc);                                       \
+    }                                                                  \
+    if (mpbinop(x, y, &xx, &yy)) RETURN_NOTIMPL;                       \
+    zz = mp_##name(MP_NEW, xx, yy);                                    \
+    MP_DROP(xx); MP_DROP(yy);                                          \
+    return (mp_pywrap(zz));                                            \
+  }
+FPBINOP(add, Add)
+FPBINOP(sub, Subtract)
+FPBINOP(mul, Multiply)
+
 #define gf_and mp_and
 #define gf_or mp_or
 #define gf_xor mp_xor
@@ -333,9 +353,6 @@ static int gfbinop(PyObject *x, PyObject *y, mp **xx, mp **yy)
     MP_DROP(xx); MP_DROP(yy);                                          \
     return (pre##_pywrap(zz));                                         \
   }
-BINOP(mp, add)
-BINOP(mp, sub)
-BINOP(mp, mul)
 BINOP(mp, and2c)
 BINOP(mp, or2c)
 BINOP(mp, xor2c)
@@ -527,6 +544,20 @@ COERCE(mp, MP)
 COERCE(gf, GF)
 #undef COERCE
 
+static PyObject *mp_pyrichcompare(PyObject *x, PyObject *y, int op)
+{
+  mp *xx, *yy;
+  PyObject *l, *rc;
+  if (PyFloat_Check(y)) {
+    l = mp_topylong(MP_X(x)); rc = PyObject_RichCompare(l, y, op);
+    Py_DECREF(l); return (rc);
+  }
+  if (mpbinop(x, y, &xx, &yy)) RETURN_NOTIMPL;
+  rc = enrich_compare(op, mp_cmp(xx, yy));
+  MP_DROP(xx); MP_DROP(yy);
+  return (rc);
+}
+
 static int mp_pycompare(PyObject *x, PyObject *y)
   { return mp_cmp(MP_X(x), MP_X(y)); }
 
@@ -980,7 +1011,7 @@ static const PyTypeObject mp_pytype_skel = {
 
   0,                                   /* @tp_traverse@ */
   0,                                   /* @tp_clear@ */
-  0,                                   /* @tp_richcompare@ */
+  mp_pyrichcompare,                    /* @tp_richcompare@ */
   0,                                   /* @tp_weaklistoffset@ */
   0,                                   /* @tp_iter@ */
   0,                                   /* @tp_iternext@ */
index cef4f822882d7036aaa1ce0dde7d067ba964be3d..1d4458b0239cdba71229241d77f6854b282f5db2 100644 (file)
@@ -128,6 +128,22 @@ PyObject *abstract_pynew(PyTypeObject *ty, PyObject *arg, PyObject *kw)
   return (0);
 }
 
+PyObject *enrich_compare(int op, int cmp)
+{
+  int r = -1;
+
+  switch (op) {
+    case Py_LT: r = cmp <  0; break;
+    case Py_LE: r = cmp <= 0; break;
+    case Py_EQ: r = cmp == 0; break;
+    case Py_NE: r = cmp != 0; break;
+    case Py_GE: r = cmp >= 0; break;
+    case Py_GT: r = cmp >  0; break;
+    default: assert(0);
+  }
+  return (getbool(r));
+}
+
 /*----- Saving and restoring exceptions ----------------------------------*/
 
 void report_lost_exception_v(struct excinfo *exc,
index 654f7517439ccef52003f3f9a809f9985af6f2a7..cd135075acbaaad47fef25aeca4efd6eb912691e 100644 (file)
@@ -279,6 +279,11 @@ extern PyObject *getulong(unsigned long); /* any kind of unsigned integer */
 extern PyObject *abstract_pynew(PyTypeObject *, PyObject *, PyObject *);
   /* A `tp_new' function which refuses to make the object. */
 
+extern PyObject *enrich_compare(int /*op*/, int /*cmp*/);
+  /* Use a traditional compare-against-zero comparison result CMP to answer a
+   * modern Python `tp_richcompare' operation OP.
+   */
+
 #ifndef CONVERT_CAREFULLY
 #  define CONVERT_CAREFULLY(newty, expty, obj)                         \
      (!sizeof(*(expty *)0 = (obj)) + (/*unconst*/ newty)(obj))
index ff373ca31b99008b6c96a5d84a03fd898a401cba..a88fa8d1ac80773018acb3e15c97354b7068c346 100644 (file)
--- a/t/t-mp.py
+++ b/t/t-mp.py
@@ -135,6 +135,20 @@ class TestMP (U.TestCase):
     me.assertTrue(y < x)
     me.assertFalse(x < x)
 
+  def test_float(me):
+    x, y = C.MP(169), 24.0
+    for fn in [T.add, T.sub, T.mul, T.div]:
+      me.assertEqual(type(fn(x, y)), float)
+      me.assertEqual(type(fn(y, x)), float)
+    me.assertEqual(x, 169.0)
+    me.assertNotEqual(x, 169.1)
+    me.assertNotEqual(x, 168.9)
+    me.assertTrue(x > 168.9)
+    me.assertTrue(x < 169.1)
+    z = 1.0
+    while z == z + 1: z *= 2.0
+    me.assertNotEqual(C.MP(int(z)) + 1, z)
+
   def test_bits(me):
     x, y, zero = C.MP(169), C.MP(-24), C.MP(0)
     me.assertTrue(x.testbit(0))