chiark / gitweb /
Add commentary and licence notices.
[rhodes] / rhodes
1 #! /usr/bin/python
2 ### -*-python-*-
3 ###
4 ### Calculate discrete logs in groups
5 ###
6 ### (c) 2017 Mark Wooding
7 ###
8
9 ###----- Licensing notice ---------------------------------------------------
10 ###
11 ### This file is part of Rhodes, a distributed discrete-log finder.
12 ###
13 ### Rhodes is free software; you can redistribute it and/or modify
14 ### it under the terms of the GNU General Public License as published by
15 ### the Free Software Foundation; either version 2 of the License, or
16 ### (at your option) any later version.
17 ###
18 ### Rhodes is distributed in the hope that it will be useful,
19 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
20 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
21 ### GNU General Public License for more details.
22 ###
23 ### You should have received a copy of the GNU General Public License
24 ### along with Rhodes; if not, write to the Free Software Foundation,
25 ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
26
27 from sys import argv, stdout, stderr, exit
28 import errno as E
29 import fcntl as F
30 import os as OS
31 import subprocess as S
32 import select as SEL
33 import signal as SIG
34
35 import catacomb as C
36 import sqlite3 as SQL
37
38 ###--------------------------------------------------------------------------
39 ### Miscellaneous utilities.
40
41 class ExpectedError (Exception):
42   pass
43
44 ###--------------------------------------------------------------------------
45 ### Database handling.
46
47 CONNINIT_SQL = """
48 PRAGMA foreign_keys = on;
49 """
50
51 SETUP_SQL = """
52 PRAGMA journal_mode = wal;
53
54 CREATE TABLE top
55         (kind TEXT NOT NULL,            -- `gf2x'
56          groupdesc TEXT NOT NULL,
57          g TEXT NOT NULL,
58          x TEXT NOT NULL,
59          m TEXT NOT NULL,               -- g^m = 1
60          n TEXT DEFAULT NULL);          -- g^n = x
61
62 CREATE TABLE progress
63         (p TEXT PRIMARY KEY NOT NULL,   -- p|m, p prime
64          e INT NOT NULL,                -- e = v_p(m)
65          k INT NOT NULL DEFAULT(0),     -- 0 <= k <= e
66          n TEXT NOT NULL DEFAULT(0),    -- (g^{m/p^k})^n = x^{m/p^k}
67          dpbits INT NOT NULL);          -- 0 for sequential
68 CREATE UNIQUE INDEX progress_by_p_k ON progress (p, k);
69
70 CREATE TABLE workers
71         (pid INT PRIMARY KEY NOT NULL,
72          p TEXT NOT NULL,
73          k INT NOT NULL,
74          FOREIGN KEY (p, k) REFERENCES progress (p, k));
75 CREATE INDEX workers_by_p ON workers (p, k);
76
77 CREATE TABLE points
78         (p TEXT NOT NULL,
79          k INT NOT NULL,
80          z TEXT NOT NULL,               -- g^a x^b = z
81          a TEXT NOT NULL,
82          b TEXT NOT NULL,
83          PRIMARY KEY (p, k, z),
84          FOREIGN KEY (p, k) REFERENCES progress (p, k));
85 """
86
87 def connect_db(dir):
88   db = SQL.connect(OS.path.join(dir, 'db'))
89   db.text_factory = str
90   c = db.cursor()
91   c.executescript(CONNINIT_SQL)
92   return db
93
94 ###--------------------------------------------------------------------------
95 ### Group support.
96
97 GROUPMAP = {}
98
99 class GroupClass (type):
100   def __new__(cls, name, supers, dict):
101     ty = super(GroupClass, cls).__new__(cls, name, supers, dict)
102     try: name = ty.NAME
103     except AttributeError: pass
104     else: GROUPMAP[name] = ty
105     return ty
106
107 class BaseGroup (object):
108   __metaclass__ = GroupClass
109   def __init__(me, desc):
110     me.desc = desc
111   def div(me, x, y):
112     return me.mul(x, me.inv(y))
113
114 class BinaryFieldUnitGroup (BaseGroup):
115   NAME = 'gf2x'
116   def __init__(me, desc):
117     super(BinaryFieldUnitGroup, me).__init__(desc)
118     p = C.GF(desc)
119     if not p.irreduciblep(): raise ExpectedError, 'not irreducible'
120     me._k = C.BinPolyField(p)
121     me.order = me._k.q - 1
122   def elt(me, x):
123     return me._k(C.GF(x))
124   def pow(me, x, n):
125     return x**n
126   def mul(me, x, y):
127     return x*y
128   def inv(me, x):
129     return x.inv()
130   def idp(me, x):
131     return x == me._k.one
132   def eq(me, x, y):
133     return x == y
134   def str(me, x):
135     return str(x)
136
137 def getgroup(kind, desc): return GROUPMAP[kind](desc)
138
139 ###--------------------------------------------------------------------------
140 ### Number-theoretic utilities.
141
142 def factor(n):
143   ff = []
144   proc = S.Popen(['./factor', str(n)], stdout = S.PIPE)
145   for line in proc.stdout:
146     pstr, estr = line.split()
147     ff.append((C.MP(pstr), int(estr)))
148   rc = proc.wait()
149   if rc: raise ExpectedError, 'factor failed: rc = %d' % rc
150   return ff
151
152 ###--------------------------------------------------------------------------
153 ### Command dispatch.
154
155 CMDMAP = {}
156
157 def defcommand(f, name = None):
158   if isinstance(f, basestring):
159     return lambda g: defcommand(g, f)
160   else:
161     if name is None: name = f.__name__
162     CMDMAP[name] = f
163     return f
164
165 ###--------------------------------------------------------------------------
166 ### Job status utilities.
167
168 def get_top(db):
169   c = db.cursor()
170   c.execute("""SELECT kind, groupdesc, g, x, m, n FROM top""")
171   kind, groupdesc, gstr, xstr, mstr, nstr = c.fetchone()
172   G = getgroup(kind, groupdesc)
173   g, x, m = G.elt(gstr), G.elt(xstr), C.MP(mstr)
174   n = nstr is not None and C.MP(nstr) or None
175   return G, g, x, m, n
176
177 def get_job(db):
178   c = db.cursor()
179   c.execute("""SELECT p.p, p.e, p.k, p.n, p.dpbits
180                FROM progress AS p LEFT OUTER JOIN workers AS w
181                        ON p.p = w.p and p.k = w.k
182                WHERE p.k < p.e AND (p.dpbits > 0 OR w.pid IS NULL)
183                LIMIT 1""")
184   row = c.fetchone()
185   if row is None: return None, None, None, None, None
186   else:
187     pstr, e, k, nstr, dpbits = row
188     p, n = C.MP(pstr), C.MP(nstr)
189     return p, e, k, n, dpbits
190
191 def maybe_cleanup_worker(dir, db, pid):
192   c = db.cursor()
193   f = OS.path.join(dir, 'lk.%d' % pid)
194   state = 'LIVE'
195   try: fd = OS.open(f, OS.O_WRONLY)
196   except OSError, err:
197     if err.errno != E.ENOENT: raise ExpectedError, 'open lockfile: %s' % err
198     state = 'STALE'
199   else:
200     try: F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
201     except IOError, err:
202       if err.errno != E.EAGAIN: raise ExpectedError, 'check lock: %s' % err
203     else:
204       state = 'STALE'
205   if state == 'STALE':
206     try: OS.unlink(f)
207     except OSError: pass
208     c.execute("""DELETE FROM workers WHERE pid = ?""", (pid,))
209
210 def maybe_kill_worker(dir, pid):
211   f = OS.path.join(dir, 'lk.%d' % pid)
212   try: fd = OS.open(f, OS.O_RDWR)
213   except OSError, err:
214     if err.errno != E.ENOENT: raise ExpectedError, 'open lockfile: %s' % err
215     return
216   try: F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
217   except IOError, err:
218     if err.errno != E.EAGAIN: raise ExpectedError, 'check lock: %s' % err
219   else: return
220   OS.kill(pid, SIG.SIGTERM)
221   try: OS.unlink(f)
222   except OSError: pass
223
224 ###--------------------------------------------------------------------------
225 ### Setup.
226
227 @defcommand
228 def setup(dir, kind, groupdesc, gstr, xstr):
229
230   ## Get the group.  This will also figure out the group order.
231   G = getgroup(kind, groupdesc)
232
233   ## Figure out the generator order.
234   g = G.elt(gstr)
235   x = G.elt(xstr)
236   ff = []
237   m = G.order
238   for p, e in factor(m):
239     ee = 0
240     for i in xrange(e):
241       mm = m/p
242       t = G.pow(g, mm)
243       if not G.idp(t): break
244       ee += 1; m = mm
245     if ee < e: ff.append((p, e - ee))
246
247   ## Check that x at least has the right order.  This check is imperfect.
248   if not G.idp(G.pow(x, m)): raise ValueError, 'x not in <g>'
249
250   ## Prepare the directory.
251   try: OS.mkdir(dir)
252   except OSError, err: raise ExpectedError, 'mkdir: %s' % err
253
254   ## Prepare the database.
255   db = connect_db(dir)
256   c = db.cursor()
257   c.executescript(SETUP_SQL)
258
259   ## Populate the general information.
260   with db:
261     c.execute("""INSERT INTO top (kind, groupdesc, g, x, m)
262                  VALUES (?, ?, ?, ?, ?)""",
263               (kind, groupdesc, G.str(g), G.str(x), str(m)))
264     for p, e in ff:
265       if p.nbits <= 48: dpbits = 0
266       else: dpbits = p.nbits*2/5
267       c.execute("""INSERT INTO progress (p, e, dpbits) VALUES (?, ?, ?)""",
268                 (str(p), e, dpbits))
269
270 ###--------------------------------------------------------------------------
271 ### Check.
272
273 @defcommand
274 def check(dir):
275   rc = [0]
276   def bad(msg):
277     print >>stderr, '%s: %s' % (PROG, msg)
278     rc[0] = 3
279   db = connect_db(dir)
280   c = db.cursor()
281   G, g, x, m, n = get_top(db)
282   print '## group: %s %s' % (G.NAME, G.desc)
283   print '## g = %s' % G.str(g)
284   print '## x = %s' % G.str(x)
285
286   if not G.idp(G.pow(g, m)):
287     bad('bad generator/order: %s^%d /= 1' % (G.str(g), m))
288   if not G.idp(G.pow(x, m)):
289     bad('x not in group: %s^%d /= 1' % (G.str(x), m))
290
291   ## Clear away old workers that aren't doing anything useful any more.
292   ## For each worker pid, check that its lockfile is still locked; if
293   ## not, it's finished and can be disposed of.
294   c.execute("""SELECT pid FROM workers""")
295   for pid, in c:
296     maybe_cleanup_worker(dir, db, pid)
297   for f in OS.listdir(dir):
298     if f.startswith('lk.'):
299       pid = int(f[3:])
300       maybe_cleanup_worker(dir, db, pid)
301
302   c.execute("""SELECT p.p, p.e, p.k, p.n, p.dpbits, COUNT(d.z)
303                FROM progress AS p LEFT OUTER JOIN points AS d
304                ON p.p = d.p AND p.k = d.k
305                GROUP BY p.p, p.k
306                ORDER BY LENGTH(p.p), p.p""")
307   mm = 1
308   for pstr, e, k, nnstr, dpbits, ndp in c:
309     p, nn = C.MP(pstr), C.MP(nnstr)
310     q = p**e
311     if m%q:
312       bad('incorrect order factorization: %d^%d /| %d' % (p, e, m))
313     mm *= q
314     if G.idp(G.pow(g, m/p)):
315       bad('bad generator/order: %s^{%d/%d} = 1' ^ (G.str(g), m, p))
316     r = m/p**k
317     h = G.pow(g, r*nn)
318     y = G.pow(x, r)
319     if not G.eq(h, y):
320       bad('bad partial log: (%s^{%d/%d^%d})^%d = %s /= %s = %s^{%d/%d^%d}' %
321           (G.str(g), m, p, k, nn, G.str(h), G.str(y), G.str(x), m, p, k))
322     if not dpbits or k == e: dpinfo = ''
323     else: dpinfo = ' [%d: %d]' % (dpbits, ndp)
324     print '## %d: %d/%d%s' % (p, k, e, dpinfo)
325   if mm != m:
326     bad('incomplete factorization: %d /= %d' % (mm, m))
327
328   if n is not None:
329     xx = G.pow(g, n)
330     if not G.eq(xx, x):
331       bad('incorrect log: %s^%d = %s /= %s' %
332           (G.str(g), n, G.str(xx), G.str(x)))
333     print '## DONE: %d' % n
334
335   exit(rc[0])
336
337 ###--------------------------------------------------------------------------
338 ### Done.
339
340 @defcommand
341 def done(dir):
342   db = connect_db(dir)
343   c = db.cursor()
344   G, g, x, m, n = get_top(db)
345   if n is not None:
346     print '## DONE: %d' % n
347     exit(0)
348   p, e, k, n, dpbits = get_job(db)
349   if p is None: exit(2)
350   else: exit(1)
351
352 ###--------------------------------------------------------------------------
353 ### Step.
354
355 @defcommand
356 def step(dir, cmd, *args):
357
358   ## Open the database.
359   db = connect_db(dir)
360   c = db.cursor()
361   ##db.isolation_level = 'EXCLUSIVE'
362
363   ## Prepare our lockfile names.
364   mypid = OS.getpid()
365   nlk = OS.path.join(dir, 'nlk.%d' % mypid)
366   lk = OS.path.join(dir, 'lk.%d' % mypid)
367
368   ## Overall exception handling...
369   try:
370
371     ## Find out what needs doing and start doing it.  For this, we open a
372     ## transaction.
373     with db:
374       G, g, x, m, n = get_top(db)
375       if n is not None: raise ExpectedError, 'job done'
376
377       ## Find something to do.  Either a job that's small enough for us to
378       ## take on alone, and that nobody else has picked up yet, or one that
379       ## everyone's pitching in on.
380       p, e, k, n, dpbits = get_job(db)
381       if p is None: raise ExpectedError, 'no work to do'
382
383       ## Figure out what needs doing.  Let q = p^e, h = g^{m/q}, y = x^{m/q}.
384       ## Currently we have n_0 where
385       ##
386       ##    h^{p^{e-k} n_0} = y^{p^{e-k}}
387       ##
388       ## Suppose n == n_0 + p^k n' (mod p^{k+1}).  Then p^k n' == n - n_0
389       ## (mod p^{k+1}).
390       ##
391       ##    (h^{p^{e-1}})^{n'} = (g^{m/p})^{n'}
392       ##                       = (y/h^{n_0})^{p^{e-k-1}}
393       ##
394       ## so this is the next discrete log to solve.
395       q = p**e
396       o = m/q
397       h, y = G.pow(g, o), G.pow(x, o)
398       hh = G.pow(h, p**(e-1))
399       yy = G.pow(G.div(y, G.pow(h, n)), p**(e-k-1))
400
401       ## Take out a lockfile.
402       fd = OS.open(nlk, OS.O_WRONLY | OS.O_CREAT, 0700)
403       F.lockf(fd, F.LOCK_EX | F.LOCK_NB)
404       OS.rename(nlk, lk)
405
406       ## Record that we're working in the database.  This completes our
407       ## initial transaction.
408       c.execute("""INSERT INTO workers (pid, p, k) VALUES (?, ?, ?)""",
409                 (mypid, str(p), k))
410
411     ## Before we get too stuck in, check for an easy case.
412     if G.idp(yy):
413       dpbits = 0 # no need for distinguished points
414       nn = 0; ni = 0
415     else:
416
417       ## There's nothing else for it.  Start the job up.
418       proc = S.Popen([cmd] + list(args) +
419                      [str(dpbits), G.NAME, G.desc,
420                       G.str(hh), G.str(yy), str(p)],
421                      stdin = S.PIPE, stdout = S.PIPE)
422       f_in, f_out = proc.stdin.fileno(), proc.stdout.fileno()
423
424       ## Now we must look after it until it starts up.  Feed it stuff on stdin
425       ## periodically, so that we notice if our network connectivity is lost.
426       ## Collect its stdout.
427       for fd in [f_in, f_out]:
428         fl = F.fcntl(fd, F.F_GETFL)
429         F.fcntl(fd, F.F_SETFL, fl | OS.O_NONBLOCK)
430       done = False
431       out = ''
432       while not done:
433         rdy, wry, exy = SEL.select([f_out], [], [], 30.0)
434         if rdy:
435           while True:
436             try: b = OS.read(f_out, 4096)
437             except OSError, err:
438               if err.errno == E.EAGAIN: break
439               else: raise ExpectedError, 'read job: %s' % err
440             else:
441               if not len(b): done = True; break
442               else: out += b
443         if not done:
444           try: OS.write(f_in, '.')
445           except OSError, err: raise ExpectedError, 'write job: %s' % err
446       rc = proc.wait()
447       if rc: raise ExpectedError, 'job failed: rc = %d' % rc
448
449       ## Parse the answer.  There are two cases.
450       if not dpbits:
451         nnstr, nistr = out.split()
452         nn, ni = C.MP(nnstr), int(nistr)
453       else:
454         astr, bstr, zstr, nistr = out.split()
455         a, b, z, ni = C.MP(astr), C.MP(bstr), G.elt(zstr), int(nistr)
456
457     ## We have an answer.  Start a new transaction while we think about what
458     ## this means.
459     with db:
460
461       if dpbits:
462
463         ## Check that it's a correct point.
464         zz = G.mul(G.pow(hh, a), G.pow(yy, b))
465         if not G.eq(zz, z):
466           raise ExpectedError, \
467               'job incorrect distinguished point: %s^%d %s^%d = %s /= %s' % \
468               (hh, a, yy, b, zz, z)
469
470         ## Report this (partial) success.
471         print '## [%d, %d/%d: %d]: %d %d -> %s [%d]' % \
472             (p, k, e, dpbits, a, b, G.str(z), ni)
473
474         ## If it's already in the database then we have an answer to the
475         ## problem.
476         c.execute("""SELECT a, b FROM points
477                      WHERE p = ? AND k = ? AND z = ?""",
478                   (str(p), k, str(z)))
479         row = c.fetchone()
480         if row is None:
481           nn = None
482           c.execute("""INSERT INTO points (p, k, a, b, z)
483                         VALUES (?, ?, ?, ?, ?)""",
484                     (str(p), str(k), str(a), str(b), G.str(z)))
485         else:
486           aastr, bbstr = row
487           aa, bb = C.MP(aastr), C.MP(bbstr)
488           if not (b - bb)%p:
489             raise ExpectedError, 'duplicate point :-('
490
491           ## We win!
492           nn = ((a - aa)*p.modinv(bb - b))%p
493           c.execute("""SELECT COUNT(z) FROM points WHERE p = ? AND k = ?""",
494                     (str(p), k))
495           ni, = c.fetchone()
496           print '## [%s, %d/%d: %d] collision %d %d -> %s <- %s %s [#%d]' % \
497               (p, k, e, dpbits, a, b, G.str(z), aa, bb, ni)
498
499       ## If we don't have a final answer then we're done.
500       if nn is None: return
501
502       ## Check that the log we've recovered is correct.
503       yyy = G.pow(hh, nn)
504       if not G.eq(yyy, yy):
505         raise ExpectedError, 'recovered incorrect log: %s^%d = %s /= %s' % \
506             (G.str(hh), nn, G.str(yyy), G.str(yy))
507
508       ## Update the log for this prime power.
509       n += nn*p**k
510       k += 1
511
512       ## Check that this is also correct.
513       yyy = G.pow(h, n*p**(e-k))
514       yy = G.pow(y, p**(e-k))
515       if not G.eq(yyy, yy):
516         raise ExpectedError, 'lifted incorrect log: %s^d = %s /= %s' % \
517             (G.str(h), n, G.str(yyy), G.str(yy))
518
519       ## Kill off the other jobs working on this component.  If we crash now,
520       ## we lose a bunch of work. :-(
521       c.execute("""SELECT pid FROM workers WHERE p = ? AND k = ?""",
522                 (str(p), k - 1))
523       for pid, in c:
524         if pid != mypid: maybe_kill_worker(dir, pid)
525       c.execute("""DELETE FROM workers WHERE p = ? AND k = ?""",
526                 (str(p), k - 1))
527       c.execute("""DELETE FROM points WHERE p = ? AND k = ?""",
528                 (str(p), k - 1))
529
530       ## Looks like we're good: update the progress table.
531       c.execute("""UPDATE progress SET k = ?, n = ? WHERE p = ?""",
532                 (k, str(n), str(p)))
533       print '## [%d, %d/%d]: %d [%d]' % (p, k, e, n, ni)
534
535       ## Quick check: are we done now?
536       c.execute("""SELECT p FROM progress WHERE k < e
537                    LIMIT 1""")
538       row = c.fetchone()
539       if row is None:
540
541         ## Wow.  Time to stitch everything together.
542         c.execute("""SELECT p, e, n FROM progress""")
543         qq, nn = [], []
544         for pstr, e, nstr in c:
545           p, n = C.MP(pstr), C.MP(nstr)
546           qq.append(p**e)
547           nn.append(n)
548         if len(qq) == 1: n = nn[0]
549         else: n = C.MPCRT(qq).solve(nn)
550
551         ## One last check that this is the right answer.
552         xx = G.pow(g, n)
553         if not G.eq(x, xx):
554           raise ExpectedError, \
555               'calculated incorrect final log: %s^d = %s /= %s' \
556               (G.str(g), n, G.str(xx), G.str(x))
557
558         ## We're good.
559         c.execute("""UPDATE top SET n = ?""", (str(n),))
560         print '## DONE: %d' % n
561
562   finally:
563
564     ## Delete our lockfile.
565     for f in [nlk, lk]:
566       try: OS.unlink(f)
567       except OSError: pass
568
569     ## Unregister from the database.
570     with db:
571       c.execute("""DELETE FROM workers WHERE pid = ?""", (mypid,))
572
573 ###--------------------------------------------------------------------------
574 ### Top-level program.
575
576 PROG = argv[0]
577
578 try:
579   CMDMAP[argv[1]](*argv[2:])
580 except ExpectedError, err:
581   print >>stderr, '%s: %s' % (PROG, err.message)
582   exit(3)