chiark / gitweb /
rhodes: Open lock files with the right permissions for exclusive locking.
[rhodes] / rhodes
1 #! /usr/bin/python
2
3 from sys import argv, stdout, stderr, exit
4 import errno as E
5 import fcntl as F
6 import os as OS
7 import subprocess as S
8 import select as SEL
9 import signal as SIG
10
11 import catacomb as C
12 import sqlite3 as SQL
13
14 class ExpectedError (Exception):
15   pass
16
17 CONNINIT_SQL = """
18 PRAGMA foreign_keys = on;
19 """
20
21 SETUP_SQL = """
22 PRAGMA journal_mode = wal;
23
24 CREATE TABLE top
25         (kind TEXT NOT NULL,            -- `gf2x'
26          groupdesc TEXT NOT NULL,
27          g TEXT NOT NULL,
28          x TEXT NOT NULL,
29          m TEXT NOT NULL,               -- g^m = 1
30          n TEXT DEFAULT NULL);          -- g^n = x
31
32 CREATE TABLE progress
33         (p TEXT PRIMARY KEY NOT NULL,   -- p|m, p prime
34          e INT NOT NULL,                -- e = v_p(m)
35          k INT NOT NULL DEFAULT(0),     -- 0 <= k <= e
36          n TEXT NOT NULL DEFAULT(0),    -- (g^{m/p^k})^n = x^{m/p^k}
37          dpbits INT NOT NULL);          -- 0 for sequential
38 CREATE UNIQUE INDEX progress_by_p_k ON progress (p, k);
39
40 CREATE TABLE workers
41         (pid INT PRIMARY KEY NOT NULL,
42          p TEXT NOT NULL,
43          k INT NOT NULL,
44          FOREIGN KEY (p, k) REFERENCES progress (p, k));
45 CREATE INDEX workers_by_p ON workers (p, k);
46
47 CREATE TABLE points
48         (p TEXT NOT NULL,
49          k INT NOT NULL,
50          z TEXT NOT NULL,               -- g^a x^b = z
51          a TEXT NOT NULL,
52          b TEXT NOT NULL,
53          PRIMARY KEY (p, k, z),
54          FOREIGN KEY (p, k) REFERENCES progress (p, k));
55 """
56
57 GROUPMAP = {}
58
59 class GroupClass (type):
60   def __new__(cls, name, supers, dict):
61     ty = super(GroupClass, cls).__new__(cls, name, supers, dict)
62     try: name = ty.NAME
63     except AttributeError: pass
64     else: GROUPMAP[name] = ty
65     return ty
66
67 class BaseGroup (object):
68   __metaclass__ = GroupClass
69   def __init__(me, desc):
70     me.desc = desc
71   def div(me, x, y):
72     return me.mul(x, me.inv(y))
73
74 class BinaryFieldUnitGroup (BaseGroup):
75   NAME = 'gf2x'
76   def __init__(me, desc):
77     super(BinaryFieldUnitGroup, me).__init__(desc)
78     p = C.GF(desc)
79     if not p.irreduciblep(): raise ExpectedError, 'not irreducible'
80     me._k = C.BinPolyField(p)
81     me.order = me._k.q - 1
82   def elt(me, x):
83     return me._k(C.GF(x))
84   def pow(me, x, n):
85     return x**n
86   def mul(me, x, y):
87     return x*y
88   def inv(me, x):
89     return x.inv()
90   def idp(me, x):
91     return x == me._k.one
92   def eq(me, x, y):
93     return x == y
94   def str(me, x):
95     return str(x)
96
97 def getgroup(kind, desc): return GROUPMAP[kind](desc)
98
99 def factor(n):
100   ff = []
101   proc = S.Popen(['./factor', str(n)], stdout = S.PIPE)
102   for line in proc.stdout:
103     pstr, estr = line.split()
104     ff.append((C.MP(pstr), int(estr)))
105   rc = proc.wait()
106   if rc: raise ExpectedError, 'factor failed: rc = %d' % rc
107   return ff
108
109 CMDMAP = {}
110
111 def defcommand(f, name = None):
112   if isinstance(f, basestring):
113     return lambda g: defcommand(g, f)
114   else:
115     if name is None: name = f.__name__
116     CMDMAP[name] = f
117     return f
118
119 def connect_db(dir):
120   db = SQL.connect(OS.path.join(dir, 'db'))
121   db.text_factory = str
122   c = db.cursor()
123   c.executescript(CONNINIT_SQL)
124   return db
125
126 @defcommand
127 def setup(dir, kind, groupdesc, gstr, xstr):
128
129   ## Get the group.  This will also figure out the group order.
130   G = getgroup(kind, groupdesc)
131
132   ## Figure out the generator order.
133   g = G.elt(gstr)
134   x = G.elt(xstr)
135   ff = []
136   m = G.order
137   for p, e in factor(m):
138     ee = 0
139     for i in xrange(e):
140       mm = m/p
141       t = G.pow(g, mm)
142       if not G.idp(t): break
143       ee += 1; m = mm
144     if ee < e: ff.append((p, e - ee))
145
146   ## Check that x at least has the right order.  This check is imperfect.
147   if not G.idp(G.pow(x, m)): raise ValueError, 'x not in <g>'
148
149   ## Prepare the directory.
150   try: OS.mkdir(dir)
151   except OSError, err: raise ExpectedError, 'mkdir: %s' % err
152
153   ## Prepare the database.
154   db = connect_db(dir)
155   c = db.cursor()
156   c.executescript(SETUP_SQL)
157
158   ## Populate the general information.
159   with db:
160     c.execute("""INSERT INTO top (kind, groupdesc, g, x, m)
161                  VALUES (?, ?, ?, ?, ?)""",
162               (kind, groupdesc, G.str(g), G.str(x), str(m)))
163     for p, e in ff:
164       if p.nbits <= 48: dpbits = 0
165       else: dpbits = p.nbits*2/5
166       c.execute("""INSERT INTO progress (p, e, dpbits) VALUES (?, ?, ?)""",
167                 (str(p), e, dpbits))
168
169 def get_top(db):
170   c = db.cursor()
171   c.execute("""SELECT kind, groupdesc, g, x, m, n FROM top""")
172   kind, groupdesc, gstr, xstr, mstr, nstr = c.fetchone()
173   G = getgroup(kind, groupdesc)
174   g, x, m = G.elt(gstr), G.elt(xstr), C.MP(mstr)
175   n = nstr is not None and C.MP(nstr) or None
176   return G, g, x, m, n
177
178 @defcommand
179 def check(dir):
180   rc = [0]
181   def bad(msg):
182     print >>stderr, '%s: %s' % (PROG, msg)
183     rc[0] = 3
184   db = connect_db(dir)
185   c = db.cursor()
186   G, g, x, m, n = get_top(db)
187   print '## group: %s %s' % (G.NAME, G.desc)
188   print '## g = %s' % G.str(g)
189   print '## x = %s' % G.str(x)
190
191   if not G.idp(G.pow(g, m)):
192     bad('bad generator/order: %s^%d /= 1' % (G.str(g), m))
193   if not G.idp(G.pow(x, m)):
194     bad('x not in group: %s^%d /= 1' % (G.str(x), m))
195
196   ## Clear away old workers that aren't doing anything useful any more.
197   ## For each worker pid, check that its lockfile is still locked; if
198   ## not, it's finished and can be disposed of.
199   c.execute("""SELECT pid FROM workers""")
200   for pid, in c:
201     maybe_cleanup_worker(dir, db, pid)
202   for f in OS.listdir(dir):
203     if f.startswith('lk.'):
204       pid = int(f[3:])
205       maybe_cleanup_worker(dir, db, pid)
206
207   c.execute("""SELECT p.p, p.e, p.k, p.n, p.dpbits, COUNT(d.z)
208                FROM progress AS p LEFT OUTER JOIN points AS d
209                ON p.p = d.p AND p.k = d.k
210                GROUP BY p.p, p.k
211                ORDER BY LENGTH(p.p), p.p""")
212   mm = 1
213   for pstr, e, k, nnstr, dpbits, ndp in c:
214     p, nn = C.MP(pstr), C.MP(nnstr)
215     q = p**e
216     if m%q:
217       bad('incorrect order factorization: %d^%d /| %d' % (p, e, m))
218     mm *= q
219     if G.idp(G.pow(g, m/p)):
220       bad('bad generator/order: %s^{%d/%d} = 1' ^ (G.str(g), m, p))
221     r = m/p**k
222     h = G.pow(g, r*nn)
223     y = G.pow(x, r)
224     if not G.eq(h, y):
225       bad('bad partial log: (%s^{%d/%d^%d})^%d = %s /= %s = %s^{%d/%d^%d}' %
226           (G.str(g), m, p, k, nn, G.str(h), G.str(y), G.str(x), m, p, k))
227     if not dpbits or k == e: dpinfo = ''
228     else: dpinfo = ' [%d: %d]' % (dpbits, ndp)
229     print '## %d: %d/%d%s' % (p, k, e, dpinfo)
230   if mm != m:
231     bad('incomplete factorization: %d /= %d' % (mm, m))
232
233   if n is not None:
234     xx = G.pow(g, n)
235     if not G.eq(xx, x):
236       bad('incorrect log: %s^%d = %s /= %s' %
237           (G.str(g), n, G.str(xx), G.str(x)))
238     print '## DONE: %d' % n
239
240   exit(rc[0])
241
242 def get_job(db):
243   c = db.cursor()
244   c.execute("""SELECT p.p, p.e, p.k, p.n, p.dpbits
245                FROM progress AS p LEFT OUTER JOIN workers AS w
246                        ON p.p = w.p and p.k = w.k
247                WHERE p.k < p.e AND (p.dpbits > 0 OR w.pid IS NULL)
248                LIMIT 1""")
249   row = c.fetchone()
250   if row is None: return None, None, None, None, None
251   else:
252     pstr, e, k, nstr, dpbits = row
253     p, n = C.MP(pstr), C.MP(nstr)
254     return p, e, k, n, dpbits
255
256 @defcommand
257 def done(dir):
258   db = connect_db(dir)
259   c = db.cursor()
260   G, g, x, m, n = get_top(db)
261   if n is not None:
262     print '## DONE: %d' % n
263     exit(0)
264   p, e, k, n, dpbits = get_job(db)
265   if p is None: exit(2)
266   else: exit(1)
267
268 def maybe_cleanup_worker(dir, db, pid):
269   c = db.cursor()
270   f = OS.path.join(dir, 'lk.%d' % pid)
271   state = 'LIVE'
272   try: fd = OS.open(f, OS.O_WRONLY)
273   except OSError, err:
274     if err.errno != E.ENOENT: raise ExpectedError, 'open lockfile: %s' % err
275     state = 'STALE'
276   else:
277     try: F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
278     except IOError, err:
279       if err.errno != E.EAGAIN: raise ExpectedError, 'check lock: %s' % err
280     else:
281       state = 'STALE'
282   if state == 'STALE':
283     try: OS.unlink(f)
284     except OSError: pass
285     c.execute("""DELETE FROM workers WHERE pid = ?""", (pid,))
286
287 def maybe_kill_worker(dir, pid):
288   f = OS.path.join(dir, 'lk.%d' % pid)
289   try: fd = OS.open(f, OS.O_RDWR)
290   except OSError, err:
291     if err.errno != E.ENOENT: raise ExpectedError, 'open lockfile: %s' % err
292     return
293   try: F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
294   except IOError, err:
295     if err.errno != E.EAGAIN: raise ExpectedError, 'check lock: %s' % err
296   else: return
297   OS.kill(pid, SIG.SIGTERM)
298   try: OS.unlink(f)
299   except OSError: pass
300
301 @defcommand
302 def step(dir, cmd, *args):
303
304   ## Open the database.
305   db = connect_db(dir)
306   c = db.cursor()
307   ##db.isolation_level = 'EXCLUSIVE'
308
309   ## Prepare our lockfile names.
310   mypid = OS.getpid()
311   nlk = OS.path.join(dir, 'nlk.%d' % mypid)
312   lk = OS.path.join(dir, 'lk.%d' % mypid)
313
314   ## Overall exception handling...
315   try:
316
317     ## Find out what needs doing and start doing it.  For this, we open a
318     ## transaction.
319     with db:
320       G, g, x, m, n = get_top(db)
321       if n is not None: raise ExpectedError, 'job done'
322
323       ## Find something to do.  Either a job that's small enough for us to
324       ## take on alone, and that nobody else has picked up yet, or one that
325       ## everyone's pitching in on.
326       p, e, k, n, dpbits = get_job(db)
327       if p is None: raise ExpectedError, 'no work to do'
328
329       ## Figure out what needs doing.  Let q = p^e, h = g^{m/q}, y = x^{m/q}.
330       ## Currently we have n_0 where
331       ##
332       ##    h^{p^{e-k} n_0} = y^{p^{e-k}}
333       ##
334       ## Suppose n == n_0 + p^k n' (mod p^{k+1}).  Then p^k n' == n - n_0
335       ## (mod p^{k+1}).
336       ##
337       ##    (h^{p^{e-1}})^{n'} = (g^{m/p})^{n'}
338       ##                       = (y/h^{n_0})^{p^{e-k-1}}
339       ##
340       ## so this is the next discrete log to solve.
341       q = p**e
342       o = m/q
343       h, y = G.pow(g, o), G.pow(x, o)
344       hh = G.pow(h, p**(e-1))
345       yy = G.pow(G.div(y, G.pow(h, n)), p**(e-k-1))
346
347       ## Take out a lockfile.
348       fd = OS.open(nlk, OS.O_WRONLY | OS.O_CREAT, 0700)
349       F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
350       OS.rename(nlk, lk)
351
352       ## Record that we're working in the database.  This completes our
353       ## initial transaction.
354       c.execute("""INSERT INTO workers (pid, p, k) VALUES (?, ?, ?)""",
355                 (mypid, str(p), k))
356
357     ## Before we get too stuck in, check for an easy case.
358     if G.idp(yy):
359       dpbits = 0 # no need for distinguished points
360       nn = 0; ni = 0
361     else:
362
363       ## There's nothing else for it.  Start the job up.
364       proc = S.Popen([cmd] + list(args) +
365                      [str(dpbits), G.NAME, G.desc,
366                       G.str(hh), G.str(yy), str(p)],
367                      stdin = S.PIPE, stdout = S.PIPE)
368       f_in, f_out = proc.stdin.fileno(), proc.stdout.fileno()
369
370       ## Now we must look after it until it starts up.  Feed it stuff on stdin
371       ## periodically, so that we notice if our network connectivity is lost.
372       ## Collect its stdout.
373       for fd in [f_in, f_out]:
374         fl = F.fcntl(fd, F.F_GETFL)
375         F.fcntl(fd, F.F_SETFL, fl | OS.O_NONBLOCK)
376       done = False
377       out = ''
378       while not done:
379         rdy, wry, exy = SEL.select([f_out], [], [], 30.0)
380         if rdy:
381           while True:
382             try: b = OS.read(f_out, 4096)
383             except OSError, err:
384               if err.errno == E.EAGAIN: break
385               else: raise ExpectedError, 'read job: %s' % err
386             else:
387               if not len(b): done = True; break
388               else: out += b
389         if not done:
390           try: OS.write(f_in, '.')
391           except OSError, err: raise ExpectedError, 'write job: %s' % err
392       rc = proc.wait()
393       if rc: raise ExpectedError, 'job failed: rc = %d' % rc
394
395       ## Parse the answer.  There are two cases.
396       if not dpbits:
397         nnstr, nistr = out.split()
398         nn, ni = C.MP(nnstr), int(nistr)
399       else:
400         astr, bstr, zstr, nistr = out.split()
401         a, b, z, ni = C.MP(astr), C.MP(bstr), G.elt(zstr), int(nistr)
402
403     ## We have an answer.  Start a new transaction while we think about what
404     ## this means.
405     with db:
406
407       if dpbits:
408
409         ## Check that it's a correct point.
410         zz = G.mul(G.pow(hh, a), G.pow(yy, b))
411         if not G.eq(zz, z):
412           raise ExpectedError, \
413               'job incorrect distinguished point: %s^%d %s^%d = %s /= %s' % \
414               (hh, a, yy, b, zz, z)
415
416         ## Report this (partial) success.
417         print '## [%d, %d/%d: %d]: %d %d -> %s [%d]' % \
418             (p, k, e, dpbits, a, b, G.str(z), ni)
419
420         ## If it's already in the database then we have an answer to the
421         ## problem.
422         c.execute("""SELECT a, b FROM points
423                      WHERE p = ? AND k = ? AND z = ?""",
424                   (str(p), k, str(z)))
425         row = c.fetchone()
426         if row is None:
427           nn = None
428           c.execute("""INSERT INTO points (p, k, a, b, z)
429                         VALUES (?, ?, ?, ?, ?)""",
430                     (str(p), str(k), str(a), str(b), G.str(z)))
431         else:
432           aastr, bbstr = row
433           aa, bb = C.MP(aastr), C.MP(bbstr)
434           if not (b - bb)%p:
435             raise ExpectedError, 'duplicate point :-('
436
437           ## We win!
438           nn = ((a - aa)*p.modinv(bb - b))%p
439           c.execute("""SELECT COUNT(z) FROM points WHERE p = ? AND k = ?""",
440                     (str(p), k))
441           ni, = c.fetchone()
442           print '## [%s, %d/%d: %d] collision %d %d -> %s <- %s %s [#%d]' % \
443               (p, k, e, dpbits, a, b, G.str(z), aa, bb, ni)
444
445       ## If we don't have a final answer then we're done.
446       if nn is None: return
447
448       ## Check that the log we've recovered is correct.
449       yyy = G.pow(hh, nn)
450       if not G.eq(yyy, yy):
451         raise ExpectedError, 'recovered incorrect log: %s^%d = %s /= %s' % \
452             (G.str(hh), nn, G.str(yyy), G.str(yy))
453
454       ## Update the log for this prime power.
455       n += nn*p**k
456       k += 1
457
458       ## Check that this is also correct.
459       yyy = G.pow(h, n*p**(e-k))
460       yy = G.pow(y, p**(e-k))
461       if not G.eq(yyy, yy):
462         raise ExpectedError, 'lifted incorrect log: %s^d = %s /= %s' % \
463             (G.str(h), n, G.str(yyy), G.str(yy))
464
465       ## Kill off the other jobs working on this component.  If we crash now,
466       ## we lose a bunch of work. :-(
467       c.execute("""SELECT pid FROM workers WHERE p = ? AND k = ?""",
468                 (str(p), k))
469       for pid, in c: maybe_kill_worker(dir, pid)
470       c.execute("""DELETE FROM workers WHERE p = ? AND k = ?""",
471                 (str(p), k - 1))
472       c.execute("""DELETE FROM points WHERE p = ? AND k = ?""",
473                 (str(p), k - 1))
474
475       ## Looks like we're good: update the progress table.
476       c.execute("""UPDATE progress SET k = ?, n = ? WHERE p = ?""",
477                 (k, str(n), str(p)))
478       print '## [%d, %d/%d]: %d [%d]' % (p, k, e, n, ni)
479
480       ## Quick check: are we done now?
481       c.execute("""SELECT p FROM progress WHERE k < e
482                    LIMIT 1""")
483       row = c.fetchone()
484       if row is None:
485
486         ## Wow.  Time to stitch everything together.
487         c.execute("""SELECT p, e, n FROM progress""")
488         qq, nn = [], []
489         for pstr, e, nstr in c:
490           p, n = C.MP(pstr), C.MP(nstr)
491           qq.append(p**e)
492           nn.append(n)
493         if len(qq) == 1: n = nn[0]
494         else: n = C.MPCRT(qq).solve(nn)
495
496         ## One last check that this is the right answer.
497         xx = G.pow(g, n)
498         if not G.eq(x, xx):
499           raise ExpectedError, \
500               'calculated incorrect final log: %s^d = %s /= %s' \
501               (G.str(g), n, G.str(xx), G.str(x))
502
503         ## We're good.
504         c.execute("""UPDATE top SET n = ?""", (str(n),))
505         print '## DONE: %d' % n
506
507   finally:
508
509     ## Delete our lockfile.
510     for f in [nlk, lk]:
511       try: OS.unlink(f)
512       except OSError: pass
513
514     ## Unregister from the database.
515     with db:
516       c.execute("""DELETE FROM workers WHERE pid = ?""", (mypid,))
517
518 PROG = argv[0]
519
520 try:
521   CMDMAP[argv[1]](*argv[2:])
522 except ExpectedError, err:
523   print >>stderr, '%s: %s' % (PROG, err.message)
524   exit(3)