chiark / gitweb /
@@@ fltfmt wip
[mLib] / utils / t / fltfmt-testgen
1 #! /usr/bin/python
2 ###
3 ### Generate exhaustive tests for floating-point conversions.
4 ###
5 ### (c) 2024 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the mLib utilities library.
11 ###
12 ### mLib is free software: you can redistribute it and/or modify it under
13 ### the terms of the GNU Library General Public License as published by
14 ### the Free Software Foundation; either version 2 of the License, or (at
15 ### your option) any later version.
16 ###
17 ### mLib is distributed in the hope that it will be useful, but WITHOUT
18 ### ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
19 ### FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
20 ### License for more details.
21 ###
22 ### You should have received a copy of the GNU Library General Public
23 ### License along with mLib.  If not, write to the Free Software
24 ### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25 ### USA.
26
27 ###--------------------------------------------------------------------------
28 ### Imports.
29
30 import sys as SYS
31 import optparse as OP
32 import random as R
33 if SYS.version_info >= (3,):
34   from io import StringIO
35   xrange = range
36   def iterkeys(d): return d.keys()
37 else:
38   from cStringIO import StringIO
39   def iterkeys(d): return d.iterkeys()
40
41 ###--------------------------------------------------------------------------
42 ### Utilities.
43
44 def bit(k): "Return an integer with just bit K set."; return 1 << k
45 def mask(k): "Return an integer with bits 0 to K - 1 set."; return bit(k) - 1
46 M32 = mask(32)
47
48 def explore(wd, lobits, hibits):
49   """
50   Return an iterator over various WD-bit values.
51
52   Suppose that a test wants to explore various WD-bit fields, but WD might be
53   too large to do this exhaustively.  We assume (reasonably, in the case at
54   hand of floating-point formats) that the really interesting bits are those
55   at the low and high ends of the field, and test small subfields at the ends
56   exhaustively, filling in the bits in the middle with zeros, ones, or random
57   data.
58
59   So, the generator behaves as follows.  If WD <= LOBITS + HIBITS + 1 then
60   the iterator will yield all WD-bit values exhaustively.  Otherwise, it
61   yields a sequence which includes all combinations of: every LOBITS-bit
62   pattern in the least significant bits; every HIBITS-bit pattern in the most
63   significant bits; and all-bits-clear, all-bits-set, and a random pattern in
64   the bits in between.
65   """
66
67   if wd <= hibits + lobits + 1:
68     for i in xrange(bit(wd)): yield i
69   else:
70     midbit = bit(wd - hibits - lobits)
71     hishift = wd - hibits
72     m = (midbit - 1) << lobits
73     for hi in xrange(bit(hibits)):
74       top = hi << hishift
75       for lo in xrange(bit(lobits)):
76         while True:
77           fill = R.randrange(midbit)
78           if fill != 0 and fill != midbit - 1: break
79         base = lo | top
80         yield base
81         yield base | (fill << lobits)
82         yield base | m
83
84 class ExploreParameters (object):
85   """
86   Simple structure for exploration parameters; see `explore' for background.
87
88   The `explo' and `exphi' attributes are the low and high subfield sizes
89   for exponent fields, and `siglo' and `sighi' are the low and high subfield
90   sizes for significand fields.
91   """
92   def __init__(me, explo = 0, exphi = 2, siglo = 1, sighi = 3):
93     me.explo, me.exphi = explo, exphi
94     me.siglo, me.sighi = siglo, sighi
95
96 FMTMAP = {} # maps format names to classes
97
98 def with_metaclass(meta, *supers):
99   """
100   Return an arbitrary instance of the metaclass META.
101
102   The class will have SUPERS (default just `object') as its superclasses.
103   This is intended to be used in direct-superclass lists, as a compatibility
104   hack, because the Python 2 and 3 syntaxes are wildly different.
105   """
106   return meta("#<anonymous base %s>" % meta.__name__,
107               supers or (object,), dict())
108
109 class FormatClass (type):
110   """
111   Metaclass for format classes.
112
113   If the class defines a `NAME' attribute then register the class in
114   `FMTMAP'.
115   """
116   def __new__(cls, name, supers, dict):
117     c = type.__new__(cls, name, supers, dict)
118     try: FMTMAP[c.NAME] = c
119     except AttributeError: pass
120     return c
121
122 class IEEEFormat (with_metaclass(FormatClass)):
123   """
124   Floating point format class.
125
126   Concrete subclasses must define the following class attributes.
127
128     * `HIDDENP' -- true if the format uses a `hidden bit' convention for
129       normal numbers.
130
131     * `EXPWD' -- exponent field width, in bits.
132
133     * `PREC' -- precision, in bits, /including/ the hidden bit if any.
134
135   Many useful quantities are derived.
136
137     * `_expbias' is the exponent bias.
138
139     * `_minexp' and `_maxexp' are the minimum and maximum representable
140       exponents.
141
142     * `_sigwd' is the width of the significand field.
143
144     * `_paywords' is the number of words required to represent a NaN payload.
145
146     * `_nbits' is the total number of bits in an encoded value.
147
148     * `_rawbytes' is the number of bytes required for an encoded value.
149   """
150
151   def __init__(me):
152     """
153     Initialize an instance.
154     """
155     me._expbias = mask(me.EXPWD - 1)
156     me._maxexp = me._expbias
157     me._minexp = 1 - me._expbias
158     if me.HIDDENP: me._sigwd = me.PREC - 1
159     else: me._sigwd = me.PREC
160
161     me._paywords = (me._sigwd + 29)//32
162
163     me._nbits = 1 + me.EXPWD + me._sigwd
164     me._rawbytes = (me._nbits + 7)//8
165
166   def decode(me, x):
167     """
168     Decode the encoded floating-point value X, represented as an integer.
169
170     Return five quantities (FLAGS, EXP, FW, FRAC, ERR), corresponding mostly
171     to the `struct floatbits' representation, characterizing the value
172     encoded in X.
173
174       * FLAGS is a list of flag tokens:
175
176          -- `NEG' if the value is negative;
177          -- `ZERO' if the value is exactly zero;
178          -- `INF' if the value is infinite;
179          -- `SNAN' if the value is a signalling NaN; and/or
180          -- `QNAN' if the value is a quiet NaN.
181
182         FLAGS will be empty if the value is a strictly positive finite
183         number.
184
185       * EXP is the exponent, as a signed integer.  This will be `None' if the
186         value is zero, infinite, or a NaN.
187
188       * FW is the length of the fraction, in 32-bit words.  This will be
189         `None' if the value is zero or infinite.
190
191       * FRAC is the fraction or payload.  This will be `None' if the value is
192         zero or infinite; otherwise it will be an integer, 0 <= FRAC <
193         2^{32FW}.  If the value is a NaN, then the FRAC represents the
194         payload, /not/ including the quiet bit, left aligned.  Otherwise,
195         FRAC is normalized so that 2^{32FW-1} <= FRAC < 2^{32FW}, and the
196         value represented is S FRAC 2^{EXP-32FW}, where S = -1 if `NEG' is in
197         FLAGS, or +1 otherwise.  The represented value is unchanged by
198         multiplying or dividing FRAC by an exact power of 2^{32} and
199         (respectively) incrementing or decrementing FW to match, but this
200         will affect the output data in a way that affects the tests.
201
202       * ERR is a list of error tokens:
203
204           -- `INVAL' if the encoded value is erroneous (though decoding
205              continues anyway).
206
207         ERR will be empty if no error occurred.
208     """
209
210     ## Extract fields.
211     sig = x&mask(me._sigwd)
212     biasedexp = (x >> me._sigwd)&mask(me.EXPWD)
213     signbit = (x >> (me._sigwd + me.EXPWD))&1
214     if not me.HIDDENP: unitbit = sig >> me.PREC - 1
215
216     ## Initialize flag lists.
217     flags = []
218     err = []
219
220     ## Capture the sign.  This is always relevant.
221     if signbit: flags.append("NEG")
222
223     ## If the exponent field is all-bits-set then we have infinity or NaN.
224     if biasedexp == mask(me.EXPWD):
225
226       ## If there's no hidden bit then the unit bit should be /set/, but is
227       ## /not/ part of the NaN payload -- or even significant for
228       ## distinguishing a NaN from an infinity.  If it's clear, signal an
229       ## error; if it's set, then clear it so that we don't have to think
230       ## about it again.
231       if not me.HIDDENP:
232         if unitbit: sig &= mask(me._sigwd - 1)
233         else: err.append("INVAL")
234
235       ## If the significand is (now) zero, we have an infinity and there's
236       ## nothing else to do.
237       if not sig:
238         flags.append("INF")
239         frac = fw = exp = None
240
241       ## Otherwise determine the NaN flavour and extract the payload.
242       else:
243         if sig&bit(me.PREC - 2): flags.append("QNAN")
244         else: flags.append("SNAN")
245         shift = 32*me._paywords + 2 - me.PREC
246         frac = (sig&mask(me.PREC - 2)) << shift
247         exp = None
248         fw = me._paywords
249
250     ## Otherwise we have a finite number.  We handle all of these together.
251     else:
252
253       ## If there's no hidden bit, then check that the unit bit matches the
254       ## exponent: it should be clear if the exponent field is all-bits-zero
255       ## (zero or subnormal numbers), and set otherwise (normal numbers).  If
256       ## this isn't the case, signal an error, but continue.  We'll normalize
257       ## the number correctly as we go.
258       if not me.HIDDENP:
259         if (not biasedexp and unitbit) or (biasedexp and not unitbit):
260           err.append("INVAL")
261
262       ## If the exponent is all-bits-zero then set it to 1; otherwise, if the
263       ## format uses a hidden bit then force the unit bit of our significand
264       ## on.  The absolute value is now exactly
265       ##
266       ##        2^{biasedexp-_expbias-PREC+1} sig
267       ##
268       ## in all cases.
269       if not biasedexp: biasedexp = 1
270       elif me.HIDDENP: sig |= bit(me._sigwd)
271
272       ## If the significand is now zero then the value must be zero.
273       if not sig:
274         flags.append("ZERO")
275         frac = fw = exp = None
276
277       ## Otherwise we have a nonzero finite value, which might need
278       ## normalization.
279       else:
280         sigwd = sig.bit_length()
281         fw = (sigwd + 31)//32
282         exp = biasedexp - me._expbias - me.PREC + sigwd + 1
283         frac = sig << (32*fw - sigwd)
284
285     ## All done.
286     return flags, exp, frac, fw, err
287
288   def _dump_as_bytes(me, var, x, wd):
289     """
290     Dump an assignment to VAR of X as a WD-byte binary string.
291
292     Print, on standard output, an assignment `VAR = ...' giving the value of
293     X, in hexadecimal, split with spaces into groups of 8 digits from the
294     right.
295     """
296
297     if not wd:
298       print("%s = #empty" % var)
299     else:
300       out = StringIO()
301       for i in xrange(wd - 1, -1, -1):
302         out.write("%02x" % ((x >> 8*i)&0xff))
303         if i and not i%4: out.write(" ")
304       print("%s = %s" % (var, out.getvalue()))
305
306   def _dump_flags(me, var, flags, zero = "0"):
307     """
308     Dump an assignment to VAR of FLAGS as a list of flags.
309
310     Print, on standard output, an assignment `VAR = ...' giving the named
311     flags.  Print ZERO (default `0') if FLAGS is empty.
312     """
313
314     if flags: print("%s = %s" % (var, " | ".join(flags)))
315     else: print("%s = %s" % (var, zero))
316
317   def genenc(me, ep = ExploreParameters()):
318     """
319     Print, on standard output, tests of encoding floating-point values.
320
321     The tests will cover positive and negative values, with the exponent and
322     signficand fields explored according to the parameters EP.
323     """
324
325     print("[enc%s]" % me.NAME)
326     for s in xrange(2):
327       for e in explore(me.EXPWD, ep.explo, ep.exphi):
328         for m in explore(me.PREC - 1, ep.siglo, ep.sighi):
329           if not me.HIDDENP and e: m |= bit(me.PREC - 1)
330           x = (s << (me.EXPWD + me._sigwd)) | (e << me._sigwd) | m
331           flags, exp, frac, fw, err = me.decode(x)
332           print("")
333           me._dump_flags("f", flags)
334           if exp is not None: print("e = %d" % exp)
335           if frac is not None:
336             while not frac&M32 and fw: frac >>= 32; fw -= 1
337             me._dump_as_bytes("m", frac, 4*fw)
338           me._dump_as_bytes("z", x, me._rawbytes)
339           if err: me._dump_flags("err", err, "OK")
340
341   def gendec(me, ep = ExploreParameters()):
342     """
343     Print, on standard output, tests of decoding floating-point values.
344
345     The tests will cover positive and negative values, with the exponent and
346     signficand fields explored according to the parameters EP.
347     """
348
349     print("[dec%s]" % me.NAME)
350     for s in xrange(2):
351       for e in explore(me.EXPWD, ep.explo, ep.exphi):
352         for m in explore(me._sigwd, ep.siglo, ep.sighi):
353           x = (s << (me.EXPWD + me._sigwd)) | (e << me._sigwd) | m
354           flags, exp, frac, fw, err = me.decode(x)
355           print("")
356           me._dump_as_bytes("x", x, me._rawbytes)
357           me._dump_flags("f", flags)
358           if exp is not None: print("e = %d" % exp)
359           if frac is not None: me._dump_as_bytes("m", frac, 4*fw)
360           if err: me._dump_flags("err", err, "OK")
361
362 class MiniFloat (IEEEFormat):
363   NAME = "mini"
364   EXPWD = 4
365   PREC = 4
366   HIDDENP = True
367
368 class BFloat16 (IEEEFormat):
369   NAME = "bf16"
370   EXPWD = 8
371   PREC = 8
372   HIDDENP = True
373
374 class Binary16 (IEEEFormat):
375   NAME = "f16"
376   EXPWD = 5
377   PREC = 11
378   HIDDENP = True
379
380 class Binary32 (IEEEFormat):
381   NAME = "f32"
382   EXPWD = 8
383   PREC = 24
384   HIDDENP = True
385
386 class Binary64 (IEEEFormat):
387   NAME = "f64"
388   EXPWD = 11
389   PREC = 53
390   HIDDENP = True
391
392 class Binary128 (IEEEFormat):
393   NAME = "f128"
394   EXPWD = 15
395   PREC = 113
396   HIDDENP = True
397
398 class DoubleExtended80 (IEEEFormat):
399   NAME = "idblext80"
400   EXPWD = 15
401   PREC = 64
402   HIDDENP = False
403
404 ###--------------------------------------------------------------------------
405 ### Main program.
406
407 op = OP.OptionParser \
408   (description = "Generate test data for IEEE format encoding and decoding",
409    usage = "usage: %prog [-E LO/HI] [-M LO/HI] [[enc|dec]FORMAT]")
410 for shortopt, longopt, kw in \
411   [("-E", "--explore-exponent",
412     dict(action = "store", metavar = "LO/HI", dest = "expparam",
413          help = "exponent exploration parameters")),
414    ("-M", "--explore-significand",
415     dict(action = "store", metavar = "LO/HI", dest = "sigparam",
416          help = "significand exploration parameters"))]:
417   op.add_option(shortopt, longopt, **kw)
418 opts, args = op.parse_args()
419
420 ep = ExploreParameters()
421 for optattr, loattr, hiattr in [("expparam", "explo", "exphi"),
422                                 ("sigparam", "siglo", "sighi")]:
423   opt = getattr(opts, optattr)
424   if opt is not None:
425     ok = False
426     try: sl = opt.index("/")
427     except ValueError: pass
428     else:
429       try: lo, hi = map(int, (opt[:sl], opt[sl + 1:]))
430       except ValueError: pass
431       else:
432         setattr(ep, loattr, lo)
433         setattr(ep, hiattr, hi)
434         ok = True
435     if not ok: op.error("bad exploration parameter `%s'" % opt)
436
437 if not args:
438   for fmt in iterkeys(FMTMAP):
439     args.append("enc" + fmt)
440     args.append("dec" + fmt)
441 firstp = True
442 for arg in args:
443   tail = fmt = None
444   if arg.startswith("enc"): tail = arg[3:]; gen = lambda f: f.genenc(ep)
445   elif arg.startswith("dec"): tail = arg[3:]; gen = lambda f: f.gendec(ep)
446   if tail is not None: fmt = FMTMAP.get(tail)
447   if not fmt: op.error("unknown test group `%s'" % arg)
448   if firstp: firstp = False
449   else: print("")
450   gen(fmt())
451
452 ###----- That's all, folks --------------------------------------------------