chiark / gitweb /
James H has helpfully provided yet more silly operators for the -A
[sgt-puzzles.git] / unfinished / numgame.c
1 /*
2  * This program implements a breadth-first search which
3  * exhaustively solves the Countdown numbers game, and related
4  * games with slightly different rule sets such as `Flippo'.
5  * 
6  * Currently it is simply a standalone command-line utility to
7  * which you provide a set of numbers and it tells you everything
8  * it can make together with how many different ways it can be
9  * made. I would like ultimately to turn it into the generator for
10  * a Puzzles puzzle, but I haven't even started on writing a
11  * Puzzles user interface yet.
12  */
13
14 /*
15  * TODO:
16  * 
17  *  - start thinking about difficulty ratings
18  *     + anything involving associative operations will be flagged
19  *       as many-paths because of the associative options (e.g.
20  *       2*3*4 can be (2*3)*4 or 2*(3*4), or indeed (2*4)*3). This
21  *       is probably a _good_ thing, since those are unusually
22  *       easy.
23  *     + tree-structured calculations ((a*b)/(c+d)) have multiple
24  *       paths because the independent branches of the tree can be
25  *       evaluated in either order, whereas straight-line
26  *       calculations with no branches will be considered easier.
27  *       Can we do anything about this? It's certainly not clear to
28  *       me that tree-structure calculations are _easier_, although
29  *       I'm also not convinced they're harder.
30  *     + I think for a realistic difficulty assessment we must also
31  *       consider the `obviousness' of the arithmetic operations in
32  *       some heuristic sense, and also (in Countdown) how many
33  *       numbers ended up being used.
34  *  - actually try some generations
35  *  - at this point we're probably ready to start on the Puzzles
36  *    integration.
37  */
38
39 #include <stdio.h>
40 #include <string.h>
41 #include <limits.h>
42 #include <assert.h>
43 #include <math.h>
44
45 #include "puzzles.h"
46 #include "tree234.h"
47
48 /*
49  * To search for numbers we can make, we employ a breadth-first
50  * search across the space of sets of input numbers. That is, for
51  * example, we start with the set (3,6,25,50,75,100); we apply
52  * moves which involve combining two numbers (e.g. adding the 50
53  * and the 75 takes us to the set (3,6,25,100,125); and then we see
54  * if we ever end up with a set containing (say) 952.
55  * 
56  * If the rules are changed so that all the numbers must be used,
57  * this is easy to adjust to: we simply see if we end up with a set
58  * containing _only_ (say) 952.
59  * 
60  * Obviously, we can vary the rules about permitted arithmetic
61  * operations simply by altering the set of valid moves in the bfs.
62  * However, there's one common rule in this sort of puzzle which
63  * takes a little more thought, and that's _concatenation_. For
64  * example, if you are given (say) four 4s and required to make 10,
65  * you are permitted to combine two of the 4s into a 44 to begin
66  * with, making (44-4)/4 = 10. However, you are generally not
67  * allowed to concatenate two numbers that _weren't_ both in the
68  * original input set (you couldn't multiply two 4s to get 16 and
69  * then concatenate a 4 on to it to make 164), so concatenation is
70  * not an operation which is valid in all situations.
71  * 
72  * We could enforce this restriction by storing a flag alongside
73  * each number indicating whether or not it's an original number;
74  * the rules being that concatenation of two numbers is only valid
75  * if they both have the original flag, and that its output _also_
76  * has the original flag (so that you can concatenate three 4s into
77  * a 444), but that applying any other arithmetic operation clears
78  * the original flag on the output. However, we can get marginally
79  * simpler than that by observing that since concatenation has to
80  * happen to a number before any other operation, we can simply
81  * place all the concatenations at the start of the search. In
82  * other words, we have a global flag on an entire number _set_
83  * which indicates whether we are still permitted to perform
84  * concatenations; if so, we can concatenate any of the numbers in
85  * that set. Performing any other operation clears the flag.
86  */
87
88 #define SETFLAG_CONCAT 1               /* we can do concatenation */
89
90 struct sets;
91
92 struct ancestor {
93     struct set *prev;                  /* index of ancestor set in set list */
94     unsigned char pa, pb, po, pr;      /* operation that got here from prev */
95 };
96
97 struct set {
98     int *numbers;                      /* rationals stored as n,d pairs */
99     short nnumbers;                    /* # of rationals, so half # of ints */
100     short flags;                       /* SETFLAG_CONCAT only, at present */
101     int npaths;                        /* number of ways to reach this set */
102     struct ancestor a;                 /* primary ancestor */
103     struct ancestor *as;               /* further ancestors, if we care */
104     int nas, assize;
105 };
106
107 struct output {
108     int number;
109     struct set *set;
110     int index;                         /* which number in the set is it? */
111     int npaths;                        /* number of ways to reach this */
112 };
113
114 #define SETLISTLEN 1024
115 #define NUMBERLISTLEN 32768
116 #define OUTPUTLISTLEN 1024
117 struct operation;
118 struct sets {
119     struct set **setlists;
120     int nsets, nsetlists, setlistsize;
121     tree234 *settree;
122     int **numberlists;
123     int nnumbers, nnumberlists, numberlistsize;
124     struct output **outputlists;
125     int noutputs, noutputlists, outputlistsize;
126     tree234 *outputtree;
127     const struct operation *const *ops;
128 };
129
130 #define OPFLAG_NEEDS_CONCAT 1
131 #define OPFLAG_KEEPS_CONCAT 2
132 #define OPFLAG_UNARY        4
133 #define OPFLAG_UNARYPREFIX  8
134
135 struct operation {
136     /*
137      * Most operations should be shown in the output working, but
138      * concatenation should not; we just take the result of the
139      * concatenation and assume that it's obvious how it was
140      * derived.
141      */
142     int display;
143
144     /*
145      * Text display of the operator, in expressions and for
146      * debugging respectively.
147      */
148     char *text, *dbgtext;
149
150     /*
151      * Flags dictating when the operator can be applied.
152      */
153     int flags;
154
155     /*
156      * Priority of the operator (for avoiding unnecessary
157      * parentheses when formatting it into a string).
158      */
159     int priority;
160
161     /*
162      * Associativity of the operator. Bit 0 means we need parens
163      * when the left operand of one of these operators is another
164      * instance of it, e.g. (2^3)^4. Bit 1 means we need parens
165      * when the right operand is another instance of the same
166      * operator, e.g. 2-(3-4). Thus:
167      * 
168      *  - this field is 0 for a fully associative operator, since
169      *    we never need parens.
170      *  - it's 1 for a right-associative operator.
171      *  - it's 2 for a left-associative operator.
172      *  - it's 3 for a _non_-associative operator (which always
173      *    uses parens just to be sure).
174      */
175     int assoc;
176
177     /*
178      * Whether the operator is commutative. Saves time in the
179      * search if we don't have to try it both ways round.
180      */
181     int commutes;
182
183     /*
184      * Function which implements the operator. Returns TRUE on
185      * success, FALSE on failure. Takes two rationals and writes
186      * out a third.
187      */
188     int (*perform)(int *a, int *b, int *output);
189 };
190
191 struct rules {
192     const struct operation *const *ops;
193     int use_all;
194 };
195
196 #define MUL(r, a, b) do { \
197     (r) = (a) * (b); \
198     if ((b) && (a) && (r) / (b) != (a)) return FALSE; \
199 } while (0)
200
201 #define ADD(r, a, b) do { \
202     (r) = (a) + (b); \
203     if ((a) > 0 && (b) > 0 && (r) < 0) return FALSE; \
204     if ((a) < 0 && (b) < 0 && (r) > 0) return FALSE; \
205 } while (0)
206
207 #define OUT(output, n, d) do { \
208     int g = gcd((n),(d)); \
209     if (g < 0) g = -g; \
210     if ((d) < 0) g = -g; \
211     if (g == -1 && (n) < -INT_MAX) return FALSE; \
212     if (g == -1 && (d) < -INT_MAX) return FALSE; \
213     (output)[0] = (n)/g; \
214     (output)[1] = (d)/g; \
215     assert((output)[1] > 0); \
216 } while (0)
217
218 static int gcd(int x, int y)
219 {
220     while (x != 0 && y != 0) {
221         int t = x;
222         x = y;
223         y = t % y;
224     }
225
226     return abs(x + y);                 /* i.e. whichever one isn't zero */
227 }
228
229 static int perform_add(int *a, int *b, int *output)
230 {
231     int at, bt, tn, bn;
232     /*
233      * a0/a1 + b0/b1 = (a0*b1 + b0*a1) / (a1*b1)
234      */
235     MUL(at, a[0], b[1]);
236     MUL(bt, b[0], a[1]);
237     ADD(tn, at, bt);
238     MUL(bn, a[1], b[1]);
239     OUT(output, tn, bn);
240     return TRUE;
241 }
242
243 static int perform_sub(int *a, int *b, int *output)
244 {
245     int at, bt, tn, bn;
246     /*
247      * a0/a1 - b0/b1 = (a0*b1 - b0*a1) / (a1*b1)
248      */
249     MUL(at, a[0], b[1]);
250     MUL(bt, b[0], a[1]);
251     ADD(tn, at, -bt);
252     MUL(bn, a[1], b[1]);
253     OUT(output, tn, bn);
254     return TRUE;
255 }
256
257 static int perform_mul(int *a, int *b, int *output)
258 {
259     int tn, bn;
260     /*
261      * a0/a1 * b0/b1 = (a0*b0) / (a1*b1)
262      */
263     MUL(tn, a[0], b[0]);
264     MUL(bn, a[1], b[1]);
265     OUT(output, tn, bn);
266     return TRUE;
267 }
268
269 static int perform_div(int *a, int *b, int *output)
270 {
271     int tn, bn;
272
273     /*
274      * Division by zero is outlawed.
275      */
276     if (b[0] == 0)
277         return FALSE;
278
279     /*
280      * a0/a1 / b0/b1 = (a0*b1) / (a1*b0)
281      */
282     MUL(tn, a[0], b[1]);
283     MUL(bn, a[1], b[0]);
284     OUT(output, tn, bn);
285     return TRUE;
286 }
287
288 static int perform_exact_div(int *a, int *b, int *output)
289 {
290     int tn, bn;
291
292     /*
293      * Division by zero is outlawed.
294      */
295     if (b[0] == 0)
296         return FALSE;
297
298     /*
299      * a0/a1 / b0/b1 = (a0*b1) / (a1*b0)
300      */
301     MUL(tn, a[0], b[1]);
302     MUL(bn, a[1], b[0]);
303     OUT(output, tn, bn);
304
305     /*
306      * Exact division means we require the result to be an integer.
307      */
308     return (output[1] == 1);
309 }
310
311 static int max_p10(int n, int *p10_r)
312 {
313     /*
314      * Find the smallest power of ten strictly greater than n.
315      *
316      * Special case: we must return at least 10, even if n is
317      * zero. (This is because this function is used for finding
318      * the power of ten by which to multiply a number being
319      * concatenated to the front of n, and concatenating 1 to 0
320      * should yield 10 and not 1.)
321      */
322     int p10 = 10;
323     while (p10 <= (INT_MAX/10) && p10 <= n)
324         p10 *= 10;
325     if (p10 > INT_MAX/10)
326         return FALSE;                  /* integer overflow */
327     *p10_r = p10;
328     return TRUE;
329 }
330
331 static int perform_concat(int *a, int *b, int *output)
332 {
333     int t1, t2, p10;
334
335     /*
336      * We can't concatenate anything which isn't a non-negative
337      * integer.
338      */
339     if (a[1] != 1 || b[1] != 1 || a[0] < 0 || b[0] < 0)
340         return FALSE;
341
342     /*
343      * For concatenation, we can safely assume leading zeroes
344      * aren't an issue. It isn't clear whether they `should' be
345      * allowed, but it turns out not to matter: concatenating a
346      * leading zero on to a number in order to harmlessly get rid
347      * of the zero is never necessary because unwanted zeroes can
348      * be disposed of by adding them to something instead. So we
349      * disallow them always.
350      *
351      * The only other possibility is that you might want to
352      * concatenate a leading zero on to something and then
353      * concatenate another non-zero digit on to _that_ (to make,
354      * for example, 106); but that's also unnecessary, because you
355      * can make 106 just as easily by concatenating the 0 on to the
356      * _end_ of the 1 first.
357      */
358     if (a[0] == 0)
359         return FALSE;
360
361     if (!max_p10(b[0], &p10)) return FALSE;
362
363     MUL(t1, p10, a[0]);
364     ADD(t2, t1, b[0]);
365     OUT(output, t2, 1);
366     return TRUE;
367 }
368
369 #define IPOW(ret, x, y) do { \
370     int ipow_limit = (y); \
371     if ((x) == 1 || (x) == 0) ipow_limit = 1; \
372     else if ((x) == -1) ipow_limit &= 1; \
373     (ret) = 1; \
374     while (ipow_limit-- > 0) { \
375         int tmp; \
376         MUL(tmp, ret, x); \
377         ret = tmp; \
378     } \
379 } while (0)
380
381 static int perform_exp(int *a, int *b, int *output)
382 {
383     int an, ad, xn, xd;
384
385     /*
386      * Exponentiation is permitted if the result is rational. This
387      * means that:
388      * 
389      *  - first we see whether we can take the (denominator-of-b)th
390      *    root of a and get a rational; if not, we give up.
391      * 
392      *  - then we do take that root of a
393      * 
394      *  - then we multiply by itself (numerator-of-b) times.
395      */
396     if (b[1] > 1) {
397         an = (int)(0.5 + pow(a[0], 1.0/b[1]));
398         ad = (int)(0.5 + pow(a[1], 1.0/b[1]));
399         IPOW(xn, an, b[1]);
400         IPOW(xd, ad, b[1]);
401         if (xn != a[0] || xd != a[1])
402             return FALSE;
403     } else {
404         an = a[0];
405         ad = a[1];
406     }
407     if (b[0] >= 0) {
408         IPOW(xn, an, b[0]);
409         IPOW(xd, ad, b[0]);
410     } else {
411         IPOW(xd, an, -b[0]);
412         IPOW(xn, ad, -b[0]);
413     }
414     if (xd == 0)
415         return FALSE;
416
417     OUT(output, xn, xd);
418     return TRUE;
419 }
420
421 static int perform_factorial(int *a, int *b, int *output)
422 {
423     int ret, t, i;
424
425     /*
426      * Factorials of non-negative integers are permitted.
427      */
428     if (a[1] != 1 || a[0] < 0)
429         return FALSE;
430
431     /*
432      * However, a special case: we don't take a factorial of
433      * anything which would thereby remain the same.
434      */
435     if (a[0] == 1 || a[0] == 2)
436         return FALSE;
437
438     ret = 1;
439     for (i = 1; i <= a[0]; i++) {
440         MUL(t, ret, i);
441         ret = t;
442     }
443
444     OUT(output, ret, 1);
445     return TRUE;
446 }
447
448 static int perform_decimal(int *a, int *b, int *output)
449 {
450     int p10;
451
452     /*
453      * Add a decimal digit to the front of a number;
454      * fail if it's not an integer.
455      * So, 1 --> 0.1, 15 --> 0.15,
456      * or, rather, 1 --> 1/10, 15 --> 15/100,
457      * x --> x / (smallest power of 10 > than x)
458      *
459      */
460     if (a[1] != 1) return FALSE;
461
462     if (!max_p10(a[0], &p10)) return FALSE;
463
464     OUT(output, a[0], p10);
465     return TRUE;
466 }
467
468 static int perform_recur(int *a, int *b, int *output)
469 {
470     int p10, tn, bn;
471
472     /*
473      * This converts a number like .4 to .44444..., or .45 to .45454...
474      * The input number must be -1 < a < 1.
475      *
476      * Calculate the smallest power of 10 that divides the denominator exactly,
477      * returning if no such power of 10 exists. Then multiply the numerator
478      * up accordingly, and the new denominator becomes that power of 10 - 1.
479      */
480     if (abs(a[0]) >= abs(a[1])) return FALSE; /* -1 < a < 1 */
481
482     p10 = 10;
483     while (p10 <= (INT_MAX/10)) {
484         if ((a[1] <= p10) && (p10 % a[1]) == 0) goto found;
485         p10 *= 10;
486     }
487     return FALSE;
488 found:
489     tn = a[0] * (p10 / a[1]);
490     bn = p10 - 1;
491
492     OUT(output, tn, bn);
493     return TRUE;
494 }
495
496 static int perform_root(int *a, int *b, int *output)
497 {
498     /*
499      * A root B is: 1           iff a == 0
500      *              B ^ (1/A)   otherwise
501      */
502     int ainv[2], res;
503
504     if (a[0] == 0) {
505         OUT(output, 1, 1);
506         return TRUE;
507     }
508
509     OUT(ainv, a[1], a[0]);
510     res = perform_exp(b, ainv, output);
511     return res;
512 }
513
514 const static struct operation op_add = {
515     TRUE, "+", "+", 0, 10, 0, TRUE, perform_add
516 };
517 const static struct operation op_sub = {
518     TRUE, "-", "-", 0, 10, 2, FALSE, perform_sub
519 };
520 const static struct operation op_mul = {
521     TRUE, "*", "*", 0, 20, 0, TRUE, perform_mul
522 };
523 const static struct operation op_div = {
524     TRUE, "/", "/", 0, 20, 2, FALSE, perform_div
525 };
526 const static struct operation op_xdiv = {
527     TRUE, "/", "/", 0, 20, 2, FALSE, perform_exact_div
528 };
529 const static struct operation op_concat = {
530     FALSE, "", "concat", OPFLAG_NEEDS_CONCAT | OPFLAG_KEEPS_CONCAT,
531         1000, 0, FALSE, perform_concat
532 };
533 const static struct operation op_exp = {
534     TRUE, "^", "^", 0, 30, 1, FALSE, perform_exp
535 };
536 const static struct operation op_factorial = {
537     TRUE, "!", "!", OPFLAG_UNARY, 40, 0, FALSE, perform_factorial
538 };
539 const static struct operation op_decimal = {
540     TRUE, ".", ".", OPFLAG_UNARY | OPFLAG_UNARYPREFIX | OPFLAG_NEEDS_CONCAT | OPFLAG_KEEPS_CONCAT, 50, 0, FALSE, perform_decimal
541 };
542 const static struct operation op_recur = {
543     TRUE, "...", "recur", OPFLAG_UNARY | OPFLAG_NEEDS_CONCAT, 45, 2, FALSE, perform_recur
544 };
545 const static struct operation op_root = {
546     TRUE, "v~", "root", 0, 30, 1, FALSE, perform_root
547 };
548
549 /*
550  * In Countdown, divisions resulting in fractions are disallowed.
551  * http://www.askoxford.com/wordgames/countdown/rules/
552  */
553 const static struct operation *const ops_countdown[] = {
554     &op_add, &op_mul, &op_sub, &op_xdiv, NULL
555 };
556 const static struct rules rules_countdown = {
557     ops_countdown, FALSE
558 };
559
560 /*
561  * A slightly different rule set which handles the reasonably well
562  * known puzzle of making 24 using two 3s and two 8s. For this we
563  * need rational rather than integer division.
564  */
565 const static struct operation *const ops_3388[] = {
566     &op_add, &op_mul, &op_sub, &op_div, NULL
567 };
568 const static struct rules rules_3388 = {
569     ops_3388, TRUE
570 };
571
572 /*
573  * A still more permissive rule set usable for the four-4s problem
574  * and similar things. Permits concatenation.
575  */
576 const static struct operation *const ops_four4s[] = {
577     &op_add, &op_mul, &op_sub, &op_div, &op_concat, NULL
578 };
579 const static struct rules rules_four4s = {
580     ops_four4s, TRUE
581 };
582
583 /*
584  * The most permissive ruleset I can think of. Permits
585  * exponentiation, and also silly unary operators like factorials.
586  */
587 const static struct operation *const ops_anythinggoes[] = {
588     &op_add, &op_mul, &op_sub, &op_div, &op_concat, &op_exp, &op_factorial, 
589     &op_decimal, &op_recur, &op_root, NULL
590 };
591 const static struct rules rules_anythinggoes = {
592     ops_anythinggoes, TRUE
593 };
594
595 #define ratcmp(a,op,b) ( (long long)(a)[0] * (b)[1] op \
596                          (long long)(b)[0] * (a)[1] )
597
598 static int addtoset(struct set *set, int newnumber[2])
599 {
600     int i, j;
601
602     /* Find where we want to insert the new number */
603     for (i = 0; i < set->nnumbers &&
604          ratcmp(set->numbers+2*i, <, newnumber); i++);
605
606     /* Move everything else up */
607     for (j = set->nnumbers; j > i; j--) {
608         set->numbers[2*j] = set->numbers[2*j-2];
609         set->numbers[2*j+1] = set->numbers[2*j-1];
610     }
611
612     /* Insert the new number */
613     set->numbers[2*i] = newnumber[0];
614     set->numbers[2*i+1] = newnumber[1];
615
616     set->nnumbers++;
617
618     return i;
619 }
620
621 #define ensure(array, size, newlen, type) do { \
622     if ((newlen) > (size)) { \
623         (size) = (newlen) + 512; \
624         (array) = sresize((array), (size), type); \
625     } \
626 } while (0)
627
628 static int setcmp(void *av, void *bv)
629 {
630     struct set *a = (struct set *)av;
631     struct set *b = (struct set *)bv;
632     int i;
633
634     if (a->nnumbers < b->nnumbers)
635         return -1;
636     else if (a->nnumbers > b->nnumbers)
637         return +1;
638
639     if (a->flags < b->flags)
640         return -1;
641     else if (a->flags > b->flags)
642         return +1;
643
644     for (i = 0; i < a->nnumbers; i++) {
645         if (ratcmp(a->numbers+2*i, <, b->numbers+2*i))
646             return -1;
647         else if (ratcmp(a->numbers+2*i, >, b->numbers+2*i))
648             return +1;
649     }
650
651     return 0;
652 }
653
654 static int outputcmp(void *av, void *bv)
655 {
656     struct output *a = (struct output *)av;
657     struct output *b = (struct output *)bv;
658
659     if (a->number < b->number)
660         return -1;
661     else if (a->number > b->number)
662         return +1;
663
664     return 0;
665 }
666
667 static int outputfindcmp(void *av, void *bv)
668 {
669     int *a = (int *)av;
670     struct output *b = (struct output *)bv;
671
672     if (*a < b->number)
673         return -1;
674     else if (*a > b->number)
675         return +1;
676
677     return 0;
678 }
679
680 static void addset(struct sets *s, struct set *set, int multiple,
681                    struct set *prev, int pa, int po, int pb, int pr)
682 {
683     struct set *s2;
684     int npaths = (prev ? prev->npaths : 1);
685
686     assert(set == s->setlists[s->nsets / SETLISTLEN] + s->nsets % SETLISTLEN);
687     s2 = add234(s->settree, set);
688     if (s2 == set) {
689         /*
690          * New set added to the tree.
691          */
692         set->a.prev = prev;
693         set->a.pa = pa;
694         set->a.po = po;
695         set->a.pb = pb;
696         set->a.pr = pr;
697         set->npaths = npaths;
698         s->nsets++;
699         s->nnumbers += 2 * set->nnumbers;
700         set->as = NULL;
701         set->nas = set->assize = 0;
702     } else {
703         /*
704          * Rediscovered an existing set. Update its npaths.
705          */
706         s2->npaths += npaths;
707         /*
708          * And optionally enter it as an additional ancestor.
709          */
710         if (multiple) {
711             if (s2->nas >= s2->assize) {
712                 s2->assize = s2->nas * 3 / 2 + 4;
713                 s2->as = sresize(s2->as, s2->assize, struct ancestor);
714             }
715             s2->as[s2->nas].prev = prev;
716             s2->as[s2->nas].pa = pa;
717             s2->as[s2->nas].po = po;
718             s2->as[s2->nas].pb = pb;
719             s2->as[s2->nas].pr = pr;
720             s2->nas++;
721         }
722     }
723 }
724
725 static struct set *newset(struct sets *s, int nnumbers, int flags)
726 {
727     struct set *sn;
728
729     ensure(s->setlists, s->setlistsize, s->nsets/SETLISTLEN+1, struct set *);
730     while (s->nsetlists <= s->nsets / SETLISTLEN)
731         s->setlists[s->nsetlists++] = snewn(SETLISTLEN, struct set);
732     sn = s->setlists[s->nsets / SETLISTLEN] + s->nsets % SETLISTLEN;
733
734     if (s->nnumbers + nnumbers * 2 > s->nnumberlists * NUMBERLISTLEN)
735         s->nnumbers = s->nnumberlists * NUMBERLISTLEN;
736     ensure(s->numberlists, s->numberlistsize,
737            s->nnumbers/NUMBERLISTLEN+1, int *);
738     while (s->nnumberlists <= s->nnumbers / NUMBERLISTLEN)
739         s->numberlists[s->nnumberlists++] = snewn(NUMBERLISTLEN, int);
740     sn->numbers = s->numberlists[s->nnumbers / NUMBERLISTLEN] +
741         s->nnumbers % NUMBERLISTLEN;
742
743     /*
744      * Start the set off empty.
745      */
746     sn->nnumbers = 0;
747
748     sn->flags = flags;
749
750     return sn;
751 }
752
753 static int addoutput(struct sets *s, struct set *ss, int index, int *n)
754 {
755     struct output *o, *o2;
756
757     /*
758      * Target numbers are always integers.
759      */
760     if (ss->numbers[2*index+1] != 1)
761         return FALSE;
762
763     ensure(s->outputlists, s->outputlistsize, s->noutputs/OUTPUTLISTLEN+1,
764            struct output *);
765     while (s->noutputlists <= s->noutputs / OUTPUTLISTLEN)
766         s->outputlists[s->noutputlists++] = snewn(OUTPUTLISTLEN,
767                                                   struct output);
768     o = s->outputlists[s->noutputs / OUTPUTLISTLEN] +
769         s->noutputs % OUTPUTLISTLEN;
770
771     o->number = ss->numbers[2*index];
772     o->set = ss;
773     o->index = index;
774     o->npaths = ss->npaths;
775     o2 = add234(s->outputtree, o);
776     if (o2 != o) {
777         o2->npaths += o->npaths;
778     } else {
779         s->noutputs++;
780     }
781     *n = o->number;
782     return TRUE;
783 }
784
785 static struct sets *do_search(int ninputs, int *inputs,
786                               const struct rules *rules, int *target,
787                               int debug, int multiple)
788 {
789     struct sets *s;
790     struct set *sn;
791     int qpos, i;
792     const struct operation *const *ops = rules->ops;
793
794     s = snew(struct sets);
795     s->setlists = NULL;
796     s->nsets = s->nsetlists = s->setlistsize = 0;
797     s->numberlists = NULL;
798     s->nnumbers = s->nnumberlists = s->numberlistsize = 0;
799     s->outputlists = NULL;
800     s->noutputs = s->noutputlists = s->outputlistsize = 0;
801     s->settree = newtree234(setcmp);
802     s->outputtree = newtree234(outputcmp);
803     s->ops = ops;
804
805     /*
806      * Start with the input set.
807      */
808     sn = newset(s, ninputs, SETFLAG_CONCAT);
809     for (i = 0; i < ninputs; i++) {
810         int newnumber[2];
811         newnumber[0] = inputs[i];
812         newnumber[1] = 1;
813         addtoset(sn, newnumber);
814     }
815     addset(s, sn, multiple, NULL, 0, 0, 0, 0);
816
817     /*
818      * Now perform the breadth-first search: keep looping over sets
819      * until we run out of steam.
820      */
821     qpos = 0;
822     while (qpos < s->nsets) {
823         struct set *ss = s->setlists[qpos / SETLISTLEN] + qpos % SETLISTLEN;
824         struct set *sn;
825         int i, j, k, m;
826
827         if (debug) {
828             int i;
829             printf("processing set:");
830             for (i = 0; i < ss->nnumbers; i++) {
831                 printf(" %d", ss->numbers[2*i]);
832                 if (ss->numbers[2*i+1] != 1)
833                     printf("/%d", ss->numbers[2*i+1]);
834             }
835             printf("\n");
836         }
837
838         /*
839          * Record all the valid output numbers in this state. We
840          * can always do this if there's only one number in the
841          * state; otherwise, we can only do it if we aren't
842          * required to use all the numbers in coming to our answer.
843          */
844         if (ss->nnumbers == 1 || !rules->use_all) {
845             for (i = 0; i < ss->nnumbers; i++) {
846                 int n;
847
848                 if (addoutput(s, ss, i, &n) && target && n == *target)
849                     return s;
850             }
851         }
852
853         /*
854          * Try every possible operation from this state.
855          */
856         for (k = 0; ops[k] && ops[k]->perform; k++) {
857             if ((ops[k]->flags & OPFLAG_NEEDS_CONCAT) &&
858                 !(ss->flags & SETFLAG_CONCAT))
859                 continue;              /* can't use this operation here */
860             for (i = 0; i < ss->nnumbers; i++) {
861                 int jlimit = (ops[k]->flags & OPFLAG_UNARY ? 1 : ss->nnumbers);
862                 for (j = 0; j < jlimit; j++) {
863                     int n[2];
864                     int pa, po, pb, pr;
865
866                     if (!(ops[k]->flags & OPFLAG_UNARY)) {
867                         if (i == j)
868                             continue;  /* can't combine a number with itself */
869                         if (i > j && ops[k]->commutes)
870                             continue;  /* no need to do this both ways round */
871                     }
872                     if (!ops[k]->perform(ss->numbers+2*i, ss->numbers+2*j, n))
873                         continue;      /* operation failed */
874
875                     sn = newset(s, ss->nnumbers-1, ss->flags);
876
877                     if (!(ops[k]->flags & OPFLAG_KEEPS_CONCAT))
878                         sn->flags &= ~SETFLAG_CONCAT;
879
880                     for (m = 0; m < ss->nnumbers; m++) {
881                         if (m == i || (!(ops[k]->flags & OPFLAG_UNARY) &&
882                                        m == j))
883                             continue;
884                         sn->numbers[2*sn->nnumbers] = ss->numbers[2*m];
885                         sn->numbers[2*sn->nnumbers + 1] = ss->numbers[2*m + 1];
886                         sn->nnumbers++;
887                     }
888                     pa = i;
889                     if (ops[k]->flags & OPFLAG_UNARY)
890                         pb = sn->nnumbers+10;
891                     else
892                         pb = j;
893                     po = k;
894                     pr = addtoset(sn, n);
895                     addset(s, sn, multiple, ss, pa, po, pb, pr);
896                     if (debug) {
897                         int i;
898                         if (ops[k]->flags & OPFLAG_UNARYPREFIX)
899                             printf("  %s %d ->", ops[po]->dbgtext, pa);
900                         else if (ops[k]->flags & OPFLAG_UNARY)
901                             printf("  %d %s ->", pa, ops[po]->dbgtext);
902                         else
903                             printf("  %d %s %d ->", pa, ops[po]->dbgtext, pb);
904                         for (i = 0; i < sn->nnumbers; i++) {
905                             printf(" %d", sn->numbers[2*i]);
906                             if (sn->numbers[2*i+1] != 1)
907                                 printf("/%d", sn->numbers[2*i+1]);
908                         }
909                         printf("\n");
910                     }
911                 }
912             }
913         }
914
915         qpos++;
916     }
917
918     return s;
919 }
920
921 static void free_sets(struct sets *s)
922 {
923     int i;
924
925     freetree234(s->settree);
926     freetree234(s->outputtree);
927     for (i = 0; i < s->nsetlists; i++)
928         sfree(s->setlists[i]);
929     sfree(s->setlists);
930     for (i = 0; i < s->nnumberlists; i++)
931         sfree(s->numberlists[i]);
932     sfree(s->numberlists);
933     for (i = 0; i < s->noutputlists; i++)
934         sfree(s->outputlists[i]);
935     sfree(s->outputlists);
936     sfree(s);
937 }
938
939 /*
940  * Print a text formula for producing a given output.
941  */
942 void print_recurse(struct sets *s, struct set *ss, int pathindex, int index,
943                    int priority, int assoc, int child);
944 void print_recurse_inner(struct sets *s, struct set *ss,
945                          struct ancestor *a, int pathindex, int index,
946                          int priority, int assoc, int child)
947 {
948     if (a->prev && index != a->pr) {
949         int pi;
950
951         /*
952          * This number was passed straight down from this set's
953          * predecessor. Find its index in the previous set and
954          * recurse to there.
955          */
956         pi = index;
957         assert(pi != a->pr);
958         if (pi > a->pr)
959             pi--;
960         if (pi >= min(a->pa, a->pb)) {
961             pi++;
962             if (pi >= max(a->pa, a->pb))
963                 pi++;
964         }
965         print_recurse(s, a->prev, pathindex, pi, priority, assoc, child);
966     } else if (a->prev && index == a->pr &&
967                s->ops[a->po]->display) {
968         /*
969          * This number was created by a displayed operator in the
970          * transition from this set to its predecessor. Hence we
971          * write an open paren, then recurse into the first
972          * operand, then write the operator, then the second
973          * operand, and finally close the paren.
974          */
975         char *op;
976         int parens, thispri, thisassoc;
977
978         /*
979          * Determine whether we need parentheses.
980          */
981         thispri = s->ops[a->po]->priority;
982         thisassoc = s->ops[a->po]->assoc;
983         parens = (thispri < priority ||
984                   (thispri == priority && (assoc & child)));
985
986         if (parens)
987             putchar('(');
988
989         if (s->ops[a->po]->flags & OPFLAG_UNARYPREFIX)
990             for (op = s->ops[a->po]->text; *op; op++)
991                 putchar(*op);
992
993         print_recurse(s, a->prev, pathindex, a->pa, thispri, thisassoc, 1);
994
995         if (!(s->ops[a->po]->flags & OPFLAG_UNARYPREFIX))
996             for (op = s->ops[a->po]->text; *op; op++)
997                 putchar(*op);
998
999         if (!(s->ops[a->po]->flags & OPFLAG_UNARY))
1000             print_recurse(s, a->prev, pathindex, a->pb, thispri, thisassoc, 2);
1001
1002         if (parens)
1003             putchar(')');
1004     } else {
1005         /*
1006          * This number is either an original, or something formed
1007          * by a non-displayed operator (concatenation). Either way,
1008          * we display it as is.
1009          */
1010         printf("%d", ss->numbers[2*index]);
1011         if (ss->numbers[2*index+1] != 1)
1012             printf("/%d", ss->numbers[2*index+1]);
1013     }
1014 }
1015 void print_recurse(struct sets *s, struct set *ss, int pathindex, int index,
1016                    int priority, int assoc, int child)
1017 {
1018     if (!ss->a.prev || pathindex < ss->a.prev->npaths) {
1019         print_recurse_inner(s, ss, &ss->a, pathindex,
1020                             index, priority, assoc, child);
1021     } else {
1022         int i;
1023         pathindex -= ss->a.prev->npaths;
1024         for (i = 0; i < ss->nas; i++) {
1025             if (pathindex < ss->as[i].prev->npaths) {
1026                 print_recurse_inner(s, ss, &ss->as[i], pathindex,
1027                                     index, priority, assoc, child);
1028                 break;
1029             }
1030             pathindex -= ss->as[i].prev->npaths;
1031         }
1032     }
1033 }
1034 void print(int pathindex, struct sets *s, struct output *o)
1035 {
1036     print_recurse(s, o->set, pathindex, o->index, 0, 0, 0);
1037 }
1038
1039 /*
1040  * gcc -g -O0 -o numgame numgame.c -I.. ../{malloc,tree234,nullfe}.c -lm
1041  */
1042 int main(int argc, char **argv)
1043 {
1044     int doing_opts = TRUE;
1045     const struct rules *rules = NULL;
1046     char *pname = argv[0];
1047     int got_target = FALSE, target = 0;
1048     int numbers[10], nnumbers = 0;
1049     int verbose = FALSE;
1050     int pathcounts = FALSE;
1051     int multiple = FALSE;
1052     int debug_bfs = FALSE;
1053
1054     struct output *o;
1055     struct sets *s;
1056     int i, start, limit;
1057
1058     while (--argc) {
1059         char *p = *++argv;
1060         int c;
1061
1062         if (doing_opts && *p == '-') {
1063             p++;
1064
1065             if (!strcmp(p, "-")) {
1066                 doing_opts = FALSE;
1067                 continue;
1068             } else if (*p == '-') {
1069                 p++;
1070                 if (!strcmp(p, "debug-bfs")) {
1071                     debug_bfs = TRUE;
1072                 } else {
1073                     fprintf(stderr, "%s: option '--%s' not recognised\n",
1074                             pname, p);
1075                 }
1076             } else while (*p) switch (c = *p++) {
1077               case 'C':
1078                 rules = &rules_countdown;
1079                 break;
1080               case 'B':
1081                 rules = &rules_3388;
1082                 break;
1083               case 'D':
1084                 rules = &rules_four4s;
1085                 break;
1086               case 'A':
1087                 rules = &rules_anythinggoes;
1088                 break;
1089               case 'v':
1090                 verbose = TRUE;
1091                 break;
1092               case 'p':
1093                 pathcounts = TRUE;
1094                 break;
1095               case 'm':
1096                 multiple = TRUE;
1097                 break;
1098               case 't':
1099                 {
1100                     char *v;
1101                     if (*p) {
1102                         v = p;
1103                         p = NULL;
1104                     } else if (--argc) {
1105                         v = *++argv;
1106                     } else {
1107                         fprintf(stderr, "%s: option '-%c' expects an"
1108                                 " argument\n", pname, c);
1109                         return 1;
1110                     }
1111                     switch (c) {
1112                       case 't':
1113                         got_target = TRUE;
1114                         target = atoi(v);
1115                         break;
1116                     }
1117                 }
1118                 break;
1119               default:
1120                 fprintf(stderr, "%s: option '-%c' not"
1121                         " recognised\n", pname, c);
1122                 return 1;
1123             }
1124         } else {
1125             if (nnumbers >= lenof(numbers)) {
1126                 fprintf(stderr, "%s: internal limit of %d numbers exceeded\n",
1127                         pname, lenof(numbers));
1128                 return 1;
1129             } else {
1130                 numbers[nnumbers++] = atoi(p);
1131             }
1132         }
1133     }
1134
1135     if (!rules) {
1136         fprintf(stderr, "%s: no rule set specified; use -C,-B,-D,-A\n", pname);
1137         return 1;
1138     }
1139
1140     if (!nnumbers) {
1141         fprintf(stderr, "%s: no input numbers specified\n", pname);
1142         return 1;
1143     }
1144
1145     s = do_search(nnumbers, numbers, rules, (got_target ? &target : NULL),
1146                   debug_bfs, multiple);
1147
1148     if (got_target) {
1149         o = findrelpos234(s->outputtree, &target, outputfindcmp,
1150                           REL234_LE, &start);
1151         if (!o)
1152             start = -1;
1153         o = findrelpos234(s->outputtree, &target, outputfindcmp,
1154                           REL234_GE, &limit);
1155         if (!o)
1156             limit = -1;
1157         assert(start != -1 || limit != -1);
1158         if (start == -1)
1159             start = limit;
1160         else if (limit == -1)
1161             limit = start;
1162         limit++;
1163     } else {
1164         start = 0;
1165         limit = count234(s->outputtree);
1166     }
1167
1168     for (i = start; i < limit; i++) {
1169         char buf[256];
1170
1171         o = index234(s->outputtree, i);
1172
1173         sprintf(buf, "%d", o->number);
1174
1175         if (pathcounts)
1176             sprintf(buf + strlen(buf), " [%d]", o->npaths);
1177
1178         if (got_target || verbose) {
1179             int j, npaths;
1180
1181             if (multiple)
1182                 npaths = o->npaths;
1183             else
1184                 npaths = 1;
1185
1186             for (j = 0; j < npaths; j++) {
1187                 printf("%s = ", buf);
1188                 print(j, s, o);
1189                 putchar('\n');
1190             }
1191         } else {
1192             printf("%s\n", buf);
1193         }
1194     }
1195
1196     free_sets(s);
1197
1198     return 0;
1199 }