chiark / gitweb /
Fix completion checking in Killer Solo.
[sgt-puzzles.git] / unfinished / numgame.c
1 /*
2  * This program implements a breadth-first search which
3  * exhaustively solves the Countdown numbers game, and related
4  * games with slightly different rule sets such as `Flippo'.
5  * 
6  * Currently it is simply a standalone command-line utility to
7  * which you provide a set of numbers and it tells you everything
8  * it can make together with how many different ways it can be
9  * made. I would like ultimately to turn it into the generator for
10  * a Puzzles puzzle, but I haven't even started on writing a
11  * Puzzles user interface yet.
12  */
13
14 /*
15  * TODO:
16  * 
17  *  - start thinking about difficulty ratings
18  *     + anything involving associative operations will be flagged
19  *       as many-paths because of the associative options (e.g.
20  *       2*3*4 can be (2*3)*4 or 2*(3*4), or indeed (2*4)*3). This
21  *       is probably a _good_ thing, since those are unusually
22  *       easy.
23  *     + tree-structured calculations ((a*b)/(c+d)) have multiple
24  *       paths because the independent branches of the tree can be
25  *       evaluated in either order, whereas straight-line
26  *       calculations with no branches will be considered easier.
27  *       Can we do anything about this? It's certainly not clear to
28  *       me that tree-structure calculations are _easier_, although
29  *       I'm also not convinced they're harder.
30  *     + I think for a realistic difficulty assessment we must also
31  *       consider the `obviousness' of the arithmetic operations in
32  *       some heuristic sense, and also (in Countdown) how many
33  *       numbers ended up being used.
34  *  - actually try some generations
35  *  - at this point we're probably ready to start on the Puzzles
36  *    integration.
37  */
38
39 #include <stdio.h>
40 #include <string.h>
41 #include <limits.h>
42 #include <assert.h>
43 #include <math.h>
44
45 #include "puzzles.h"
46 #include "tree234.h"
47
48 /*
49  * To search for numbers we can make, we employ a breadth-first
50  * search across the space of sets of input numbers. That is, for
51  * example, we start with the set (3,6,25,50,75,100); we apply
52  * moves which involve combining two numbers (e.g. adding the 50
53  * and the 75 takes us to the set (3,6,25,100,125); and then we see
54  * if we ever end up with a set containing (say) 952.
55  * 
56  * If the rules are changed so that all the numbers must be used,
57  * this is easy to adjust to: we simply see if we end up with a set
58  * containing _only_ (say) 952.
59  * 
60  * Obviously, we can vary the rules about permitted arithmetic
61  * operations simply by altering the set of valid moves in the bfs.
62  * However, there's one common rule in this sort of puzzle which
63  * takes a little more thought, and that's _concatenation_. For
64  * example, if you are given (say) four 4s and required to make 10,
65  * you are permitted to combine two of the 4s into a 44 to begin
66  * with, making (44-4)/4 = 10. However, you are generally not
67  * allowed to concatenate two numbers that _weren't_ both in the
68  * original input set (you couldn't multiply two 4s to get 16 and
69  * then concatenate a 4 on to it to make 164), so concatenation is
70  * not an operation which is valid in all situations.
71  * 
72  * We could enforce this restriction by storing a flag alongside
73  * each number indicating whether or not it's an original number;
74  * the rules being that concatenation of two numbers is only valid
75  * if they both have the original flag, and that its output _also_
76  * has the original flag (so that you can concatenate three 4s into
77  * a 444), but that applying any other arithmetic operation clears
78  * the original flag on the output. However, we can get marginally
79  * simpler than that by observing that since concatenation has to
80  * happen to a number before any other operation, we can simply
81  * place all the concatenations at the start of the search. In
82  * other words, we have a global flag on an entire number _set_
83  * which indicates whether we are still permitted to perform
84  * concatenations; if so, we can concatenate any of the numbers in
85  * that set. Performing any other operation clears the flag.
86  */
87
88 #define SETFLAG_CONCAT 1               /* we can do concatenation */
89
90 struct sets;
91
92 struct ancestor {
93     struct set *prev;                  /* index of ancestor set in set list */
94     unsigned char pa, pb, po, pr;      /* operation that got here from prev */
95 };
96
97 struct set {
98     int *numbers;                      /* rationals stored as n,d pairs */
99     short nnumbers;                    /* # of rationals, so half # of ints */
100     short flags;                       /* SETFLAG_CONCAT only, at present */
101     int npaths;                        /* number of ways to reach this set */
102     struct ancestor a;                 /* primary ancestor */
103     struct ancestor *as;               /* further ancestors, if we care */
104     int nas, assize;
105 };
106
107 struct output {
108     int number;
109     struct set *set;
110     int index;                         /* which number in the set is it? */
111     int npaths;                        /* number of ways to reach this */
112 };
113
114 #define SETLISTLEN 1024
115 #define NUMBERLISTLEN 32768
116 #define OUTPUTLISTLEN 1024
117 struct operation;
118 struct sets {
119     struct set **setlists;
120     int nsets, nsetlists, setlistsize;
121     tree234 *settree;
122     int **numberlists;
123     int nnumbers, nnumberlists, numberlistsize;
124     struct output **outputlists;
125     int noutputs, noutputlists, outputlistsize;
126     tree234 *outputtree;
127     const struct operation *const *ops;
128 };
129
130 #define OPFLAG_NEEDS_CONCAT 1
131 #define OPFLAG_KEEPS_CONCAT 2
132 #define OPFLAG_UNARY        4
133 #define OPFLAG_UNARYPREFIX  8
134 #define OPFLAG_FN           16
135
136 struct operation {
137     /*
138      * Most operations should be shown in the output working, but
139      * concatenation should not; we just take the result of the
140      * concatenation and assume that it's obvious how it was
141      * derived.
142      */
143     int display;
144
145     /*
146      * Text display of the operator, in expressions and for
147      * debugging respectively.
148      */
149     char *text, *dbgtext;
150
151     /*
152      * Flags dictating when the operator can be applied.
153      */
154     int flags;
155
156     /*
157      * Priority of the operator (for avoiding unnecessary
158      * parentheses when formatting it into a string).
159      */
160     int priority;
161
162     /*
163      * Associativity of the operator. Bit 0 means we need parens
164      * when the left operand of one of these operators is another
165      * instance of it, e.g. (2^3)^4. Bit 1 means we need parens
166      * when the right operand is another instance of the same
167      * operator, e.g. 2-(3-4). Thus:
168      * 
169      *  - this field is 0 for a fully associative operator, since
170      *    we never need parens.
171      *  - it's 1 for a right-associative operator.
172      *  - it's 2 for a left-associative operator.
173      *  - it's 3 for a _non_-associative operator (which always
174      *    uses parens just to be sure).
175      */
176     int assoc;
177
178     /*
179      * Whether the operator is commutative. Saves time in the
180      * search if we don't have to try it both ways round.
181      */
182     int commutes;
183
184     /*
185      * Function which implements the operator. Returns TRUE on
186      * success, FALSE on failure. Takes two rationals and writes
187      * out a third.
188      */
189     int (*perform)(int *a, int *b, int *output);
190 };
191
192 struct rules {
193     const struct operation *const *ops;
194     int use_all;
195 };
196
197 #define MUL(r, a, b) do { \
198     (r) = (a) * (b); \
199     if ((b) && (a) && (r) / (b) != (a)) return FALSE; \
200 } while (0)
201
202 #define ADD(r, a, b) do { \
203     (r) = (a) + (b); \
204     if ((a) > 0 && (b) > 0 && (r) < 0) return FALSE; \
205     if ((a) < 0 && (b) < 0 && (r) > 0) return FALSE; \
206 } while (0)
207
208 #define OUT(output, n, d) do { \
209     int g = gcd((n),(d)); \
210     if (g < 0) g = -g; \
211     if ((d) < 0) g = -g; \
212     if (g == -1 && (n) < -INT_MAX) return FALSE; \
213     if (g == -1 && (d) < -INT_MAX) return FALSE; \
214     (output)[0] = (n)/g; \
215     (output)[1] = (d)/g; \
216     assert((output)[1] > 0); \
217 } while (0)
218
219 static int gcd(int x, int y)
220 {
221     while (x != 0 && y != 0) {
222         int t = x;
223         x = y;
224         y = t % y;
225     }
226
227     return abs(x + y);                 /* i.e. whichever one isn't zero */
228 }
229
230 static int perform_add(int *a, int *b, int *output)
231 {
232     int at, bt, tn, bn;
233     /*
234      * a0/a1 + b0/b1 = (a0*b1 + b0*a1) / (a1*b1)
235      */
236     MUL(at, a[0], b[1]);
237     MUL(bt, b[0], a[1]);
238     ADD(tn, at, bt);
239     MUL(bn, a[1], b[1]);
240     OUT(output, tn, bn);
241     return TRUE;
242 }
243
244 static int perform_sub(int *a, int *b, int *output)
245 {
246     int at, bt, tn, bn;
247     /*
248      * a0/a1 - b0/b1 = (a0*b1 - b0*a1) / (a1*b1)
249      */
250     MUL(at, a[0], b[1]);
251     MUL(bt, b[0], a[1]);
252     ADD(tn, at, -bt);
253     MUL(bn, a[1], b[1]);
254     OUT(output, tn, bn);
255     return TRUE;
256 }
257
258 static int perform_mul(int *a, int *b, int *output)
259 {
260     int tn, bn;
261     /*
262      * a0/a1 * b0/b1 = (a0*b0) / (a1*b1)
263      */
264     MUL(tn, a[0], b[0]);
265     MUL(bn, a[1], b[1]);
266     OUT(output, tn, bn);
267     return TRUE;
268 }
269
270 static int perform_div(int *a, int *b, int *output)
271 {
272     int tn, bn;
273
274     /*
275      * Division by zero is outlawed.
276      */
277     if (b[0] == 0)
278         return FALSE;
279
280     /*
281      * a0/a1 / b0/b1 = (a0*b1) / (a1*b0)
282      */
283     MUL(tn, a[0], b[1]);
284     MUL(bn, a[1], b[0]);
285     OUT(output, tn, bn);
286     return TRUE;
287 }
288
289 static int perform_exact_div(int *a, int *b, int *output)
290 {
291     int tn, bn;
292
293     /*
294      * Division by zero is outlawed.
295      */
296     if (b[0] == 0)
297         return FALSE;
298
299     /*
300      * a0/a1 / b0/b1 = (a0*b1) / (a1*b0)
301      */
302     MUL(tn, a[0], b[1]);
303     MUL(bn, a[1], b[0]);
304     OUT(output, tn, bn);
305
306     /*
307      * Exact division means we require the result to be an integer.
308      */
309     return (output[1] == 1);
310 }
311
312 static int max_p10(int n, int *p10_r)
313 {
314     /*
315      * Find the smallest power of ten strictly greater than n.
316      *
317      * Special case: we must return at least 10, even if n is
318      * zero. (This is because this function is used for finding
319      * the power of ten by which to multiply a number being
320      * concatenated to the front of n, and concatenating 1 to 0
321      * should yield 10 and not 1.)
322      */
323     int p10 = 10;
324     while (p10 <= (INT_MAX/10) && p10 <= n)
325         p10 *= 10;
326     if (p10 > INT_MAX/10)
327         return FALSE;                  /* integer overflow */
328     *p10_r = p10;
329     return TRUE;
330 }
331
332 static int perform_concat(int *a, int *b, int *output)
333 {
334     int t1, t2, p10;
335
336     /*
337      * We can't concatenate anything which isn't a non-negative
338      * integer.
339      */
340     if (a[1] != 1 || b[1] != 1 || a[0] < 0 || b[0] < 0)
341         return FALSE;
342
343     /*
344      * For concatenation, we can safely assume leading zeroes
345      * aren't an issue. It isn't clear whether they `should' be
346      * allowed, but it turns out not to matter: concatenating a
347      * leading zero on to a number in order to harmlessly get rid
348      * of the zero is never necessary because unwanted zeroes can
349      * be disposed of by adding them to something instead. So we
350      * disallow them always.
351      *
352      * The only other possibility is that you might want to
353      * concatenate a leading zero on to something and then
354      * concatenate another non-zero digit on to _that_ (to make,
355      * for example, 106); but that's also unnecessary, because you
356      * can make 106 just as easily by concatenating the 0 on to the
357      * _end_ of the 1 first.
358      */
359     if (a[0] == 0)
360         return FALSE;
361
362     if (!max_p10(b[0], &p10)) return FALSE;
363
364     MUL(t1, p10, a[0]);
365     ADD(t2, t1, b[0]);
366     OUT(output, t2, 1);
367     return TRUE;
368 }
369
370 #define IPOW(ret, x, y) do { \
371     int ipow_limit = (y); \
372     if ((x) == 1 || (x) == 0) ipow_limit = 1; \
373     else if ((x) == -1) ipow_limit &= 1; \
374     (ret) = 1; \
375     while (ipow_limit-- > 0) { \
376         int tmp; \
377         MUL(tmp, ret, x); \
378         ret = tmp; \
379     } \
380 } while (0)
381
382 static int perform_exp(int *a, int *b, int *output)
383 {
384     int an, ad, xn, xd;
385
386     /*
387      * Exponentiation is permitted if the result is rational. This
388      * means that:
389      * 
390      *  - first we see whether we can take the (denominator-of-b)th
391      *    root of a and get a rational; if not, we give up.
392      * 
393      *  - then we do take that root of a
394      * 
395      *  - then we multiply by itself (numerator-of-b) times.
396      */
397     if (b[1] > 1) {
398         an = (int)(0.5 + pow(a[0], 1.0/b[1]));
399         ad = (int)(0.5 + pow(a[1], 1.0/b[1]));
400         IPOW(xn, an, b[1]);
401         IPOW(xd, ad, b[1]);
402         if (xn != a[0] || xd != a[1])
403             return FALSE;
404     } else {
405         an = a[0];
406         ad = a[1];
407     }
408     if (b[0] >= 0) {
409         IPOW(xn, an, b[0]);
410         IPOW(xd, ad, b[0]);
411     } else {
412         IPOW(xd, an, -b[0]);
413         IPOW(xn, ad, -b[0]);
414     }
415     if (xd == 0)
416         return FALSE;
417
418     OUT(output, xn, xd);
419     return TRUE;
420 }
421
422 static int perform_factorial(int *a, int *b, int *output)
423 {
424     int ret, t, i;
425
426     /*
427      * Factorials of non-negative integers are permitted.
428      */
429     if (a[1] != 1 || a[0] < 0)
430         return FALSE;
431
432     /*
433      * However, a special case: we don't take a factorial of
434      * anything which would thereby remain the same.
435      */
436     if (a[0] == 1 || a[0] == 2)
437         return FALSE;
438
439     ret = 1;
440     for (i = 1; i <= a[0]; i++) {
441         MUL(t, ret, i);
442         ret = t;
443     }
444
445     OUT(output, ret, 1);
446     return TRUE;
447 }
448
449 static int perform_decimal(int *a, int *b, int *output)
450 {
451     int p10;
452
453     /*
454      * Add a decimal digit to the front of a number;
455      * fail if it's not an integer.
456      * So, 1 --> 0.1, 15 --> 0.15,
457      * or, rather, 1 --> 1/10, 15 --> 15/100,
458      * x --> x / (smallest power of 10 > than x)
459      *
460      */
461     if (a[1] != 1) return FALSE;
462
463     if (!max_p10(a[0], &p10)) return FALSE;
464
465     OUT(output, a[0], p10);
466     return TRUE;
467 }
468
469 static int perform_recur(int *a, int *b, int *output)
470 {
471     int p10, tn, bn;
472
473     /*
474      * This converts a number like .4 to .44444..., or .45 to .45454...
475      * The input number must be -1 < a < 1.
476      *
477      * Calculate the smallest power of 10 that divides the denominator exactly,
478      * returning if no such power of 10 exists. Then multiply the numerator
479      * up accordingly, and the new denominator becomes that power of 10 - 1.
480      */
481     if (abs(a[0]) >= abs(a[1])) return FALSE; /* -1 < a < 1 */
482
483     p10 = 10;
484     while (p10 <= (INT_MAX/10)) {
485         if ((a[1] <= p10) && (p10 % a[1]) == 0) goto found;
486         p10 *= 10;
487     }
488     return FALSE;
489 found:
490     tn = a[0] * (p10 / a[1]);
491     bn = p10 - 1;
492
493     OUT(output, tn, bn);
494     return TRUE;
495 }
496
497 static int perform_root(int *a, int *b, int *output)
498 {
499     /*
500      * A root B is: 1           iff a == 0
501      *              B ^ (1/A)   otherwise
502      */
503     int ainv[2], res;
504
505     if (a[0] == 0) {
506         OUT(output, 1, 1);
507         return TRUE;
508     }
509
510     OUT(ainv, a[1], a[0]);
511     res = perform_exp(b, ainv, output);
512     return res;
513 }
514
515 static int perform_perc(int *a, int *b, int *output)
516 {
517     if (a[0] == 0) return FALSE; /* 0% = 0, uninteresting. */
518     if (a[1] > (INT_MAX/100)) return FALSE;
519
520     OUT(output, a[0], a[1]*100);
521     return TRUE;
522 }
523
524 static int perform_gamma(int *a, int *b, int *output)
525 {
526     int asub1[2];
527
528     /*
529      * gamma(a) = (a-1)!
530      *
531      * special case not caught by perform_fact: gamma(1) is 1 so
532      * don't bother.
533      */
534     if (a[0] == 1 && a[1] == 1) return FALSE;
535
536     OUT(asub1, a[0]-a[1], a[1]);
537     return perform_factorial(asub1, b, output);
538 }
539
540 static int perform_sqrt(int *a, int *b, int *output)
541 {
542     int half[2] = { 1, 2 };
543
544     /*
545      * sqrt(0) == 0, sqrt(1) == 1: don't perform unary noops.
546      */
547     if (a[0] == 0 || (a[0] == 1 && a[1] == 1)) return FALSE;
548
549     return perform_exp(a, half, output);
550 }
551
552 const static struct operation op_add = {
553     TRUE, "+", "+", 0, 10, 0, TRUE, perform_add
554 };
555 const static struct operation op_sub = {
556     TRUE, "-", "-", 0, 10, 2, FALSE, perform_sub
557 };
558 const static struct operation op_mul = {
559     TRUE, "*", "*", 0, 20, 0, TRUE, perform_mul
560 };
561 const static struct operation op_div = {
562     TRUE, "/", "/", 0, 20, 2, FALSE, perform_div
563 };
564 const static struct operation op_xdiv = {
565     TRUE, "/", "/", 0, 20, 2, FALSE, perform_exact_div
566 };
567 const static struct operation op_concat = {
568     FALSE, "", "concat", OPFLAG_NEEDS_CONCAT | OPFLAG_KEEPS_CONCAT,
569         1000, 0, FALSE, perform_concat
570 };
571 const static struct operation op_exp = {
572     TRUE, "^", "^", 0, 30, 1, FALSE, perform_exp
573 };
574 const static struct operation op_factorial = {
575     TRUE, "!", "!", OPFLAG_UNARY, 40, 0, FALSE, perform_factorial
576 };
577 const static struct operation op_decimal = {
578     TRUE, ".", ".", OPFLAG_UNARY | OPFLAG_UNARYPREFIX | OPFLAG_NEEDS_CONCAT | OPFLAG_KEEPS_CONCAT, 50, 0, FALSE, perform_decimal
579 };
580 const static struct operation op_recur = {
581     TRUE, "...", "recur", OPFLAG_UNARY | OPFLAG_NEEDS_CONCAT, 45, 2, FALSE, perform_recur
582 };
583 const static struct operation op_root = {
584     TRUE, "v~", "root", 0, 30, 1, FALSE, perform_root
585 };
586 const static struct operation op_perc = {
587     TRUE, "%", "%", OPFLAG_UNARY | OPFLAG_NEEDS_CONCAT, 45, 1, FALSE, perform_perc
588 };
589 const static struct operation op_gamma = {
590     TRUE, "gamma", "gamma", OPFLAG_UNARY | OPFLAG_UNARYPREFIX | OPFLAG_FN, 1, 3, FALSE, perform_gamma
591 };
592 const static struct operation op_sqrt = {
593     TRUE, "v~", "sqrt", OPFLAG_UNARY | OPFLAG_UNARYPREFIX, 30, 1, FALSE, perform_sqrt
594 };
595
596 /*
597  * In Countdown, divisions resulting in fractions are disallowed.
598  * http://www.askoxford.com/wordgames/countdown/rules/
599  */
600 const static struct operation *const ops_countdown[] = {
601     &op_add, &op_mul, &op_sub, &op_xdiv, NULL
602 };
603 const static struct rules rules_countdown = {
604     ops_countdown, FALSE
605 };
606
607 /*
608  * A slightly different rule set which handles the reasonably well
609  * known puzzle of making 24 using two 3s and two 8s. For this we
610  * need rational rather than integer division.
611  */
612 const static struct operation *const ops_3388[] = {
613     &op_add, &op_mul, &op_sub, &op_div, NULL
614 };
615 const static struct rules rules_3388 = {
616     ops_3388, TRUE
617 };
618
619 /*
620  * A still more permissive rule set usable for the four-4s problem
621  * and similar things. Permits concatenation.
622  */
623 const static struct operation *const ops_four4s[] = {
624     &op_add, &op_mul, &op_sub, &op_div, &op_concat, NULL
625 };
626 const static struct rules rules_four4s = {
627     ops_four4s, TRUE
628 };
629
630 /*
631  * The most permissive ruleset I can think of. Permits
632  * exponentiation, and also silly unary operators like factorials.
633  */
634 const static struct operation *const ops_anythinggoes[] = {
635     &op_add, &op_mul, &op_sub, &op_div, &op_concat, &op_exp, &op_factorial, 
636     &op_decimal, &op_recur, &op_root, &op_perc, &op_gamma, &op_sqrt, NULL
637 };
638 const static struct rules rules_anythinggoes = {
639     ops_anythinggoes, TRUE
640 };
641
642 #define ratcmp(a,op,b) ( (long long)(a)[0] * (b)[1] op \
643                          (long long)(b)[0] * (a)[1] )
644
645 static int addtoset(struct set *set, int newnumber[2])
646 {
647     int i, j;
648
649     /* Find where we want to insert the new number */
650     for (i = 0; i < set->nnumbers &&
651          ratcmp(set->numbers+2*i, <, newnumber); i++);
652
653     /* Move everything else up */
654     for (j = set->nnumbers; j > i; j--) {
655         set->numbers[2*j] = set->numbers[2*j-2];
656         set->numbers[2*j+1] = set->numbers[2*j-1];
657     }
658
659     /* Insert the new number */
660     set->numbers[2*i] = newnumber[0];
661     set->numbers[2*i+1] = newnumber[1];
662
663     set->nnumbers++;
664
665     return i;
666 }
667
668 #define ensure(array, size, newlen, type) do { \
669     if ((newlen) > (size)) { \
670         (size) = (newlen) + 512; \
671         (array) = sresize((array), (size), type); \
672     } \
673 } while (0)
674
675 static int setcmp(void *av, void *bv)
676 {
677     struct set *a = (struct set *)av;
678     struct set *b = (struct set *)bv;
679     int i;
680
681     if (a->nnumbers < b->nnumbers)
682         return -1;
683     else if (a->nnumbers > b->nnumbers)
684         return +1;
685
686     if (a->flags < b->flags)
687         return -1;
688     else if (a->flags > b->flags)
689         return +1;
690
691     for (i = 0; i < a->nnumbers; i++) {
692         if (ratcmp(a->numbers+2*i, <, b->numbers+2*i))
693             return -1;
694         else if (ratcmp(a->numbers+2*i, >, b->numbers+2*i))
695             return +1;
696     }
697
698     return 0;
699 }
700
701 static int outputcmp(void *av, void *bv)
702 {
703     struct output *a = (struct output *)av;
704     struct output *b = (struct output *)bv;
705
706     if (a->number < b->number)
707         return -1;
708     else if (a->number > b->number)
709         return +1;
710
711     return 0;
712 }
713
714 static int outputfindcmp(void *av, void *bv)
715 {
716     int *a = (int *)av;
717     struct output *b = (struct output *)bv;
718
719     if (*a < b->number)
720         return -1;
721     else if (*a > b->number)
722         return +1;
723
724     return 0;
725 }
726
727 static void addset(struct sets *s, struct set *set, int multiple,
728                    struct set *prev, int pa, int po, int pb, int pr)
729 {
730     struct set *s2;
731     int npaths = (prev ? prev->npaths : 1);
732
733     assert(set == s->setlists[s->nsets / SETLISTLEN] + s->nsets % SETLISTLEN);
734     s2 = add234(s->settree, set);
735     if (s2 == set) {
736         /*
737          * New set added to the tree.
738          */
739         set->a.prev = prev;
740         set->a.pa = pa;
741         set->a.po = po;
742         set->a.pb = pb;
743         set->a.pr = pr;
744         set->npaths = npaths;
745         s->nsets++;
746         s->nnumbers += 2 * set->nnumbers;
747         set->as = NULL;
748         set->nas = set->assize = 0;
749     } else {
750         /*
751          * Rediscovered an existing set. Update its npaths.
752          */
753         s2->npaths += npaths;
754         /*
755          * And optionally enter it as an additional ancestor.
756          */
757         if (multiple) {
758             if (s2->nas >= s2->assize) {
759                 s2->assize = s2->nas * 3 / 2 + 4;
760                 s2->as = sresize(s2->as, s2->assize, struct ancestor);
761             }
762             s2->as[s2->nas].prev = prev;
763             s2->as[s2->nas].pa = pa;
764             s2->as[s2->nas].po = po;
765             s2->as[s2->nas].pb = pb;
766             s2->as[s2->nas].pr = pr;
767             s2->nas++;
768         }
769     }
770 }
771
772 static struct set *newset(struct sets *s, int nnumbers, int flags)
773 {
774     struct set *sn;
775
776     ensure(s->setlists, s->setlistsize, s->nsets/SETLISTLEN+1, struct set *);
777     while (s->nsetlists <= s->nsets / SETLISTLEN)
778         s->setlists[s->nsetlists++] = snewn(SETLISTLEN, struct set);
779     sn = s->setlists[s->nsets / SETLISTLEN] + s->nsets % SETLISTLEN;
780
781     if (s->nnumbers + nnumbers * 2 > s->nnumberlists * NUMBERLISTLEN)
782         s->nnumbers = s->nnumberlists * NUMBERLISTLEN;
783     ensure(s->numberlists, s->numberlistsize,
784            s->nnumbers/NUMBERLISTLEN+1, int *);
785     while (s->nnumberlists <= s->nnumbers / NUMBERLISTLEN)
786         s->numberlists[s->nnumberlists++] = snewn(NUMBERLISTLEN, int);
787     sn->numbers = s->numberlists[s->nnumbers / NUMBERLISTLEN] +
788         s->nnumbers % NUMBERLISTLEN;
789
790     /*
791      * Start the set off empty.
792      */
793     sn->nnumbers = 0;
794
795     sn->flags = flags;
796
797     return sn;
798 }
799
800 static int addoutput(struct sets *s, struct set *ss, int index, int *n)
801 {
802     struct output *o, *o2;
803
804     /*
805      * Target numbers are always integers.
806      */
807     if (ss->numbers[2*index+1] != 1)
808         return FALSE;
809
810     ensure(s->outputlists, s->outputlistsize, s->noutputs/OUTPUTLISTLEN+1,
811            struct output *);
812     while (s->noutputlists <= s->noutputs / OUTPUTLISTLEN)
813         s->outputlists[s->noutputlists++] = snewn(OUTPUTLISTLEN,
814                                                   struct output);
815     o = s->outputlists[s->noutputs / OUTPUTLISTLEN] +
816         s->noutputs % OUTPUTLISTLEN;
817
818     o->number = ss->numbers[2*index];
819     o->set = ss;
820     o->index = index;
821     o->npaths = ss->npaths;
822     o2 = add234(s->outputtree, o);
823     if (o2 != o) {
824         o2->npaths += o->npaths;
825     } else {
826         s->noutputs++;
827     }
828     *n = o->number;
829     return TRUE;
830 }
831
832 static struct sets *do_search(int ninputs, int *inputs,
833                               const struct rules *rules, int *target,
834                               int debug, int multiple)
835 {
836     struct sets *s;
837     struct set *sn;
838     int qpos, i;
839     const struct operation *const *ops = rules->ops;
840
841     s = snew(struct sets);
842     s->setlists = NULL;
843     s->nsets = s->nsetlists = s->setlistsize = 0;
844     s->numberlists = NULL;
845     s->nnumbers = s->nnumberlists = s->numberlistsize = 0;
846     s->outputlists = NULL;
847     s->noutputs = s->noutputlists = s->outputlistsize = 0;
848     s->settree = newtree234(setcmp);
849     s->outputtree = newtree234(outputcmp);
850     s->ops = ops;
851
852     /*
853      * Start with the input set.
854      */
855     sn = newset(s, ninputs, SETFLAG_CONCAT);
856     for (i = 0; i < ninputs; i++) {
857         int newnumber[2];
858         newnumber[0] = inputs[i];
859         newnumber[1] = 1;
860         addtoset(sn, newnumber);
861     }
862     addset(s, sn, multiple, NULL, 0, 0, 0, 0);
863
864     /*
865      * Now perform the breadth-first search: keep looping over sets
866      * until we run out of steam.
867      */
868     qpos = 0;
869     while (qpos < s->nsets) {
870         struct set *ss = s->setlists[qpos / SETLISTLEN] + qpos % SETLISTLEN;
871         struct set *sn;
872         int i, j, k, m;
873
874         if (debug) {
875             int i;
876             printf("processing set:");
877             for (i = 0; i < ss->nnumbers; i++) {
878                 printf(" %d", ss->numbers[2*i]);
879                 if (ss->numbers[2*i+1] != 1)
880                     printf("/%d", ss->numbers[2*i+1]);
881             }
882             printf("\n");
883         }
884
885         /*
886          * Record all the valid output numbers in this state. We
887          * can always do this if there's only one number in the
888          * state; otherwise, we can only do it if we aren't
889          * required to use all the numbers in coming to our answer.
890          */
891         if (ss->nnumbers == 1 || !rules->use_all) {
892             for (i = 0; i < ss->nnumbers; i++) {
893                 int n;
894
895                 if (addoutput(s, ss, i, &n) && target && n == *target)
896                     return s;
897             }
898         }
899
900         /*
901          * Try every possible operation from this state.
902          */
903         for (k = 0; ops[k] && ops[k]->perform; k++) {
904             if ((ops[k]->flags & OPFLAG_NEEDS_CONCAT) &&
905                 !(ss->flags & SETFLAG_CONCAT))
906                 continue;              /* can't use this operation here */
907             for (i = 0; i < ss->nnumbers; i++) {
908                 int jlimit = (ops[k]->flags & OPFLAG_UNARY ? 1 : ss->nnumbers);
909                 for (j = 0; j < jlimit; j++) {
910                     int n[2], newnn = ss->nnumbers;
911                     int pa, po, pb, pr;
912
913                     if (!(ops[k]->flags & OPFLAG_UNARY)) {
914                         if (i == j)
915                             continue;  /* can't combine a number with itself */
916                         if (i > j && ops[k]->commutes)
917                             continue;  /* no need to do this both ways round */
918                         newnn--;
919                     }
920                     if (!ops[k]->perform(ss->numbers+2*i, ss->numbers+2*j, n))
921                         continue;      /* operation failed */
922
923                     sn = newset(s, newnn, ss->flags);
924
925                     if (!(ops[k]->flags & OPFLAG_KEEPS_CONCAT))
926                         sn->flags &= ~SETFLAG_CONCAT;
927
928                     for (m = 0; m < ss->nnumbers; m++) {
929                         if (m == i || (!(ops[k]->flags & OPFLAG_UNARY) &&
930                                        m == j))
931                             continue;
932                         sn->numbers[2*sn->nnumbers] = ss->numbers[2*m];
933                         sn->numbers[2*sn->nnumbers + 1] = ss->numbers[2*m + 1];
934                         sn->nnumbers++;
935                     }
936                     pa = i;
937                     if (ops[k]->flags & OPFLAG_UNARY)
938                         pb = sn->nnumbers+10;
939                     else
940                         pb = j;
941                     po = k;
942                     pr = addtoset(sn, n);
943                     addset(s, sn, multiple, ss, pa, po, pb, pr);
944                     if (debug) {
945                         int i;
946                         if (ops[k]->flags & OPFLAG_UNARYPREFIX)
947                             printf("  %s %d ->", ops[po]->dbgtext, pa);
948                         else if (ops[k]->flags & OPFLAG_UNARY)
949                             printf("  %d %s ->", pa, ops[po]->dbgtext);
950                         else
951                             printf("  %d %s %d ->", pa, ops[po]->dbgtext, pb);
952                         for (i = 0; i < sn->nnumbers; i++) {
953                             printf(" %d", sn->numbers[2*i]);
954                             if (sn->numbers[2*i+1] != 1)
955                                 printf("/%d", sn->numbers[2*i+1]);
956                         }
957                         printf("\n");
958                     }
959                 }
960             }
961         }
962
963         qpos++;
964     }
965
966     return s;
967 }
968
969 static void free_sets(struct sets *s)
970 {
971     int i;
972
973     freetree234(s->settree);
974     freetree234(s->outputtree);
975     for (i = 0; i < s->nsetlists; i++)
976         sfree(s->setlists[i]);
977     sfree(s->setlists);
978     for (i = 0; i < s->nnumberlists; i++)
979         sfree(s->numberlists[i]);
980     sfree(s->numberlists);
981     for (i = 0; i < s->noutputlists; i++)
982         sfree(s->outputlists[i]);
983     sfree(s->outputlists);
984     sfree(s);
985 }
986
987 /*
988  * Print a text formula for producing a given output.
989  */
990 void print_recurse(struct sets *s, struct set *ss, int pathindex, int index,
991                    int priority, int assoc, int child);
992 void print_recurse_inner(struct sets *s, struct set *ss,
993                          struct ancestor *a, int pathindex, int index,
994                          int priority, int assoc, int child)
995 {
996     if (a->prev && index != a->pr) {
997         int pi;
998
999         /*
1000          * This number was passed straight down from this set's
1001          * predecessor. Find its index in the previous set and
1002          * recurse to there.
1003          */
1004         pi = index;
1005         assert(pi != a->pr);
1006         if (pi > a->pr)
1007             pi--;
1008         if (pi >= min(a->pa, a->pb)) {
1009             pi++;
1010             if (pi >= max(a->pa, a->pb))
1011                 pi++;
1012         }
1013         print_recurse(s, a->prev, pathindex, pi, priority, assoc, child);
1014     } else if (a->prev && index == a->pr &&
1015                s->ops[a->po]->display) {
1016         /*
1017          * This number was created by a displayed operator in the
1018          * transition from this set to its predecessor. Hence we
1019          * write an open paren, then recurse into the first
1020          * operand, then write the operator, then the second
1021          * operand, and finally close the paren.
1022          */
1023         char *op;
1024         int parens, thispri, thisassoc;
1025
1026         /*
1027          * Determine whether we need parentheses.
1028          */
1029         thispri = s->ops[a->po]->priority;
1030         thisassoc = s->ops[a->po]->assoc;
1031         parens = (thispri < priority ||
1032                   (thispri == priority && (assoc & child)));
1033
1034         if (parens)
1035             putchar('(');
1036
1037         if (s->ops[a->po]->flags & OPFLAG_UNARYPREFIX)
1038             for (op = s->ops[a->po]->text; *op; op++)
1039                 putchar(*op);
1040
1041         if (s->ops[a->po]->flags & OPFLAG_FN)
1042             putchar('(');
1043
1044         print_recurse(s, a->prev, pathindex, a->pa, thispri, thisassoc, 1);
1045
1046         if (s->ops[a->po]->flags & OPFLAG_FN)
1047             putchar(')');
1048
1049         if (!(s->ops[a->po]->flags & OPFLAG_UNARYPREFIX))
1050             for (op = s->ops[a->po]->text; *op; op++)
1051                 putchar(*op);
1052
1053         if (!(s->ops[a->po]->flags & OPFLAG_UNARY))
1054             print_recurse(s, a->prev, pathindex, a->pb, thispri, thisassoc, 2);
1055
1056         if (parens)
1057             putchar(')');
1058     } else {
1059         /*
1060          * This number is either an original, or something formed
1061          * by a non-displayed operator (concatenation). Either way,
1062          * we display it as is.
1063          */
1064         printf("%d", ss->numbers[2*index]);
1065         if (ss->numbers[2*index+1] != 1)
1066             printf("/%d", ss->numbers[2*index+1]);
1067     }
1068 }
1069 void print_recurse(struct sets *s, struct set *ss, int pathindex, int index,
1070                    int priority, int assoc, int child)
1071 {
1072     if (!ss->a.prev || pathindex < ss->a.prev->npaths) {
1073         print_recurse_inner(s, ss, &ss->a, pathindex,
1074                             index, priority, assoc, child);
1075     } else {
1076         int i;
1077         pathindex -= ss->a.prev->npaths;
1078         for (i = 0; i < ss->nas; i++) {
1079             if (pathindex < ss->as[i].prev->npaths) {
1080                 print_recurse_inner(s, ss, &ss->as[i], pathindex,
1081                                     index, priority, assoc, child);
1082                 break;
1083             }
1084             pathindex -= ss->as[i].prev->npaths;
1085         }
1086     }
1087 }
1088 void print(int pathindex, struct sets *s, struct output *o)
1089 {
1090     print_recurse(s, o->set, pathindex, o->index, 0, 0, 0);
1091 }
1092
1093 /*
1094  * gcc -g -O0 -o numgame numgame.c -I.. ../{malloc,tree234,nullfe}.c -lm
1095  */
1096 int main(int argc, char **argv)
1097 {
1098     int doing_opts = TRUE;
1099     const struct rules *rules = NULL;
1100     char *pname = argv[0];
1101     int got_target = FALSE, target = 0;
1102     int numbers[10], nnumbers = 0;
1103     int verbose = FALSE;
1104     int pathcounts = FALSE;
1105     int multiple = FALSE;
1106     int debug_bfs = FALSE;
1107     int got_range = FALSE, rangemin = 0, rangemax = 0;
1108
1109     struct output *o;
1110     struct sets *s;
1111     int i, start, limit;
1112
1113     while (--argc) {
1114         char *p = *++argv;
1115         int c;
1116
1117         if (doing_opts && *p == '-') {
1118             p++;
1119
1120             if (!strcmp(p, "-")) {
1121                 doing_opts = FALSE;
1122                 continue;
1123             } else if (*p == '-') {
1124                 p++;
1125                 if (!strcmp(p, "debug-bfs")) {
1126                     debug_bfs = TRUE;
1127                 } else {
1128                     fprintf(stderr, "%s: option '--%s' not recognised\n",
1129                             pname, p);
1130                 }
1131             } else while (p && *p) switch (c = *p++) {
1132               case 'C':
1133                 rules = &rules_countdown;
1134                 break;
1135               case 'B':
1136                 rules = &rules_3388;
1137                 break;
1138               case 'D':
1139                 rules = &rules_four4s;
1140                 break;
1141               case 'A':
1142                 rules = &rules_anythinggoes;
1143                 break;
1144               case 'v':
1145                 verbose = TRUE;
1146                 break;
1147               case 'p':
1148                 pathcounts = TRUE;
1149                 break;
1150               case 'm':
1151                 multiple = TRUE;
1152                 break;
1153               case 't':
1154               case 'r':
1155                 {
1156                     char *v;
1157                     if (*p) {
1158                         v = p;
1159                         p = NULL;
1160                     } else if (--argc) {
1161                         v = *++argv;
1162                     } else {
1163                         fprintf(stderr, "%s: option '-%c' expects an"
1164                                 " argument\n", pname, c);
1165                         return 1;
1166                     }
1167                     switch (c) {
1168                       case 't':
1169                         got_target = TRUE;
1170                         target = atoi(v);
1171                         break;
1172                       case 'r':
1173                         {
1174                              char *sep = strchr(v, '-');
1175                              got_range = TRUE;
1176                              if (sep) {
1177                                  rangemin = atoi(v);
1178                                  rangemax = atoi(sep+1);
1179                              } else {
1180                                  rangemin = 0;
1181                                  rangemax = atoi(v);
1182                              }
1183                         }
1184                         break;
1185                     }
1186                 }
1187                 break;
1188               default:
1189                 fprintf(stderr, "%s: option '-%c' not"
1190                         " recognised\n", pname, c);
1191                 return 1;
1192             }
1193         } else {
1194             if (nnumbers >= lenof(numbers)) {
1195                 fprintf(stderr, "%s: internal limit of %d numbers exceeded\n",
1196                         pname, (int)lenof(numbers));
1197                 return 1;
1198             } else {
1199                 numbers[nnumbers++] = atoi(p);
1200             }
1201         }
1202     }
1203
1204     if (!rules) {
1205         fprintf(stderr, "%s: no rule set specified; use -C,-B,-D,-A\n", pname);
1206         return 1;
1207     }
1208
1209     if (!nnumbers) {
1210         fprintf(stderr, "%s: no input numbers specified\n", pname);
1211         return 1;
1212     }
1213
1214     if (got_range) {
1215         if (got_target) {
1216             fprintf(stderr, "%s: only one of -t and -r may be specified\n", pname);
1217             return 1;
1218         }
1219         if (rangemin >= rangemax) {
1220             fprintf(stderr, "%s: range not sensible (%d - %d)\n", pname, rangemin, rangemax);
1221             return 1;
1222         }
1223     }
1224
1225     s = do_search(nnumbers, numbers, rules, (got_target ? &target : NULL),
1226                   debug_bfs, multiple);
1227
1228     if (got_target) {
1229         o = findrelpos234(s->outputtree, &target, outputfindcmp,
1230                           REL234_LE, &start);
1231         if (!o)
1232             start = -1;
1233         o = findrelpos234(s->outputtree, &target, outputfindcmp,
1234                           REL234_GE, &limit);
1235         if (!o)
1236             limit = -1;
1237         assert(start != -1 || limit != -1);
1238         if (start == -1)
1239             start = limit;
1240         else if (limit == -1)
1241             limit = start;
1242         limit++;
1243     } else if (got_range) {
1244         if (!findrelpos234(s->outputtree, &rangemin, outputfindcmp,
1245                            REL234_GE, &start) ||
1246             !findrelpos234(s->outputtree, &rangemax, outputfindcmp,
1247                            REL234_LE, &limit)) {
1248             printf("No solutions available in specified range %d-%d\n", rangemin, rangemax);
1249             return 1;
1250         }
1251         limit++;
1252     } else {
1253         start = 0;
1254         limit = count234(s->outputtree);
1255     }
1256
1257     for (i = start; i < limit; i++) {
1258         char buf[256];
1259
1260         o = index234(s->outputtree, i);
1261
1262         sprintf(buf, "%d", o->number);
1263
1264         if (pathcounts)
1265             sprintf(buf + strlen(buf), " [%d]", o->npaths);
1266
1267         if (got_target || verbose) {
1268             int j, npaths;
1269
1270             if (multiple)
1271                 npaths = o->npaths;
1272             else
1273                 npaths = 1;
1274
1275             for (j = 0; j < npaths; j++) {
1276                 printf("%s = ", buf);
1277                 print(j, s, o);
1278                 putchar('\n');
1279             }
1280         } else {
1281             printf("%s\n", buf);
1282         }
1283     }
1284
1285     free_sets(s);
1286
1287     return 0;
1288 }
1289
1290 /* vim: set shiftwidth=4 tabstop=8: */