chiark / gitweb /
Merge remote-tracking branch 'origin/HEAD'
[catacomb-python] / rand.c
diff --git a/rand.c b/rand.c
index 54027370dc1be174e4f98e29f5c32dc889ea8718..6fe78bc758990ac36022447731243d3cc1fd714e 100644 (file)
--- a/rand.c
+++ b/rand.c
@@ -583,6 +583,12 @@ static PyTypeObject *gccrand_pytype, *gcrand_pytype, *gclatinrand_pytype;
 typedef grand *gcrand_func(const void *, size_t sz);
 typedef grand *gcirand_func(const void *, size_t sz, uint32);
 typedef grand *gcnrand_func(const void *, size_t sz, const void *);
+typedef grand *gcshakerand_func(const void *, size_t,
+                               const void *, size_t,
+                               const void *, size_t);
+typedef grand *gcshafuncrand_func(const void *, size_t,
+                                 const void *, size_t);
+typedef grand *gckmacrand_func(const void *, size_t, const void *, size_t);
 typedef struct gccrand_info {
   const char *name;
   const octet *keysz;
@@ -671,6 +677,34 @@ end:
   return (0);
 }
 
+static PyObject *gcshakyrand_pynew(PyTypeObject *ty,
+                                  PyObject *arg, PyObject *kw)
+{
+  const gccrand_info *info = GCCRAND_INFO(ty);
+  static char *kwlist_shake[] = { "key", "func", "perso", 0 };
+  static char *kwlist_func[] = { "key", "perso", 0 };
+  char *k, *f = 0, *p = 0;
+  Py_ssize_t ksz, fsz = 0, psz = 0;
+
+  if ((info->f&RNGF_MASK) == RNG_SHAKE
+       ? !PyArg_ParseTupleAndKeywords(arg, kw, "s#|s#s#:new", kwlist_shake,
+                                      &k, &ksz, &f, &fsz, &p, &psz)
+       : !PyArg_ParseTupleAndKeywords(arg, kw, "s#|s#:new", kwlist_func,
+                                      &k, &ksz, &p, &psz))
+    goto end;
+  if (keysz(ksz, info->keysz) != ksz) VALERR("bad key length");
+  return (grand_dopywrap(ty,
+                        (info->f&RNGF_MASK) == RNG_SHAKE
+                          ? ((gcshakerand_func *)info->func)(f, fsz,
+                                                             p, psz,
+                                                             k, ksz)
+                          : ((gcshafuncrand_func *)info->func)(p, psz,
+                                                               k, ksz),
+                        f_freeme));
+end:
+  return (0);
+}
+
 static PyObject *gccrand_pywrap(const gccrand_info *info)
 {
   gccrand_pyobj *g = newtype(gccrand_pytype, 0, info->name);
@@ -689,6 +723,8 @@ static PyObject *gccrand_pywrap(const gccrand_info *info)
   switch (info->f&RNGF_MASK) {
     case RNG_LATIN: g->ty.ht_type.tp_new = gcnrand_pynew; break;
     case RNG_SEAL: g->ty.ht_type.tp_new = gcirand_pynew; break;
+    case RNG_SHAKE: case RNG_KMAC:
+      g->ty.ht_type.tp_new = gcshakyrand_pynew; break;
     default: g->ty.ht_type.tp_new = gcrand_pynew; break;
   }
   typeready(&g->ty.ht_type);