chiark / gitweb /
math/mpgen, symm/multigen: Various minor cleanups.
[catacomb] / math / mpgen
1 #! @PYTHON@
2 ###
3 ### Generate multiprecision integer representations
4 ###
5 ### (c) 2013 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of Catacomb.
11 ###
12 ### Catacomb is free software; you can redistribute it and/or modify
13 ### it under the terms of the GNU Library General Public License as
14 ### published by the Free Software Foundation; either version 2 of the
15 ### License, or (at your option) any later version.
16 ###
17 ### Catacomb is distributed in the hope that it will be useful,
18 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20 ### GNU Library General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU Library General Public
23 ### License along with Catacomb; if not, write to the Free
24 ### Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
25 ### MA 02111-1307, USA.
26
27 from __future__ import with_statement
28
29 import re as RX
30 import optparse as OP
31 import types as TY
32
33 from sys import stdout
34
35 ###--------------------------------------------------------------------------
36 ### Random utilities.
37
38 def write_header(mode, name):
39   stdout.write("""\
40 /* -*-c-*- GENERATED by mpgen (%s)
41  *
42  * %s
43  */
44
45 """ % (mode, name))
46
47 def write_banner(text):
48   stdout.write("/*----- %s %s*/\n" % (text, '-' * (66 - len(text))))
49
50 class struct (object): pass
51
52 R_IDBAD = RX.compile('[^0-9A-Za-z]')
53 def fix_name(name): return R_IDBAD.sub('_', name)
54
55 ###--------------------------------------------------------------------------
56 ### Determining the appropriate types.
57
58 TYPEMAP = {}
59
60 class IntClass (type):
61   def __new__(cls, name, supers, dict):
62     c = type.__new__(cls, name, supers, dict)
63     try: TYPEMAP[c.tag] = c
64     except AttributeError: pass
65     return c
66
67 class BasicIntType (object):
68   __metaclass__ = IntClass
69   preamble = ''
70   typedef_prefix = ''
71   literalfmt = '%su'
72   def __init__(me, bits, rank):
73     me.bits = bits
74     me.rank = rank
75     me.litwd = len(me.literal(0))
76   def literal(me, value, fmt = None):
77     if fmt is None: fmt = '0x%0' + str((me.bits + 3)//4) + 'x'
78     return me.literalfmt % (fmt % value)
79
80 class UnsignedCharType (BasicIntType):
81   tag = 'uchar'
82   name = 'unsigned char'
83
84 class UnsignedShortType (BasicIntType):
85   tag = 'ushort'
86   name = 'unsigned short'
87
88 class UnsignedIntType (BasicIntType):
89   tag = 'uint'
90   name = 'unsigned int'
91
92 class UnsignedLongType (BasicIntType):
93   tag = 'ulong'
94   name = 'unsigned long'
95   literalfmt = '%sul'
96
97 class UnsignedLongLongType (BasicIntType):
98   tag = 'ullong'
99   name = 'unsigned long long'
100   preamble = """
101 #if __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ >= 91)
102 #  define CATACOMB_GCC_EXTENSION __extension__
103 #else
104 #  define CATACOMB_GCC_EXTENSION
105 #endif
106 """
107   typedef_prefix = 'CATACOMB_GCC_EXTENSION '
108   literalfmt = 'CATACOMB_GCC_EXTENSION %sull'
109
110 class UIntMaxType (BasicIntType):
111   tag = 'uintmax'
112   name = 'uintmax_t'
113   preamble = "\n#include <stdint.h>\n"
114
115 class TypeChoice (object):
116   def __init__(me, tifile):
117
118     ## Load the captured type information.
119     me.ti = TY.ModuleType('typeinfo')
120     execfile(opts.typeinfo, me.ti.__dict__)
121
122     ## Build a map of the available types.
123     tymap = {}
124     byrank = []
125     for tag, bits in me.ti.TYPEINFO:
126       rank = len(byrank)
127       tymap[tag] = rank
128       byrank.append(TYPEMAP[tag](bits, rank))
129
130     ## First pass: determine a suitable word size.  The criteria are (a)
131     ## there exists another type at least twice as long (so that we can do a
132     ## single x single -> double multiplication), and (b) operations on a
133     ## word are efficient (so we'd prefer a plain machine word).  We'll start
134     ## at `int' and work down.  Maybe this won't work: there's a plan B.
135     mpwbits = 0
136     i = tymap['uint']
137     while not mpwbits and i >= 0:
138       ibits = byrank[i].bits
139       for j in xrange(i + 1, len(byrank)):
140         if byrank[j].bits >= 2*ibits:
141           mpwbits = ibits
142           break
143
144     ## If that didn't work, then we'll start with the largest type available
145     ## and go with half its size.
146     if not mpwbits:
147       mpwbits = byrank[-1].bits//2
148
149     ## Make sure we've not ended up somewhere really silly.
150     if mpwbits < 16:
151       raise Exception, "`mpw' type is too small: your C environment is weird"
152
153     ## Now figure out suitable types for `mpw' and `mpd'.
154     def find_type(bits, what):
155       for ty in byrank:
156         if ty.bits >= bits: return ty
157       raise Exception, \
158           "failed to find suitable %d-bit type, for %s" % (bits, what)
159
160     ## Store our decisions.
161     me.mpwbits = mpwbits
162     me.mpw = find_type(mpwbits, 'mpw')
163     me.mpd = find_type(mpwbits*2, 'mpd')
164
165 ###--------------------------------------------------------------------------
166 ### Outputting constant multiprecision integers.
167
168 MARGIN = 72
169
170 def write_preamble():
171   stdout.write("""
172 #include <mLib/macros.h>
173 #define MP_(name, flags) \\
174   { (/*unconst*/ mpw *)name##__mpw, \\
175     (/*unconst*/ mpw *)name##__mpw + N(name##__mpw), \\
176     N(name##__mpw), 0, MP_CONST | flags, 0 }
177 #define ZERO_MP { 0, 0, 0, 0, MP_CONST, 0 }
178 #define POS_MP(name) MP_(name, 0)
179 #define NEG_MP(name) MP_(name, MP_NEG)
180 """)
181
182 def write_limbs(name, x):
183   if not x: return
184   stdout.write("\nstatic const mpw %s__mpw[] = {" % name)
185   sep = ''
186   pos = MARGIN
187   if x < 0: x = -x
188   mask = (1 << TC.mpwbits) - 1
189
190   while x > 0:
191     w, x = x & mask, x >> TC.mpwbits
192     f = TC.mpw.literal(w)
193     if pos + 2 + len(f) <= MARGIN:
194       stdout.write(sep + ' ' + f)
195     else:
196       pos = 2
197       stdout.write(sep + '\n  ' + f)
198     pos += len(f) + 2
199     sep = ','
200
201   stdout.write("\n};\n")
202
203 def mp_body(name, x):
204   return "%s_MP(%s)" % (x >= 0 and "POS" or "NEG", name)
205
206 ###--------------------------------------------------------------------------
207 ### Mode definition machinery.
208
209 MODEMAP = {}
210
211 def defmode(func):
212   name = func.func_name
213   if name.startswith('m_'): name = name[2:]
214   MODEMAP[name] = func
215   return func
216
217 ###--------------------------------------------------------------------------
218 ### The basic types header.
219
220 @defmode
221 def m_mptypes():
222   write_header("mptypes", "mptypes.h")
223   stdout.write("""\
224 #ifndef CATACOMB_MPTYPES_H
225 #define CATACOMB_MPTYPES_H
226 """)
227
228   have = set([TC.mpw, TC.mpd])
229   for t in have:
230     stdout.write(t.preamble)
231
232   for label, t, bits in [('mpw', TC.mpw, TC.mpwbits),
233                          ('mpd', TC.mpd, TC.mpwbits*2)]:
234     LABEL = label.upper()
235     stdout.write("\n%stypedef %s %s;\n" % (t.typedef_prefix, t.name, label))
236     stdout.write("#define %s_BITS %d\n" % (LABEL, bits))
237     i = 1
238     while 2*i < bits: i *= 2
239     stdout.write("#define %s_P2 %d\n" % (LABEL, i))
240     stdout.write("#define %s_MAX %s\n" % (LABEL,
241                                           t.literal((1 << bits) - 1, "%d")))
242
243   stdout.write("\n#endif\n")
244
245 ###--------------------------------------------------------------------------
246 ### Constant tables.
247
248 @defmode
249 def m_mplimits_c():
250   write_header("mplimits_c", "mplimits.c")
251   stdout.write('#include "mplimits.h"\n')
252   write_preamble()
253   seen = {}
254   v = []
255   def write(x):
256     if not x or x in seen: return
257     seen[x] = 1
258     write_limbs('limits_%d' % len(v), x)
259     v.append(x)
260   for tag, lo, hi in TC.ti.LIMITS:
261     write(lo)
262     write(hi)
263
264   stdout.write("\nmp mp_limits[] = {")
265   i = 0
266   sep = "\n  "
267   for x in v:
268     stdout.write("%s%s_MP(limits_%d)" % (sep, x < 0 and "NEG" or "POS", i))
269     i += 1
270     sep = ",\n  "
271   stdout.write("\n};\n");
272
273 @defmode
274 def m_mplimits_h():
275   write_header("mplimits_h", "mplimits.h")
276   stdout.write("""\
277 #ifndef CATACOMB_MPLIMITS_H
278 #define CATACOMB_MPLIMITS_H
279
280 #ifndef CATACOMB_MP_H
281 #  include "mp.h"
282 #endif
283
284 extern mp mp_limits[];
285
286 """)
287
288   seen = { 0: "MP_ZERO" }
289   slot = [0]
290   def find(x):
291     try:
292       r = seen[x]
293     except KeyError:
294       r = seen[x] = '(&mp_limits[%d])' % slot[0]
295       slot[0] += 1
296     return r
297   for tag, lo, hi in TC.ti.LIMITS:
298     stdout.write("#define MP_%s_MIN %s\n" % (tag, find(lo)))
299     stdout.write("#define MP_%s_MAX %s\n" % (tag, find(hi)))
300
301   stdout.write("\n#endif\n")
302
303 ###--------------------------------------------------------------------------
304 ### Group tables.
305
306 class GroupTableClass (type):
307   def __new__(cls, name, supers, dict):
308     c = type.__new__(cls, name, supers, dict)
309     try: mode = c.mode
310     except AttributeError: pass
311     else: MODEMAP[c.mode] = c.run
312     return c
313
314 class GroupTable (object):
315   __metaclass__ = GroupTableClass
316   keyword = 'group'
317   slots = []
318   def __init__(me):
319     me.st = st = struct()
320     st.nextmp = 0
321     st.mpmap = { None: 'NO_MP', 0: 'ZERO_MP' }
322     st.d = {}
323     st.name = None
324     me._names = []
325     me._defs = set()
326     me._slotmap = dict([(s.name, s) for s in me.slots])
327     me._headslots = [s for s in me.slots if s.headline]
328   def _flush(me):
329     if me.st.name is None: return
330     stdout.write("/* --- %s --- */\n" % me.st.name)
331     for s in me.slots: s.setup(me.st)
332     stdout.write("\nstatic %s c_%s = {" % (me.data_t, fix_name(me.st.name)))
333     sep = "\n  "
334     for s in me.slots:
335       stdout.write(sep)
336       s.write(me.st)
337       sep = ",\n  "
338     stdout.write("\n};\n\n")
339     me.st.d = {}
340     me.st.name = None
341   @classmethod
342   def run(cls, input):
343     me = cls()
344     write_header(me.mode, me.filename)
345     stdout.write('#include "%s"\n' % me.header)
346     write_preamble()
347     stdout.write("#define NO_MP { 0, 0, 0, 0, 0, 0 }\n\n")
348     write_banner("Group data")
349     stdout.write('\n')
350     with open(input) as file:
351       for line in file:
352         ff = line.split()
353         if not ff or ff[0].startswith('#'): continue
354         if ff[0] == 'alias':
355           if len(ff) != 3: raise Exception, "wrong number of alias arguments"
356           me._flush()
357           me._names.append((ff[1], ff[2]))
358         elif ff[0] == me.keyword:
359           if len(ff) < 2 or len(ff) > 2 + len(me._headslots):
360             raise Exception, "bad number of headline arguments"
361           me._flush()
362           me.st.name = name = ff[1]
363           me._defs.add(name)
364           me._names.append((name, name))
365           for f, s in zip(ff[2:], me._headslots): s.set(me.st, f)
366         elif ff[0] in me._slotmap:
367           if len(ff) != 2:
368             raise Exception, "bad number of values for slot `%s'" % ff[0]
369           me._slotmap[ff[0]].set(me.st, ff[1])
370         else:
371           raise Exception, "unknown keyword `%s'" % ff[0]
372     me._flush()
373     write_banner("Main table")
374     stdout.write("\nconst %s %s[] = {\n" % (me.entry_t, me.tabname))
375     for a, n in me._names:
376       if n not in me._defs:
377         raise Exception, "alias `%s' refers to unknown group `%s'" % (a, n)
378       stdout.write('  { "%s", &c_%s },\n' % (a, fix_name(n)))
379     stdout.write("  { 0, 0 }\n};\n\n")
380     write_banner("That's all, folks")
381
382 class BaseSlot (object):
383   def __init__(me, name, headline = False, omitp = None, allowp = None):
384     me.name = name
385     me.headline = headline
386     me._omitp = omitp
387     me._allowp = allowp
388   def set(me, st, value):
389     if me._allowp and not me._allowp(st, value):
390       raise Exception, "slot `%s' not allowed here" % me.name
391     st.d[me] = value
392   def setup(me, st):
393     if me not in st.d and (not me._omitp or not me._omitp(st)):
394       raise Exception, "missing slot `%s'" % me.name
395
396 class EnumSlot (BaseSlot):
397   def __init__(me, name, prefix, values, **kw):
398     super(EnumSlot, me).__init__(name, **kw)
399     me._values = set(values)
400     me._prefix = prefix
401   def set(me, st, value):
402     if value not in me._values:
403       raise Exception, "invalid %s value `%s'" % (me.name, value)
404     super(EnumSlot, me).set(st, value)
405   def write(me, st):
406     try: stdout.write('%s_%s' % (me._prefix, st.d[me].upper()))
407     except KeyError: stdout.write('0')
408
409 class MPSlot (BaseSlot):
410   def set(me, st, value):
411     super(MPSlot, me).set(st, long(value, 0))
412   def setup(me, st):
413     super(MPSlot, me).setup(st)
414     v = st.d.get(me)
415     if v not in st.mpmap:
416       write_limbs('v%d' % st.nextmp, v)
417       st.mpmap[v] = mp_body('v%d' % st.nextmp, v)
418       st.nextmp += 1
419   def write(me, st):
420     stdout.write(st.mpmap[st.d.get(me)])
421
422 class BinaryGroupTable (GroupTable):
423   mode = 'bintab'
424   filename = 'bintab.c'
425   header = 'bintab.h'
426   data_t = 'bindata'
427   entry_t = 'binentry'
428   tabname = 'bintab'
429   slots = [MPSlot('p'), MPSlot('q'), MPSlot('g')]
430
431 class EllipticCurveTable (GroupTable):
432   mode = 'ectab'
433   filename = 'ectab.c'
434   header = 'ectab.h'
435   keyword = 'curve'
436   data_t = 'ecdata'
437   entry_t = 'ecentry'
438   tabname = 'ectab'
439   _typeslot = EnumSlot('type', 'FTAG',
440                        ['prime', 'niceprime', 'binpoly', 'binnorm'],
441                        headline = True)
442   slots = [_typeslot,
443            MPSlot('p'),
444            MPSlot('beta',
445                   allowp = lambda st, _:
446                     st.d[EllipticCurveTable._typeslot] == 'binnorm',
447                   omitp = lambda st:
448                     st.d[EllipticCurveTable._typeslot] != 'binnorm'),
449            MPSlot('a'), MPSlot('b'), MPSlot('r'), MPSlot('h'),
450            MPSlot('gx'), MPSlot('gy')]
451
452 class PrimeGroupTable (GroupTable):
453   mode = 'ptab'
454   filename = 'ptab.c'
455   header = 'ptab.h'
456   data_t = 'pdata'
457   entry_t = 'pentry'
458   tabname = 'ptab'
459   slots = [MPSlot('p'), MPSlot('q'), MPSlot('g')]
460
461 ###--------------------------------------------------------------------------
462 ### Main program.
463
464 op = OP.OptionParser(
465   description = 'Generate multiprecision integer representations',
466   usage = 'usage: %prog [-t TYPEINFO] MODE [ARGS ...]',
467   version = 'Catacomb, version @VERSION@')
468 for shortopt, longopt, kw in [
469   ('-t', '--typeinfo', dict(
470       action = 'store', metavar = 'PATH', dest = 'typeinfo',
471       help = 'alternative typeinfo file'))]:
472   op.add_option(shortopt, longopt, **kw)
473 op.set_defaults(typeinfo = './typeinfo.py')
474 opts, args = op.parse_args()
475
476 if len(args) < 1: op.error('missing MODE')
477 mode = args[0]
478
479 TC = TypeChoice(opts.typeinfo)
480
481 try: modefunc = MODEMAP[mode]
482 except KeyError: op.error("unknown mode `%s'" % mode)
483 modefunc(*args[1:])
484
485 ###----- That's all, folks --------------------------------------------------