chiark / gitweb /
soak: Implement split, join, and set operations.
authorMark Wooding <mdw@distorted.org.uk>
Tue, 20 Aug 2024 19:18:40 +0000 (20:18 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Tue, 20 Aug 2024 19:30:58 +0000 (20:30 +0100)
Significantly more complicated.  Probably some edge-case bugs around
splitting bounds.

soak

diff --git a/soak b/soak
index 02f4487a29e4d4dbcbaadabd89d752741f1817b7..4428a70c34c4628a5382120dff326c96ae67040f 100755 (executable)
--- a/soak
+++ b/soak
@@ -4,6 +4,7 @@ import base64 as B64
 import errno as E
 import io as IO
 import math as M
 import errno as E
 import io as IO
 import math as M
+import optparse as OP
 import os as OS
 import random as RND
 import re as RX
 import os as OS
 import random as RND
 import re as RX
@@ -11,31 +12,23 @@ import sys as SYS
 import subprocess as SUB
 import time as T
 
 import subprocess as SUB
 import time as T
 
-PROG = OS.path.basename(SYS.argv[0])
-
-MAX = 4096
 SEEDSZ = 32
 SEEDSZ = 32
-CKPT_STEPS = 5000
-NSTEPS = 1000000
-SYNC = False
 
 
-def usage(file):
-  file.write("usage: %s PROG [CKPT]\n" % PROG)
+PROG = OS.path.basename(SYS.argv[0])
 
 def base64_encode(buf):
   ## No, you blundering morons, the result of Base64 encoding is text.
   return B64.b64encode(buf).decode("us-ascii")
 
 def binsearch(list, key, keyfn, lessfn):
 
 def base64_encode(buf):
   ## No, you blundering morons, the result of Base64 encoding is text.
   return B64.b64encode(buf).decode("us-ascii")
 
 def binsearch(list, key, keyfn, lessfn):
-  lo, hi = 0, len(list)
-  if not hi: return None, -1
+  n = len(list)
+  lo, hi = 0, n
   while lo < hi:
     mid = (lo + hi)//2
   while lo < hi:
     mid = (lo + hi)//2
-    if lessfn(key, keyfn(list[mid])): hi = mid
-    elif mid == lo: break
-    else: lo = mid
-  found = list[lo]
-  if lo == hi or lessfn(keyfn(found), key): found = None
+    if lessfn(keyfn(list[mid]), key): lo = mid + 1
+    else: hi = mid
+  if lo < n and not lessfn(key, keyfn(list[lo])): found = list[lo]
+  else: found = None
   return found, lo
 
 class WeightedChoice (object):
   return found, lo
 
 class WeightedChoice (object):
@@ -46,10 +39,10 @@ class WeightedChoice (object):
       acc += wt
       me._choices.append((acc - 1, opt))
     me._total = acc
       acc += wt
       me._choices.append((acc - 1, opt))
     me._total = acc
-  def choose(me):
-    i = RAND.randrange(me._total)
-    _, j = binsearch(me._choices, i,
-                     lambda pair: pair[0], lambda x, y: x < y)
+  def choose(me, rand):
+    i = rand.randrange(me._total)
+    ch, j = binsearch(me._choices, i,
+                      lambda pair: pair[0], lambda x, y: x < y)
     return me._choices[j][1]
 
 class Collection (object):
     return me._choices[j][1]
 
 class Collection (object):
@@ -62,6 +55,8 @@ class Collection (object):
     if me._set is None: me._set = set(me._list)
   def __iter__(me):
     me._ensure_list(); return iter(me._list)
     if me._set is None: me._set = set(me._list)
   def __iter__(me):
     me._ensure_list(); return iter(me._list)
+  def __len__(me):
+    me._ensure_list(); return len(me._list)
   def __contains__(me, x):
     me._ensure_set(); return x in me._set
   def add(me, x):
   def __contains__(me, x):
     me._ensure_set(); return x in me._set
   def add(me, x):
@@ -93,49 +88,140 @@ class Collection (object):
   def lower(me):
     me._ensure_list()
     if not me._list: return None
   def lower(me):
     me._ensure_list()
     if not me._list: return None
-    else: me._list[0]
+    else: return me._list[0]
   def upper(me):
     me._ensure_list()
     if not me._list: return None
   def upper(me):
     me._ensure_list()
     if not me._list: return None
-    else: me._list[-1]
-
-def write_ckpt(seed, tree):
-  newckpt = CKPT + ".new"
-  with open(newckpt, "w") as f:
-    f.write(seed); f.write("\n")
-    f.write(tree); f.write("\n")
-  OS.rename(newckpt, CKPT)
-
-TEST = None
-n = len(SYS.argv)
-if n < 2: usage(SYS.stderr); SYS.exit(2)
-testprog = SYS.argv[1]
-CKPT = None
-seed, tree = None, "_"
-if n >= 3:
-  CKPT = SYS.argv[2]
-  try:
-    with open(CKPT, "r") as f:
-      seed = f.readline()
-      tree = f.readline()
-      rest = f.readline()
-  except OSError as err:
-    if err.errno == E.ENOENT: pass
-    else: raise
+    else: return me._list[-1]
+
+class Options (object):
+  def __init__(me):
+    op = OP.OptionParser\
+         (usage = "%prog [-y] [-c STEPS] [-f FILE] [-l LIMIT] [-n STEPS] PROG")
+    for short, long, kw in \
+        [("-c", "--ckpt-steps",
+          dict(type = "int", metavar = "STEPS",
+               dest = "ckpt_steps", default = 5000,
+               help = "number of steps between checkpoints")),
+         ("-f", "--ckpt-file",
+          dict(type = "string", metavar = "FILE",
+               dest = "ckpt_file", default = "soak.ckpt",
+               help = "file to hold checkpoint information")),
+         ("-l", "--limit",
+          dict(type = "int", metavar = "LIMIT",
+               dest = "limit", default = None,
+               help = "exclusive limit value to store in test trees")),
+         ("-n", "--steps",
+          dict(type = "int", metavar = "STEPS",
+               dest = "nsteps", default = None,
+               help = "number of steps to run before stopping")),
+         ("-y", "--sync",
+          dict(action = "store_true", dest = "sync",
+               help = "check and print state after every step"))]:
+      op.add_option(short, long, **kw)
+    opts, args = op.parse_args()
+    me.limit = opts.limit
+    me.ckpt_file = opts.ckpt_file
+    me.sync = opts.sync
+    me.ckpt_steps = opts.ckpt_steps
+    me.nsteps = opts.nsteps
+    if len(args) != 1: op.print_usage(SYS.stderr); SYS.exit(2)
+    me.testprog = args[0]
+
+class Level (object):
+  def __init__(me, kind, base, limit, tree = "_"):
+    me.coll = Collection(map(int, RX.findall(r"\d+", tree)))
+    me.kind, me.base, me.limit, me.tree = kind, int(base), int(limit), tree
+    me.rlim = int(M.sqrt(me.limit - me.base))
+  def write(me, file):
+    file.write("%s %d %d %s\n" % (me.kind, me.base, me.limit, me.tree))
+  @classmethod
+  def read(cls, file):
+    line = file.readline()
+    if line == "": return None
+    kind, base, limit, tree = line.split(maxsplit = 3)
+    return cls(kind, base, limit, tree)
+
+class State (object):
+  def __init__(me, opts):
+    me._ckpt_file = opts.ckpt_file
+    try:
+      with open(me._ckpt_file, "r") as f:
+        me.seed, = f.readline().split()
+        stack = []
+        while True:
+          lv = Level.read(f)
+          if lv is None: break
+          stack.append(lv)
+      assert stack
+      me.cur = stack.pop()
+      me.stack = stack
+      if opts.limit is not None and me.cur.limit != opts.limit:
+          raise ValueError("checkpointed limit %d /= command-line limit %d" %
+                           (me.cur.limit, opts.limit))
+    except OSError as err:
+      if err.errno != E.ENOENT: raise
+      me.seed = base64_encode(OS.urandom(SEEDSZ))
+      if opts.limit is not None: me.limit = opts.limit
+      else: me.limit = 4096
+      me.stack = []
+      me.cur = Level('base', 0, me.limit)
+      me.write_ckpt(reseed = False)
+    me.rand = RND.Random(me.seed)
+    n, b = 0, me.cur.limit
+    while True:
+      bb = int(M.sqrt(b)) + 4
+      if bb >= b: break
+      n, b = n + 1, bb
+    me.stklim = n
+  def push(me, lv):
+    me.stack.append(me.cur)
+    me.cur = lv
+  def pop(me):
+    assert me.stack
+    lv = me.cur
+    me.cur = me.stack.pop()
+    return lv
+  def write_ckpt(me, reseed = True):
+    if reseed:
+      me.seed = base64_encode(bytes(me.rand.randrange(256)
+                                    for _ in range(SEEDSZ)))
+      me.rand.seed(me.seed)
+    new = me._ckpt_file + ".new"
+    with open(new, "w") as f:
+      f.write("%s\n" % me.seed)
+      for lv in me.stack: lv.write(f)
+      me.cur.write(f)
+    OS.rename(new, me._ckpt_file)
+  def clear_ckpt(me):
+    try: OS.unlink(me._ckpt_file)
+    except OSError as err:
+      if err.errno == E.ENOENT: pass
+      else: raise
+
+def choices():
+  ch = [(896, "addrm1"),
+        (56, "addrmn"),
+        (56, "lookup")]
+
+  sp = len(ST.stack)
+  ch += [(ST.stklim - sp, "split"),
+         (ST.stklim - sp, "push")]
+
+  if ST.cur.kind == "join":
+    ch += [(sp, "join0"), (sp, "join1")]
+  elif ST.cur.kind == "setop":
+    ch += [(sp, "unisect"), (sp, "diffsect")]
+  elif ST.cur.kind == "base":
+    pass
   else:
   else:
-    if rest: raise ValueError("trailing junk in checkpoint file")
-    seed, tree = seed.strip(), tree.strip()
-    SYNC = True
-if seed is None:
-  seed = base64_encode(OS.urandom(SEEDSZ))
-  write_ckpt(seed, tree)
-RAND = RND.Random(seed)
-COLL = Collection(map(int, RX.findall(r"\d+", tree)))
-
-CHOICES = WeightedChoice([(900, "addrm1"),
-                          (50, "addrmn"),
-                          (50, "lookup")])
-STEP = 0
+    raise ValueError("unknown level kind `%s'" % ST.cur.kind)
+
+  return WeightedChoice(ch)
+
+OPTS = Options()
+ST = State(OPTS)
+KID = SUB.Popen([OPTS.testprog], stdin = SUB.PIPE, stdout = SUB.PIPE)
 
 def fail(msg):
   SYS.stderr.write("%s: FAILED: %s\n" % (PROG, msg))
 
 def fail(msg):
   SYS.stderr.write("%s: FAILED: %s\n" % (PROG, msg))
@@ -146,114 +232,176 @@ def fail(msg):
   SYS.stderr.write("%s:   exit status = %d\n" % (PROG, rc))
   SYS.exit(2)
 
   SYS.stderr.write("%s:   exit status = %d\n" % (PROG, rc))
   SYS.exit(2)
 
-KID = SUB.Popen([testprog], stdin = SUB.PIPE, stdout = SUB.PIPE)
-STACK = []
-step = 0
-LO, HI = 0, MAX
-RLIM = int(M.sqrt(MAX))
-
 def put(msg, echo = True):
   try: KID.stdin.write(msg.encode()); KID.stdin.flush()
   except OSError as err: fail("write failed: %s" % err)
 def put(msg, echo = True):
   try: KID.stdin.write(msg.encode()); KID.stdin.flush()
   except OSError as err: fail("write failed: %s" % err)
-  if SYNC and echo: SYS.stdout.write("$ " + msg); SYS.stdout.flush()
+  if OPTS.sync and echo: SYS.stdout.write("$ " + msg); SYS.stdout.flush()
 def get(echo = True):
   try: line = KID.stdout.readline().decode()
   except OSError as err: fail("read failed: %s" % err)
   if line == "": fail("unexpected end of file")
 def get(echo = True):
   try: line = KID.stdout.readline().decode()
   except OSError as err: fail("read failed: %s" % err)
   if line == "": fail("unexpected end of file")
-  if SYNC and echo: SYS.stdout.write(line)
+  if OPTS.sync and echo: SYS.stdout.write(line)
   if line[-1] == "\n": return line[:-1]
   else: return line
 
   if line[-1] == "\n": return line[:-1]
   else: return line
 
-put("= %s\n" % tree)
+def dump_tree():
+  if OPTS.sync:
+    put("D\n:;;END DUMP\n", echo = False)
+    while True:
+      line = get(echo = False)
+      if line == ";;END DUMP": break
+      SYS.stdout.write(line); SYS.stdout.write("\n")
+
+def check_tree():
+  put("i\n", echo = False)
+  line = get(echo = False)
+  if ST.cur.coll: ref = " ".join("%d" % i for i in ST.cur.coll)
+  else: ref = "(empty tree)"
+  if line != ref: fail("iteration mismatch: %s /= %s" % (line, ref))
+  put("!:;;END CHECK\n", echo = False)
+  line = get(echo = False)
+  if line != ";;END CHECK": fail("unexpected output: `%s'" % line)
+
+def snapshot():
+  put("L\n", echo = False)
+  ST.cur.tree = get(echo = False)
 
 
-T0 = T.time()
-TCKPT = 0
+for lv in ST.stack:
+  put("= %s\n" % lv.tree)
+  dump_tree()
+  put("(\n")
+put("= %s\n" % ST.cur.tree)
+dump_tree()
 
 
-while NSTEPS is None or STEP < NSTEPS:
-  if SYNC: SYS.stdout.write("\n;; step %d\n" % STEP)
-  op = CHOICES.choose()
+STEP = 0; nsteps = OPTS.nsteps
+ch = choices()
+while nsteps is None or STEP < nsteps:
+  if OPTS.sync: SYS.stdout.write("\n;; step %d\n" % STEP)
+  op = ch.choose(ST.rand)
 
   if op == "addrm1":
 
   if op == "addrm1":
-    k = RAND.randrange(LO, HI)
-    if k in COLL: COLL.remove(k); put("%d-\n" % k)
-    else: COLL.add(k); put("%d+\n" % k)
+    k = ST.rand.randrange(ST.cur.base, ST.cur.limit)
+    if k in ST.cur.coll: ST.cur.coll.remove(k); put("%d-\n" % k)
+    else: ST.cur.coll.add(k); put("%d+\n" % k)
 
   elif op == "addrmn":
 
   elif op == "addrmn":
-    n = RAND.randrange(RLIM)
-    i = RAND.randrange(LO, HI - n)
+    n = ST.rand.randrange(ST.cur.rlim)
+    i = ST.rand.randrange(ST.cur.base, ST.cur.limit - n)
     m = i + n//2
     m = i + n//2
-    foundp = m in COLL
-    dir = RAND.choice([+1, -1])
+    foundp = m in ST.cur.coll
+    dir = ST.rand.choice([+1, -1])
     rr = range(i, i + n)
     if dir < 0: rr = reversed(rr)
     firstp = True
     buf = IO.StringIO()
     if foundp:
       for j in rr:
     rr = range(i, i + n)
     if dir < 0: rr = reversed(rr)
     firstp = True
     buf = IO.StringIO()
     if foundp:
       for j in rr:
-        if j not in COLL: continue
+        if j not in ST.cur.coll: continue
         if firstp: firstp = False
         else: buf.write(" ")
         if firstp: firstp = False
         else: buf.write(" ")
-        COLL.remove(j); buf.write("%d-" % j)
+        ST.cur.coll.remove(j); buf.write("%d-" % j)
     else:
       for j in rr:
     else:
       for j in rr:
-        if j in COLL: continue
+        if j in ST.cur.coll: continue
         if firstp: firstp = False
         else: buf.write(" ")
         if firstp: firstp = False
         else: buf.write(" ")
-        COLL.add(j); buf.write("%d+" % j)
+        ST.cur.coll.add(j); buf.write("%d+" % j)
     if not firstp: buf.write("\n")
     put(buf.getvalue())
 
   elif op == "lookup":
     if not firstp: buf.write("\n")
     put(buf.getvalue())
 
   elif op == "lookup":
-    k = RAND.randrange(LO, HI)
+    k = ST.rand.randrange(ST.cur.base, ST.cur.limit)
     put("%d? n\n" % k)
     line = get()
     if line == "(nil)":
     put("%d? n\n" % k)
     line = get()
     if line == "(nil)":
-      if k in COLL: fail("key %d unexpectedly missing" % k)
+      if k in ST.cur.coll: fail("key %d unexpectedly missing" % k)
     else:
       m = RX.match(r"^#<node #0x([0-9a-f]{8}) (\d+)>", line)
       if m:
         kk = int(m.group(2))
         if kk != k: fail("search for key %d found %d instead" % (k, kk))
     else:
       m = RX.match(r"^#<node #0x([0-9a-f]{8}) (\d+)>", line)
       if m:
         kk = int(m.group(2))
         if kk != k: fail("search for key %d found %d instead" % (k, kk))
-        elif k not in COLL: fail("key %d unexpectedly found" % k)
+        elif k not in ST.cur.coll: fail("key %d unexpectedly found" % k)
       else:
         fail("unexpected response to lookup: `%s'" % line)
 
       else:
         fail("unexpected response to lookup: `%s'" % line)
 
+  elif op == "split":
+    check_tree()
+    k = ST.rand.randrange(ST.cur.base, ST.cur.limit - ST.cur.rlim - 4)
+    put("%d@ /\n" % k)
+    left, mid, right = ST.cur.coll.split(k)
+    ST.cur.coll = left
+    old_limit, ST.cur.limit = ST.cur.limit, k
+    ST.push(Level("split.mid", k, k + 1, str(mid)))
+    ST.push(Level("split.right", k + 1, old_limit,
+                  " ".join("%d" % x for x in right)))
+    dump_tree(); check_tree(); put(")\n"); ST.pop()
+    dump_tree(); check_tree(); put(")\n"); ST.pop()
+    dump_tree(); check_tree(); snapshot()
+    put("(\n")
+    if ST.cur.coll: new_base = ST.cur.coll.upper() + 2
+    else: new_base = ST.cur.base + 2
+    ST.push(Level("join", new_base, old_limit, "_"))
+    ch = choices()
+
+  elif op == "push":
+    check_tree(); snapshot()
+    put("(\n"); ST.push(Level("setop", ST.cur.base, ST.cur.limit, "_"))
+    ST.stack[-1].limit = ST.cur.limit
+    ch = choices()
+
+  elif op == "join0":
+    lower = ST.stack[-1].coll
+    put("* ~\n")
+    ST.stack[-1].coll = lower.join(None, ST.cur.coll)
+    ST.stack[-1].limit = ST.cur.limit
+    ST.pop()
+    ch = choices()
+
+  elif op == "join1":
+    lower = ST.stack[-1].coll
+    if lower: base = lower.upper() + 1
+    else: base = ST.stack[-1].base + 1
+    if ST.cur.coll: limit = ST.cur.coll.lower()
+    else: limit = ST.cur.limit
+    k = ST.rand.randrange(base, limit)
+    put("%d ~\n" % k)
+    ST.stack[-1].coll = lower.join(k, ST.cur.coll)
+    ST.stack[-1].limit = ST.cur.limit
+    ST.pop()
+    ch = choices()
+
+  elif op == "unisect":
+    put("|\n")
+    ST.cur.coll, ST.stack[-1].coll = ST.cur.coll.unisect(ST.stack[-1].coll)
+    dump_tree(); check_tree(); put(")\n"); ST.pop()
+    ch = choices()
+
+  elif op == "diffsect":
+    put("\\\n")
+    diff, ST.cur.coll = ST.cur.coll.diffsect(ST.stack[-1].coll)
+    ST.push(Level("diffsect.diff", ST.cur.base, ST.cur.limit,
+                  " ".join("%d" % x for x in diff)))
+    dump_tree(); check_tree(); put(")\n"); ST.pop()
+    dump_tree(); check_tree(); put(")\n"); ST.pop()
+    ch = choices()
+
   else:
     raise ValueError("unexpected operation `%s'" % op)
 
   else:
     raise ValueError("unexpected operation `%s'" % op)
 
-  t0 = T.time()
   STEP += 1
   STEP += 1
-  if SYNC:
-    put("D\n:;;END DUMP\n", echo = False)
-    while True:
-      line = get(echo = False)
-      if line == ";;END DUMP": break
-      SYS.stdout.write(line); SYS.stdout.write("\n")
-  if SYNC or STEP == CKPT_STEPS:
-    put("i\n", echo = False)
-    line = get(echo = False)
-    ref = " ".join("%d" % i for i in COLL)
-    if line != ref: fail("iteration mismatch: %s /= %s" % (line, ref))
-    put("!:;;END CHECK\n", echo = False)
-    line = get(echo = False)
-    if line != ";;END CHECK": fail("unexpected output: `%s'" % line)
-  if STEP == CKPT_STEPS:
-    put("L\n", echo = False)
-    tree = get(echo = False)
-    seed = base64_encode(bytes(RAND.randrange(256) for _ in range(SEEDSZ)))
-    RAND.seed(seed)
-    if CKPT is not None: write_ckpt(seed, tree)
-    STEP = 0; NSTEPS -= CKPT_STEPS; SYNC = False
-  t1 = T.time()
-  TCKPT += t1 - t0
-
-if not SYNC: SYS.stdout.write("\n")
-
-T1 = T.time()
-SYS.stderr.write(";; time = %s; checkpointing = %s = %s\n" %
-                 (T1 - T0, TCKPT, TCKPT/(T1 - T0)))
-if CKPT:
-  try: OS.unlink(CKPT)
-  except OSError as err:
-    if err.errno == E.ENOENT: pass
-    else: raise
+  dump_tree()
+  if OPTS.sync or STEP == OPTS.ckpt_steps: check_tree()
+  if STEP == OPTS.ckpt_steps:
+    snapshot()
+    ST.write_ckpt()
+    STEP = 0; OPTS.sync = False
+    if nsteps is not None: nsteps -= OPTS.ckpt_steps
+
+while True:
+  check_tree()
+  if not ST.stack: break
+  put(")\n")
+  dump_tree()
+  check_tree()
+  ST.pop()
+ST.clear_ckpt()