chiark / gitweb /
@@@ fltfmt mess
[mLib] / utils / t / fltfmt-testgen
1 #! /usr/bin/python
2 ###
3 ### Generate exhaustive tests for floating-point conversions.
4 ###
5 ### (c) 2024 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the mLib utilities library.
11 ###
12 ### mLib is free software: you can redistribute it and/or modify it under
13 ### the terms of the GNU Library General Public License as published by
14 ### the Free Software Foundation; either version 2 of the License, or (at
15 ### your option) any later version.
16 ###
17 ### mLib is distributed in the hope that it will be useful, but WITHOUT
18 ### ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
19 ### FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
20 ### License for more details.
21 ###
22 ### You should have received a copy of the GNU Library General Public
23 ### License along with mLib.  If not, write to the Free Software
24 ### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25 ### USA.
26
27 ###--------------------------------------------------------------------------
28 ### Imports.
29
30 import sys as SYS
31 import optparse as OP
32 import random as R
33 if SYS.version_info >= (3,):
34   from io import StringIO
35   xrange = range
36   def iterkeys(d): return d.keys()
37 else:
38   from cStringIO import StringIO
39   def iterkeys(d): return d.iterkeys()
40
41 ###--------------------------------------------------------------------------
42 ### Utilities.
43
44 def bit(k): "Return an integer with just bit K set."; return 1 << k
45 def mask(k): "Return an integer with bits 0 to K - 1 set."; return bit(k) - 1
46 M32 = mask(32)
47
48 def explore(wd, lobits, hibits):
49   """
50   Return an iterator over various WD-bit values.
51
52   Suppose that a test wants to explore various WD-bit fields, but WD might be
53   too large to do this exhaustively.  We assume (reasonably, in the case at
54   hand of floating-point formats) that the really interesting bits are those
55   at the low and high ends of the field, and test small subfields at the ends
56   exhaustively, filling in the bits in the middle with zeros, ones, or random
57   data.
58
59   So, the generator behaves as follows.  If WD <= LOBITS + HIBITS + 1 then
60   the iterator will yield all WD-bit values exhaustively.  Otherwise, it
61   yields a sequence which includes all combinations of: every LOBITS-bit
62   pattern in the least significant bits; every HIBITS-bit pattern in the most
63   significant bits; and all-bits-clear, all-bits-set, and a random pattern in
64   the bits in between.
65   """
66
67   if wd <= hibits + lobits + 1:
68     for i in xrange(bit(wd)): yield i
69   else:
70     midbit = bit(wd - hibits - lobits)
71     hishift = wd - hibits
72     m = (midbit - 1) << lobits
73     for hi in xrange(bit(hibits)):
74       top = hi << hishift
75       for lo in xrange(bit(lobits)):
76         base = lo | top
77         yield base
78         yield base | (R.randrange(midbit) << lobits)
79         yield base | m
80
81 class ExploreParameters (object):
82   """
83   Simple structure for exploration parameters; see `explore' for background.
84
85   The `explo' and `exphi' attributes are the low and high subfield sizes
86   for exponent fields, and `siglo' and `sighi' are the low and high subfield
87   sizes for significand fields.
88   """
89   def __init__(me, explo = 0, exphi = 2, siglo = 1, sighi = 3):
90     me.explo, me.exphi = explo, exphi
91     me.siglo, me.sighi = siglo, sighi
92
93 FMTMAP = {} # maps format names to classes
94
95 def with_metaclass(meta, *supers):
96   """
97   Return an arbitrary instance of the metaclass META.
98
99   The class will have SUPERS (default just `object') as its superclasses.
100   This is intended to be used in direct-superclass lists, as a compatibility
101   hack, because the Python 2 and 3 syntaxes are wildly different.
102   """
103   return meta("#<anonymous base %s>" % meta.__name__,
104               supers or (object,), dict())
105
106 class FormatClass (type):
107   """
108   Metaclass for format classes.
109
110   If the class defines a `NAME' attribute then register the class in
111   `FMTMAP'.
112   """
113   def __new__(cls, name, supers, dict):
114     c = type.__new__(cls, name, supers, dict)
115     try: FMTMAP[c.NAME] = c
116     except AttributeError: pass
117     return c
118
119 class IEEEFormat (with_metaclass(FormatClass)):
120   """
121   Floating point format class.
122
123   Concrete subclasses must define the following class attributes.
124
125     * `HIDDENP' -- true if the format uses a `hidden bit' convention for
126       normal numbers.
127
128     * `EXPWD' -- exponent field width, in bits.
129
130     * `PREC' -- precision, in bits, /including/ the hidden bit if any.
131
132   Many useful quantities are derived.
133
134     * `_expbias' is the exponent bias.
135
136     * `_minexp' and `_maxexp' are the minimum and maximum representable
137       exponents.
138
139     * `_sigwd' is the width of the significand field.
140
141     * `_paywords' is the number of words required to represent a NaN payload.
142
143     * `_nbits' is the total number of bits in an encoded value.
144
145     * `_rawbytes' is the number of bytes required for an encoded value.
146   """
147
148   def __init__(me):
149     """
150     Initialize an instance.
151     """
152     me._expbias = mask(me.EXPWD - 1)
153     me._maxexp = me._expbias
154     me._minexp = 1 - me._expbias
155     if me.HIDDENP: me._sigwd = me.PREC - 1
156     else: me._sigwd = me.PREC
157
158     me._paywords = (me._sigwd + 29)//32
159
160     me._nbits = 1 + me.EXPWD + me._sigwd
161     me._rawbytes = (me._nbits + 7)//8
162
163   def decode(me, x):
164     """
165     Decode the encoded floating-point value X, represented as an integer.
166
167     Return five quantities (FLAGS, EXP, FW, FRAC, ERR), corresponding mostly
168     to the `struct floatbits' representation, characterizing the value
169     encoded in X.
170
171       * FLAGS is a list of flag tokens:
172
173          -- `NEG' if the value is negative;
174          -- `ZERO' if the value is exactly zero;
175          -- `INF' if the value is infinite;
176          -- `SNAN' if the value is a signalling NaN; and/or
177          -- `QNAN' if the value is a quiet NaN.
178
179         FLAGS will be empty if the value is a strictly positive finite
180         number.
181
182       * EXP is the exponent, as a signed integer.  This will be `None' if the
183         value is zero, infinite, or a NaN.
184
185       * FW is the length of the fraction, in 32-bit words.  This will be
186         `None' if the value is zero or infinite.
187
188       * FRAC is the fraction or payload.  This will be `None' if the value is
189         zero or infinite; otherwise it will be an integer, 0 <= FRAC <
190         2^{32FW}.  If the value is a NaN, then the FRAC represents the
191         payload, /not/ including the quiet bit, left aligned.  Otherwise,
192         FRAC is normalized so that 2^{32FW-1} <= FRAC < 2^{32FW}, and the
193         value represented is S FRAC 2^{EXP-32FW}, where S = -1 if `NEG' is in
194         FLAGS, or +1 otherwise.  The represented value is unchanged by
195         multiplying or dividing FRAC by an exact power of 2^{32} and
196         (respectively) incrementing or decrementing FW to match, but this
197         will affect the output data in a way that affects the tests.
198
199       * ERR is a list of error tokens:
200
201           -- `INVAL' if the encoded value is erroneous (though decoding
202              continues anyway).
203
204         ERR will be empty if no error occurred.
205     """
206
207     ## Extract fields.
208     sig = x&mask(me._sigwd)
209     biasedexp = (x >> me._sigwd)&mask(me.EXPWD)
210     signbit = (x >> (me._sigwd + me.EXPWD))&1
211     if not me.HIDDENP: unitbit = sig >> me.PREC - 1
212
213     ## Initialize flag lists.
214     flags = []
215     err = []
216
217     ## Capture the sign.  This is always relevant.
218     if signbit: flags.append("NEG")
219
220     ## If the exponent field is all-bits-set then we have infinity or NaN.
221     if biasedexp == mask(me.EXPWD):
222
223       ## If there's no hidden bit then the unit bit should be /set/, but is
224       ## /not/ part of the NaN payload -- or even significant for
225       ## distinguishing a NaN from an infinity.  If it's clear, signal an
226       ## error; if it's set, then clear it so that we don't have to think
227       ## about it again.
228       if not me.HIDDENP:
229         if unitbit: sig &= mask(me._sigwd - 1)
230         else: err.append("INVAL")
231
232       ## If the significand is (now) zero, we have an infinity and there's
233       ## nothing else to do.
234       if not sig:
235         flags.append("INF")
236         frac = fw = exp = None
237
238       ## Otherwise determine the NaN flavour and extract the payload.
239       else:
240         if sig&bit(me.PREC - 2): flags.append("QNAN")
241         else: flags.append("SNAN")
242         shift = 32*me._paywords + 2 - me.PREC
243         frac = (sig&mask(me.PREC - 2)) << shift
244         exp = None
245         fw = me._paywords
246
247     ## Otherwise we have a finite number.  We handle all of these together.
248     else:
249
250       ## If there's no hidden bit, then check that the unit bit matches the
251       ## exponent: it should be clear if the exponent field is all-bits-zero
252       ## (zero or subnormal numbers), and set otherwise (normal numbers).  If
253       ## this isn't the case, signal an error, but continue.  We'll normalize
254       ## the number correctly as we go.
255       if not me.HIDDENP:
256         if (not biasedexp and unitbit) or (biasedexp and not unitbit):
257           err.append("INVAL")
258
259       ## If the exponent is all-bits-zero then set it to 1; otherwise, if the
260       ## format uses a hidden bit then force the unit bit of our significand
261       ## on.  The absolute value is now exactly
262       ##
263       ##        2^{biasedexp-_expbias-PREC+1} sig
264       ##
265       ## in all cases.
266       if not biasedexp: biasedexp = 1
267       elif me.HIDDENP: sig |= bit(me._sigwd)
268
269       ## If the significand is now zero then the value must be zero.
270       if not sig:
271         flags.append("ZERO")
272         frac = fw = exp = None
273
274       ## Otherwise we have a nonzero finite value, which might need
275       ## normalization.
276       else:
277         sigwd = sig.bit_length()
278         fw = (sigwd + 31)//32
279         exp = biasedexp - me._expbias - me.PREC + sigwd + 1
280         frac = sig << (32*fw - sigwd)
281
282     ## All done.
283     return flags, exp, frac, fw, err
284
285   def _dump_as_bytes(me, var, x, wd):
286     """
287     Dump an assignment to VAR of X as a WD-byte binary string.
288
289     Print, on standard output, an assignment `VAR = ...' giving the value of
290     X, in hexadecimal, split with spaces into groups of 8 digits from the
291     right.
292     """
293
294     if not wd:
295       print("%s = #empty" % var)
296     else:
297       out = StringIO()
298       for i in xrange(wd - 1, -1, -1):
299         out.write("%02x" % ((x >> 8*i)&0xff))
300         if i and not i%4: out.write(" ")
301       print("%s = %s" % (var, out.getvalue()))
302
303   def _dump_flags(me, var, flags, zero = "0"):
304     """
305     Dump an assignment to VAR of FLAGS as a list of flags.
306
307     Print, on standard output, an assignment `VAR = ...' giving the named
308     flags.  Print ZERO (default `0') if FLAGS is empty.
309     """
310
311     if flags: print("%s = %s" % (var, " | ".join(flags)))
312     else: print("%s = %s" % (var, zero))
313
314   def genenc(me, ep = ExploreParameters()):
315     """
316     Print, on standard output, tests of encoding floating-point values.
317
318     The tests will cover positive and negative values, with the exponent and
319     signficand fields explored according to the parameters EP.
320     """
321
322     print("[enc%s]" % me.NAME)
323     for s in xrange(2):
324       for e in explore(me.EXPWD, ep.explo, ep.exphi):
325         for m in explore(me.PREC - 1, ep.siglo, ep.sighi):
326           if not me.HIDDENP and e: m |= bit(me.PREC - 1)
327           x = (s << (me.EXPWD + me._sigwd)) | (e << me._sigwd) | m
328           flags, exp, frac, fw, err = me.decode(x)
329           print("")
330           me._dump_flags("f", flags)
331           if exp is not None: print("e = %d" % exp)
332           if frac is not None:
333             while not frac&M32 and fw: frac >>= 32; fw -= 1
334             me._dump_as_bytes("m", frac, 4*fw)
335           me._dump_as_bytes("z", x, me._rawbytes)
336           if err: me._dump_flags("err", err, "OK")
337
338   def gendec(me, ep = ExploreParameters()):
339     """
340     Print, on standard output, tests of decoding floating-point values.
341
342     The tests will cover positive and negative values, with the exponent and
343     signficand fields explored according to the parameters EP.
344     """
345
346     print("[dec%s]" % me.NAME)
347     for s in xrange(2):
348       for e in explore(me.EXPWD, ep.explo, ep.exphi):
349         for m in explore(me._sigwd, ep.siglo, ep.sighi):
350           x = (s << (me.EXPWD + me._sigwd)) | (e << me._sigwd) | m
351           flags, exp, frac, fw, err = me.decode(x)
352           print("")
353           me._dump_as_bytes("x", x, me._rawbytes)
354           me._dump_flags("f", flags)
355           if exp is not None: print("e = %d" % exp)
356           if frac is not None: me._dump_as_bytes("m", frac, 4*fw)
357           if err: me._dump_flags("err", err, "OK")
358
359 class MiniFloat (IEEEFormat):
360   NAME = "mini"
361   EXPWD = 4
362   PREC = 4
363   HIDDENP = True
364
365 class BFloat16 (IEEEFormat):
366   NAME = "bf16"
367   EXPWD = 8
368   PREC = 8
369   HIDDENP = True
370
371 class Binary16 (IEEEFormat):
372   NAME = "f16"
373   EXPWD = 5
374   PREC = 11
375   HIDDENP = True
376
377 class Binary32 (IEEEFormat):
378   NAME = "f32"
379   EXPWD = 8
380   PREC = 24
381   HIDDENP = True
382
383 class Binary64 (IEEEFormat):
384   NAME = "f64"
385   EXPWD = 11
386   PREC = 53
387   HIDDENP = True
388
389 class Binary128 (IEEEFormat):
390   NAME = "f128"
391   EXPWD = 15
392   PREC = 113
393   HIDDENP = True
394
395 class DoubleExtended80 (IEEEFormat):
396   NAME = "idblext80"
397   EXPWD = 15
398   PREC = 64
399   HIDDENP = False
400
401 ###--------------------------------------------------------------------------
402 ### Main program.
403
404 op = OP.OptionParser \
405   (description = "Generate test data for IEEE format encoding and decoding",
406    usage = "usage: %prog [-E LO/HI] [-M LO/HI] [[enc|dec]FORMAT]")
407 for shortopt, longopt, kw in \
408   [("-E", "--explore-exponent",
409     dict(action = "store", metavar = "LO/HI", dest = "expparam",
410          help = "exponent exploration parameters")),
411    ("-M", "--explore-significand",
412     dict(action = "store", metavar = "LO/HI", dest = "sigparam",
413          help = "significand exploration parameters"))]:
414   op.add_option(shortopt, longopt, **kw)
415 opts, args = op.parse_args()
416
417 ep = ExploreParameters()
418 for optattr, loattr, hiattr in [("expparam", "explo", "exphi"),
419                                 ("sigparam", "siglo", "sighi")]:
420   opt = getattr(opts, optattr)
421   if opt is not None:
422     ok = False
423     try: sl = opt.index("/")
424     except ValueError: pass
425     else:
426       try: lo, hi = map(int, (opt[:sl], opt[sl + 1:]))
427       except ValueError: pass
428       else:
429         setattr(ep, loattr, lo)
430         setattr(ep, hiattr, hi)
431         ok = True
432     if not ok: op.error("bad exploration parameter `%s'" % opt)
433
434 if not args:
435   for fmt in iterkeys(FMTMAP):
436     args.append("enc" + fmt)
437     args.append("dec" + fmt)
438 firstp = True
439 for arg in args:
440   tail = fmt = None
441   if arg.startswith("enc"): tail = arg[3:]; gen = lambda f: f.genenc(ep)
442   elif arg.startswith("dec"): tail = arg[3:]; gen = lambda f: f.gendec(ep)
443   if tail is not None: fmt = FMTMAP.get(tail)
444   if not fmt: op.error("unknown test group `%s'" % arg)
445   if firstp: firstp = False
446   else: print("")
447   gen(fmt())
448
449 ###----- That's all, folks --------------------------------------------------