chiark / gitweb /
Normalise Unequal (and latin.c) so that solver diagnostics start
[sgt-puzzles.git] / latin.c
1 #include <assert.h>
2 #include <string.h>
3 #include <stdarg.h>
4
5 #include "puzzles.h"
6 #include "tree234.h"
7 #include "maxflow.h"
8
9 #ifdef STANDALONE_LATIN_TEST
10 #define STANDALONE_SOLVER
11 #endif
12
13 #include "latin.h"
14
15 /* --------------------------------------------------------
16  * Solver.
17  */
18
19 #ifdef STANDALONE_SOLVER
20 int solver_show_working, solver_recurse_depth;
21 #endif
22
23 /*
24  * Function called when we are certain that a particular square has
25  * a particular number in it. The y-coordinate passed in here is
26  * transformed.
27  */
28 void latin_solver_place(struct latin_solver *solver, int x, int y, int n)
29 {
30     int i, o = solver->o;
31
32     assert(n <= o);
33     assert(cube(x,y,n));
34
35     /*
36      * Rule out all other numbers in this square.
37      */
38     for (i = 1; i <= o; i++)
39         if (i != n)
40             cube(x,y,i) = FALSE;
41
42     /*
43      * Rule out this number in all other positions in the row.
44      */
45     for (i = 0; i < o; i++)
46         if (i != y)
47             cube(x,i,n) = FALSE;
48
49     /*
50      * Rule out this number in all other positions in the column.
51      */
52     for (i = 0; i < o; i++)
53         if (i != x)
54             cube(i,y,n) = FALSE;
55
56     /*
57      * Enter the number in the result grid.
58      */
59     solver->grid[YUNTRANS(y)*o+x] = n;
60
61     /*
62      * Cross out this number from the list of numbers left to place
63      * in its row, its column and its block.
64      */
65     solver->row[y*o+n-1] = solver->col[x*o+n-1] = TRUE;
66 }
67
68 int latin_solver_elim(struct latin_solver *solver, int start, int step
69 #ifdef STANDALONE_SOLVER
70                       , char *fmt, ...
71 #endif
72                       )
73 {
74     int o = solver->o;
75     int fpos, m, i;
76
77     /*
78      * Count the number of set bits within this section of the
79      * cube.
80      */
81     m = 0;
82     fpos = -1;
83     for (i = 0; i < o; i++)
84         if (solver->cube[start+i*step]) {
85             fpos = start+i*step;
86             m++;
87         }
88
89     if (m == 1) {
90         int x, y, n;
91         assert(fpos >= 0);
92
93         n = 1 + fpos % o;
94         y = fpos / o;
95         x = y / o;
96         y %= o;
97
98         if (!solver->grid[YUNTRANS(y)*o+x]) {
99 #ifdef STANDALONE_SOLVER
100             if (solver_show_working) {
101                 va_list ap;
102                 printf("%*s", solver_recurse_depth*4, "");
103                 va_start(ap, fmt);
104                 vprintf(fmt, ap);
105                 va_end(ap);
106                 printf(":\n%*s  placing %d at (%d,%d)\n",
107                        solver_recurse_depth*4, "", n, x+1, YUNTRANS(y)+1);
108             }
109 #endif
110             latin_solver_place(solver, x, y, n);
111             return +1;
112         }
113     } else if (m == 0) {
114 #ifdef STANDALONE_SOLVER
115         if (solver_show_working) {
116             va_list ap;
117             printf("%*s", solver_recurse_depth*4, "");
118             va_start(ap, fmt);
119             vprintf(fmt, ap);
120             va_end(ap);
121             printf(":\n%*s  no possibilities available\n",
122                    solver_recurse_depth*4, "");
123         }
124 #endif
125         return -1;
126     }
127
128     return 0;
129 }
130
131 struct latin_solver_scratch {
132     unsigned char *grid, *rowidx, *colidx, *set;
133     int *neighbours, *bfsqueue;
134 #ifdef STANDALONE_SOLVER
135     int *bfsprev;
136 #endif
137 };
138
139 int latin_solver_set(struct latin_solver *solver,
140                      struct latin_solver_scratch *scratch,
141                      int start, int step1, int step2
142 #ifdef STANDALONE_SOLVER
143                      , char *fmt, ...
144 #endif
145                      )
146 {
147     int o = solver->o;
148     int i, j, n, count;
149     unsigned char *grid = scratch->grid;
150     unsigned char *rowidx = scratch->rowidx;
151     unsigned char *colidx = scratch->colidx;
152     unsigned char *set = scratch->set;
153
154     /*
155      * We are passed a o-by-o matrix of booleans. Our first job
156      * is to winnow it by finding any definite placements - i.e.
157      * any row with a solitary 1 - and discarding that row and the
158      * column containing the 1.
159      */
160     memset(rowidx, TRUE, o);
161     memset(colidx, TRUE, o);
162     for (i = 0; i < o; i++) {
163         int count = 0, first = -1;
164         for (j = 0; j < o; j++)
165             if (solver->cube[start+i*step1+j*step2])
166                 first = j, count++;
167
168         if (count == 0) return -1;
169         if (count == 1)
170             rowidx[i] = colidx[first] = FALSE;
171     }
172
173     /*
174      * Convert each of rowidx/colidx from a list of 0s and 1s to a
175      * list of the indices of the 1s.
176      */
177     for (i = j = 0; i < o; i++)
178         if (rowidx[i])
179             rowidx[j++] = i;
180     n = j;
181     for (i = j = 0; i < o; i++)
182         if (colidx[i])
183             colidx[j++] = i;
184     assert(n == j);
185
186     /*
187      * And create the smaller matrix.
188      */
189     for (i = 0; i < n; i++)
190         for (j = 0; j < n; j++)
191             grid[i*o+j] = solver->cube[start+rowidx[i]*step1+colidx[j]*step2];
192
193     /*
194      * Having done that, we now have a matrix in which every row
195      * has at least two 1s in. Now we search to see if we can find
196      * a rectangle of zeroes (in the set-theoretic sense of
197      * `rectangle', i.e. a subset of rows crossed with a subset of
198      * columns) whose width and height add up to n.
199      */
200
201     memset(set, 0, n);
202     count = 0;
203     while (1) {
204         /*
205          * We have a candidate set. If its size is <=1 or >=n-1
206          * then we move on immediately.
207          */
208         if (count > 1 && count < n-1) {
209             /*
210              * The number of rows we need is n-count. See if we can
211              * find that many rows which each have a zero in all
212              * the positions listed in `set'.
213              */
214             int rows = 0;
215             for (i = 0; i < n; i++) {
216                 int ok = TRUE;
217                 for (j = 0; j < n; j++)
218                     if (set[j] && grid[i*o+j]) {
219                         ok = FALSE;
220                         break;
221                     }
222                 if (ok)
223                     rows++;
224             }
225
226             /*
227              * We expect never to be able to get _more_ than
228              * n-count suitable rows: this would imply that (for
229              * example) there are four numbers which between them
230              * have at most three possible positions, and hence it
231              * indicates a faulty deduction before this point or
232              * even a bogus clue.
233              */
234             if (rows > n - count) {
235 #ifdef STANDALONE_SOLVER
236                 if (solver_show_working) {
237                     va_list ap;
238                     printf("%*s", solver_recurse_depth*4,
239                            "");
240                     va_start(ap, fmt);
241                     vprintf(fmt, ap);
242                     va_end(ap);
243                     printf(":\n%*s  contradiction reached\n",
244                            solver_recurse_depth*4, "");
245                 }
246 #endif
247                 return -1;
248             }
249
250             if (rows >= n - count) {
251                 int progress = FALSE;
252
253                 /*
254                  * We've got one! Now, for each row which _doesn't_
255                  * satisfy the criterion, eliminate all its set
256                  * bits in the positions _not_ listed in `set'.
257                  * Return +1 (meaning progress has been made) if we
258                  * successfully eliminated anything at all.
259                  *
260                  * This involves referring back through
261                  * rowidx/colidx in order to work out which actual
262                  * positions in the cube to meddle with.
263                  */
264                 for (i = 0; i < n; i++) {
265                     int ok = TRUE;
266                     for (j = 0; j < n; j++)
267                         if (set[j] && grid[i*o+j]) {
268                             ok = FALSE;
269                             break;
270                         }
271                     if (!ok) {
272                         for (j = 0; j < n; j++)
273                             if (!set[j] && grid[i*o+j]) {
274                                 int fpos = (start+rowidx[i]*step1+
275                                             colidx[j]*step2);
276 #ifdef STANDALONE_SOLVER
277                                 if (solver_show_working) {
278                                     int px, py, pn;
279
280                                     if (!progress) {
281                                         va_list ap;
282                                         printf("%*s", solver_recurse_depth*4,
283                                                "");
284                                         va_start(ap, fmt);
285                                         vprintf(fmt, ap);
286                                         va_end(ap);
287                                         printf(":\n");
288                                     }
289
290                                     pn = 1 + fpos % o;
291                                     py = fpos / o;
292                                     px = py / o;
293                                     py %= o;
294
295                                     printf("%*s  ruling out %d at (%d,%d)\n",
296                                            solver_recurse_depth*4, "",
297                                            pn, px+1, YUNTRANS(py)+1);
298                                 }
299 #endif
300                                 progress = TRUE;
301                                 solver->cube[fpos] = FALSE;
302                             }
303                     }
304                 }
305
306                 if (progress) {
307                     return +1;
308                 }
309             }
310         }
311
312         /*
313          * Binary increment: change the rightmost 0 to a 1, and
314          * change all 1s to the right of it to 0s.
315          */
316         i = n;
317         while (i > 0 && set[i-1])
318             set[--i] = 0, count--;
319         if (i > 0)
320             set[--i] = 1, count++;
321         else
322             break;                     /* done */
323     }
324
325     return 0;
326 }
327
328 /*
329  * Look for forcing chains. A forcing chain is a path of
330  * pairwise-exclusive squares (i.e. each pair of adjacent squares
331  * in the path are in the same row, column or block) with the
332  * following properties:
333  *
334  *  (a) Each square on the path has precisely two possible numbers.
335  *
336  *  (b) Each pair of squares which are adjacent on the path share
337  *      at least one possible number in common.
338  *
339  *  (c) Each square in the middle of the path shares _both_ of its
340  *      numbers with at least one of its neighbours (not the same
341  *      one with both neighbours).
342  *
343  * These together imply that at least one of the possible number
344  * choices at one end of the path forces _all_ the rest of the
345  * numbers along the path. In order to make real use of this, we
346  * need further properties:
347  *
348  *  (c) Ruling out some number N from the square at one end
349  *      of the path forces the square at the other end to
350  *      take number N.
351  *
352  *  (d) The two end squares are both in line with some third
353  *      square.
354  *
355  *  (e) That third square currently has N as a possibility.
356  *
357  * If we can find all of that lot, we can deduce that at least one
358  * of the two ends of the forcing chain has number N, and that
359  * therefore the mutually adjacent third square does not.
360  *
361  * To find forcing chains, we're going to start a bfs at each
362  * suitable square, once for each of its two possible numbers.
363  */
364 int latin_solver_forcing(struct latin_solver *solver,
365                          struct latin_solver_scratch *scratch)
366 {
367     int o = solver->o;
368     int *bfsqueue = scratch->bfsqueue;
369 #ifdef STANDALONE_SOLVER
370     int *bfsprev = scratch->bfsprev;
371 #endif
372     unsigned char *number = scratch->grid;
373     int *neighbours = scratch->neighbours;
374     int x, y;
375
376     for (y = 0; y < o; y++)
377         for (x = 0; x < o; x++) {
378             int count, t, n;
379
380             /*
381              * If this square doesn't have exactly two candidate
382              * numbers, don't try it.
383              *
384              * In this loop we also sum the candidate numbers,
385              * which is a nasty hack to allow us to quickly find
386              * `the other one' (since we will shortly know there
387              * are exactly two).
388              */
389             for (count = t = 0, n = 1; n <= o; n++)
390                 if (cube(x, y, n))
391                     count++, t += n;
392             if (count != 2)
393                 continue;
394
395             /*
396              * Now attempt a bfs for each candidate.
397              */
398             for (n = 1; n <= o; n++)
399                 if (cube(x, y, n)) {
400                     int orign, currn, head, tail;
401
402                     /*
403                      * Begin a bfs.
404                      */
405                     orign = n;
406
407                     memset(number, o+1, o*o);
408                     head = tail = 0;
409                     bfsqueue[tail++] = y*o+x;
410 #ifdef STANDALONE_SOLVER
411                     bfsprev[y*o+x] = -1;
412 #endif
413                     number[y*o+x] = t - n;
414
415                     while (head < tail) {
416                         int xx, yy, nneighbours, xt, yt, i;
417
418                         xx = bfsqueue[head++];
419                         yy = xx / o;
420                         xx %= o;
421
422                         currn = number[yy*o+xx];
423
424                         /*
425                          * Find neighbours of yy,xx.
426                          */
427                         nneighbours = 0;
428                         for (yt = 0; yt < o; yt++)
429                             neighbours[nneighbours++] = yt*o+xx;
430                         for (xt = 0; xt < o; xt++)
431                             neighbours[nneighbours++] = yy*o+xt;
432
433                         /*
434                          * Try visiting each of those neighbours.
435                          */
436                         for (i = 0; i < nneighbours; i++) {
437                             int cc, tt, nn;
438
439                             xt = neighbours[i] % o;
440                             yt = neighbours[i] / o;
441
442                             /*
443                              * We need this square to not be
444                              * already visited, and to include
445                              * currn as a possible number.
446                              */
447                             if (number[yt*o+xt] <= o)
448                                 continue;
449                             if (!cube(xt, yt, currn))
450                                 continue;
451
452                             /*
453                              * Don't visit _this_ square a second
454                              * time!
455                              */
456                             if (xt == xx && yt == yy)
457                                 continue;
458
459                             /*
460                              * To continue with the bfs, we need
461                              * this square to have exactly two
462                              * possible numbers.
463                              */
464                             for (cc = tt = 0, nn = 1; nn <= o; nn++)
465                                 if (cube(xt, yt, nn))
466                                     cc++, tt += nn;
467                             if (cc == 2) {
468                                 bfsqueue[tail++] = yt*o+xt;
469 #ifdef STANDALONE_SOLVER
470                                 bfsprev[yt*o+xt] = yy*o+xx;
471 #endif
472                                 number[yt*o+xt] = tt - currn;
473                             }
474
475                             /*
476                              * One other possibility is that this
477                              * might be the square in which we can
478                              * make a real deduction: if it's
479                              * adjacent to x,y, and currn is equal
480                              * to the original number we ruled out.
481                              */
482                             if (currn == orign &&
483                                 (xt == x || yt == y)) {
484 #ifdef STANDALONE_SOLVER
485                                 if (solver_show_working) {
486                                     char *sep = "";
487                                     int xl, yl;
488                                     printf("%*sforcing chain, %d at ends of ",
489                                            solver_recurse_depth*4, "", orign);
490                                     xl = xx;
491                                     yl = yy;
492                                     while (1) {
493                                         printf("%s(%d,%d)", sep, xl+1,
494                                                YUNTRANS(yl)+1);
495                                         xl = bfsprev[yl*o+xl];
496                                         if (xl < 0)
497                                             break;
498                                         yl = xl / o;
499                                         xl %= o;
500                                         sep = "-";
501                                     }
502                                     printf("\n%*s  ruling out %d at (%d,%d)\n",
503                                            solver_recurse_depth*4, "",
504                                            orign, xt+1, YUNTRANS(yt)+1);
505                                 }
506 #endif
507                                 cube(xt, yt, orign) = FALSE;
508                                 return 1;
509                             }
510                         }
511                     }
512                 }
513         }
514
515     return 0;
516 }
517
518 struct latin_solver_scratch *latin_solver_new_scratch(struct latin_solver *solver)
519 {
520     struct latin_solver_scratch *scratch = snew(struct latin_solver_scratch);
521     int o = solver->o;
522     scratch->grid = snewn(o*o, unsigned char);
523     scratch->rowidx = snewn(o, unsigned char);
524     scratch->colidx = snewn(o, unsigned char);
525     scratch->set = snewn(o, unsigned char);
526     scratch->neighbours = snewn(3*o, int);
527     scratch->bfsqueue = snewn(o*o, int);
528 #ifdef STANDALONE_SOLVER
529     scratch->bfsprev = snewn(o*o, int);
530 #endif
531     return scratch;
532 }
533
534 void latin_solver_free_scratch(struct latin_solver_scratch *scratch)
535 {
536 #ifdef STANDALONE_SOLVER
537     sfree(scratch->bfsprev);
538 #endif
539     sfree(scratch->bfsqueue);
540     sfree(scratch->neighbours);
541     sfree(scratch->set);
542     sfree(scratch->colidx);
543     sfree(scratch->rowidx);
544     sfree(scratch->grid);
545     sfree(scratch);
546 }
547
548 void latin_solver_alloc(struct latin_solver *solver, digit *grid, int o)
549 {
550     int x, y;
551
552     solver->o = o;
553     solver->cube = snewn(o*o*o, unsigned char);
554     solver->grid = grid;                /* write straight back to the input */
555     memset(solver->cube, TRUE, o*o*o);
556
557     solver->row = snewn(o*o, unsigned char);
558     solver->col = snewn(o*o, unsigned char);
559     memset(solver->row, FALSE, o*o);
560     memset(solver->col, FALSE, o*o);
561
562     for (x = 0; x < o; x++)
563         for (y = 0; y < o; y++)
564             if (grid[y*o+x])
565                 latin_solver_place(solver, x, YTRANS(y), grid[y*o+x]);
566 }
567
568 void latin_solver_free(struct latin_solver *solver)
569 {
570     sfree(solver->cube);
571     sfree(solver->row);
572     sfree(solver->col);
573 }
574
575 int latin_solver_diff_simple(struct latin_solver *solver)
576 {
577     int x, y, n, ret, o = solver->o;
578     /*
579      * Row-wise positional elimination.
580      */
581     for (y = 0; y < o; y++)
582         for (n = 1; n <= o; n++)
583             if (!solver->row[y*o+n-1]) {
584                 ret = latin_solver_elim(solver, cubepos(0,y,n), o*o
585 #ifdef STANDALONE_SOLVER
586                                         , "positional elimination,"
587                                         " %d in row %d", n, YUNTRANS(y)+1
588 #endif
589                                         );
590                 if (ret != 0) return ret;
591             }
592     /*
593      * Column-wise positional elimination.
594      */
595     for (x = 0; x < o; x++)
596         for (n = 1; n <= o; n++)
597             if (!solver->col[x*o+n-1]) {
598                 ret = latin_solver_elim(solver, cubepos(x,0,n), o
599 #ifdef STANDALONE_SOLVER
600                                         , "positional elimination,"
601                                         " %d in column %d", n, x+1
602 #endif
603                                         );
604                 if (ret != 0) return ret;
605             }
606
607     /*
608      * Numeric elimination.
609      */
610     for (x = 0; x < o; x++)
611         for (y = 0; y < o; y++)
612             if (!solver->grid[YUNTRANS(y)*o+x]) {
613                 ret = latin_solver_elim(solver, cubepos(x,y,1), 1
614 #ifdef STANDALONE_SOLVER
615                                         , "numeric elimination at (%d,%d)",
616                                         x+1, YUNTRANS(y)+1
617 #endif
618                                         );
619                 if (ret != 0) return ret;
620             }
621     return 0;
622 }
623
624 int latin_solver_diff_set(struct latin_solver *solver,
625                           struct latin_solver_scratch *scratch,
626                           int extreme)
627 {
628     int x, y, n, ret, o = solver->o;
629
630     if (!extreme) {
631         /*
632          * Row-wise set elimination.
633          */
634         for (y = 0; y < o; y++) {
635             ret = latin_solver_set(solver, scratch, cubepos(0,y,1), o*o, 1
636 #ifdef STANDALONE_SOLVER
637                                    , "set elimination, row %d", YUNTRANS(y)+1
638 #endif
639                                   );
640             if (ret != 0) return ret;
641         }
642         /*
643          * Column-wise set elimination.
644          */
645         for (x = 0; x < o; x++) {
646             ret = latin_solver_set(solver, scratch, cubepos(x,0,1), o, 1
647 #ifdef STANDALONE_SOLVER
648                                    , "set elimination, column %d", x+1
649 #endif
650                                   );
651             if (ret != 0) return ret;
652         }
653     } else {
654         /*
655          * Row-vs-column set elimination on a single number
656          * (much tricker for a human to do!)
657          */
658         for (n = 1; n <= o; n++) {
659             ret = latin_solver_set(solver, scratch, cubepos(0,0,n), o*o, o
660 #ifdef STANDALONE_SOLVER
661                                    , "positional set elimination, number %d", n
662 #endif
663                                   );
664             if (ret != 0) return ret;
665         }
666     }
667     return 0;
668 }
669
670 /*
671  * Returns:
672  * 0 for 'didn't do anything' implying it was already solved.
673  * -1 for 'impossible' (no solution)
674  * 1 for 'single solution'
675  * >1 for 'multiple solutions' (you don't get to know how many, and
676  *     the first such solution found will be set.
677  *
678  * and this function may well assert if given an impossible board.
679  */
680 static int latin_solver_recurse
681     (struct latin_solver *solver, int diff_simple, int diff_set_0,
682      int diff_set_1, int diff_forcing, int diff_recursive,
683      usersolver_t const *usersolvers, void *ctx,
684      ctxnew_t ctxnew, ctxfree_t ctxfree)
685 {
686     int best, bestcount;
687     int o = solver->o, x, y, n;
688
689     best = -1;
690     bestcount = o+1;
691
692     for (y = 0; y < o; y++)
693         for (x = 0; x < o; x++)
694             if (!solver->grid[y*o+x]) {
695                 int count;
696
697                 /*
698                  * An unfilled square. Count the number of
699                  * possible digits in it.
700                  */
701                 count = 0;
702                 for (n = 1; n <= o; n++)
703                     if (cube(x,YTRANS(y),n))
704                         count++;
705
706                 /*
707                  * We should have found any impossibilities
708                  * already, so this can safely be an assert.
709                  */
710                 assert(count > 1);
711
712                 if (count < bestcount) {
713                     bestcount = count;
714                     best = y*o+x;
715                 }
716             }
717
718     if (best == -1)
719         /* we were complete already. */
720         return 0;
721     else {
722         int i, j;
723         digit *list, *ingrid, *outgrid;
724         int diff = diff_impossible;    /* no solution found yet */
725
726         /*
727          * Attempt recursion.
728          */
729         y = best / o;
730         x = best % o;
731
732         list = snewn(o, digit);
733         ingrid = snewn(o*o, digit);
734         outgrid = snewn(o*o, digit);
735         memcpy(ingrid, solver->grid, o*o);
736
737         /* Make a list of the possible digits. */
738         for (j = 0, n = 1; n <= o; n++)
739             if (cube(x,YTRANS(y),n))
740                 list[j++] = n;
741
742 #ifdef STANDALONE_SOLVER
743         if (solver_show_working) {
744             char *sep = "";
745             printf("%*srecursing on (%d,%d) [",
746                    solver_recurse_depth*4, "", x+1, y+1);
747             for (i = 0; i < j; i++) {
748                 printf("%s%d", sep, list[i]);
749                 sep = " or ";
750             }
751             printf("]\n");
752         }
753 #endif
754
755         /*
756          * And step along the list, recursing back into the
757          * main solver at every stage.
758          */
759         for (i = 0; i < j; i++) {
760             int ret;
761             void *newctx;
762
763             memcpy(outgrid, ingrid, o*o);
764             outgrid[y*o+x] = list[i];
765
766 #ifdef STANDALONE_SOLVER
767             if (solver_show_working)
768                 printf("%*sguessing %d at (%d,%d)\n",
769                        solver_recurse_depth*4, "", list[i], x+1, y+1);
770             solver_recurse_depth++;
771 #endif
772
773             if (ctxnew) {
774                 newctx = ctxnew(ctx);
775             } else {
776                 newctx = ctx;
777             }
778             ret = latin_solver(outgrid, o, diff_recursive,
779                                diff_simple, diff_set_0, diff_set_1,
780                                diff_forcing, diff_recursive,
781                                usersolvers, newctx, ctxnew, ctxfree);
782             if (ctxnew)
783                 ctxfree(newctx);
784
785 #ifdef STANDALONE_SOLVER
786             solver_recurse_depth--;
787             if (solver_show_working) {
788                 printf("%*sretracting %d at (%d,%d)\n",
789                        solver_recurse_depth*4, "", list[i], x+1, y+1);
790             }
791 #endif
792             /* we recurse as deep as we can, so we should never find
793              * find ourselves giving up on a puzzle without declaring it
794              * impossible.  */
795             assert(ret != diff_unfinished);
796
797             /*
798              * If we have our first solution, copy it into the
799              * grid we will return.
800              */
801             if (diff == diff_impossible && ret != diff_impossible)
802                 memcpy(solver->grid, outgrid, o*o);
803
804             if (ret == diff_ambiguous)
805                 diff = diff_ambiguous;
806             else if (ret == diff_impossible)
807                 /* do not change our return value */;
808             else {
809                 /* the recursion turned up exactly one solution */
810                 if (diff == diff_impossible)
811                     diff = diff_recursive;
812                 else
813                     diff = diff_ambiguous;
814             }
815
816             /*
817              * As soon as we've found more than one solution,
818              * give up immediately.
819              */
820             if (diff == diff_ambiguous)
821                 break;
822         }
823
824         sfree(outgrid);
825         sfree(ingrid);
826         sfree(list);
827
828         if (diff == diff_impossible)
829             return -1;
830         else if (diff == diff_ambiguous)
831             return 2;
832         else {
833             assert(diff == diff_recursive);
834             return 1;
835         }
836     }
837 }
838
839 int latin_solver_main(struct latin_solver *solver, int maxdiff,
840                       int diff_simple, int diff_set_0, int diff_set_1,
841                       int diff_forcing, int diff_recursive,
842                       usersolver_t const *usersolvers, void *ctx,
843                       ctxnew_t ctxnew, ctxfree_t ctxfree)
844 {
845     struct latin_solver_scratch *scratch = latin_solver_new_scratch(solver);
846     int ret, diff = diff_simple;
847
848     assert(maxdiff <= diff_recursive);
849     /*
850      * Now loop over the grid repeatedly trying all permitted modes
851      * of reasoning. The loop terminates if we complete an
852      * iteration without making any progress; we then return
853      * failure or success depending on whether the grid is full or
854      * not.
855      */
856     while (1) {
857         int i;
858
859         cont:
860
861         latin_solver_debug(solver->cube, solver->o);
862
863         for (i = 0; i <= maxdiff; i++) {
864             if (usersolvers[i])
865                 ret = usersolvers[i](solver, ctx);
866             else
867                 ret = 0;
868             if (ret == 0 && i == diff_simple)
869                 ret = latin_solver_diff_simple(solver);
870             if (ret == 0 && i == diff_set_0)
871                 ret = latin_solver_diff_set(solver, scratch, 0);
872             if (ret == 0 && i == diff_set_1)
873                 ret = latin_solver_diff_set(solver, scratch, 1);
874             if (ret == 0 && i == diff_forcing)
875                 ret = latin_solver_forcing(solver, scratch);
876
877             if (ret < 0) {
878                 diff = diff_impossible;
879                 goto got_result;
880             } else if (ret > 0) {
881                 diff = max(diff, i);
882                 goto cont;
883             }
884         }
885
886         /*
887          * If we reach here, we have made no deductions in this
888          * iteration, so the algorithm terminates.
889          */
890         break;
891     }
892
893     /*
894      * Last chance: if we haven't fully solved the puzzle yet, try
895      * recursing based on guesses for a particular square. We pick
896      * one of the most constrained empty squares we can find, which
897      * has the effect of pruning the search tree as much as
898      * possible.
899      */
900     if (maxdiff == diff_recursive) {
901         int nsol = latin_solver_recurse(solver,
902                                         diff_simple, diff_set_0, diff_set_1,
903                                         diff_forcing, diff_recursive,
904                                         usersolvers, ctx, ctxnew, ctxfree);
905         if (nsol < 0) diff = diff_impossible;
906         else if (nsol == 1) diff = diff_recursive;
907         else if (nsol > 1) diff = diff_ambiguous;
908         /* if nsol == 0 then we were complete anyway
909          * (and thus don't need to change diff) */
910     } else {
911         /*
912          * We're forbidden to use recursion, so we just see whether
913          * our grid is fully solved, and return diff_unfinished
914          * otherwise.
915          */
916         int x, y, o = solver->o;
917
918         for (y = 0; y < o; y++)
919             for (x = 0; x < o; x++)
920                 if (!solver->grid[y*o+x])
921                     diff = diff_unfinished;
922     }
923
924     got_result:
925
926 #ifdef STANDALONE_SOLVER
927     if (solver_show_working)
928         printf("%*s%s found\n",
929                solver_recurse_depth*4, "",
930                diff == diff_impossible ? "no solution (impossible)" :
931                diff == diff_unfinished ? "no solution (unfinished)" :
932                diff == diff_ambiguous ? "multiple solutions" :
933                "one solution");
934 #endif
935
936     latin_solver_free_scratch(scratch);
937
938     return diff;
939 }
940
941 int latin_solver(digit *grid, int o, int maxdiff,
942                  int diff_simple, int diff_set_0, int diff_set_1,
943                  int diff_forcing, int diff_recursive,
944                  usersolver_t const *usersolvers, void *ctx,
945                  ctxnew_t ctxnew, ctxfree_t ctxfree)
946 {
947     struct latin_solver solver;
948     int diff;
949
950     latin_solver_alloc(&solver, grid, o);
951     diff = latin_solver_main(&solver, maxdiff,
952                              diff_simple, diff_set_0, diff_set_1,
953                              diff_forcing, diff_recursive,
954                              usersolvers, ctx, ctxnew, ctxfree);
955     latin_solver_free(&solver);
956     return diff;
957 }
958
959 void latin_solver_debug(unsigned char *cube, int o)
960 {
961 #ifdef STANDALONE_SOLVER
962     if (solver_show_working > 1) {
963         struct latin_solver ls, *solver = &ls;
964         char *dbg;
965         int x, y, i, c = 0;
966
967         ls.cube = cube; ls.o = o; /* for cube() to work */
968
969         dbg = snewn(3*o*o*o, char);
970         for (y = 0; y < o; y++) {
971             for (x = 0; x < o; x++) {
972                 for (i = 1; i <= o; i++) {
973                     if (cube(x,y,i))
974                         dbg[c++] = i + '0';
975                     else
976                         dbg[c++] = '.';
977                 }
978                 dbg[c++] = ' ';
979             }
980             dbg[c++] = '\n';
981         }
982         dbg[c++] = '\n';
983         dbg[c++] = '\0';
984
985         printf("%s", dbg);
986         sfree(dbg);
987     }
988 #endif
989 }
990
991 void latin_debug(digit *sq, int o)
992 {
993 #ifdef STANDALONE_SOLVER
994     if (solver_show_working) {
995         int x, y;
996
997         for (y = 0; y < o; y++) {
998             for (x = 0; x < o; x++) {
999                 printf("%2d ", sq[y*o+x]);
1000             }
1001             printf("\n");
1002         }
1003         printf("\n");
1004     }
1005 #endif
1006 }
1007
1008 /* --------------------------------------------------------
1009  * Generation.
1010  */
1011
1012 digit *latin_generate(int o, random_state *rs)
1013 {
1014     digit *sq;
1015     int *edges, *backedges, *capacity, *flow;
1016     void *scratch;
1017     int ne, scratchsize;
1018     int i, j, k;
1019     digit *row, *col, *numinv, *num;
1020
1021     /*
1022      * To efficiently generate a latin square in such a way that
1023      * all possible squares are possible outputs from the function,
1024      * we make use of a theorem which states that any r x n latin
1025      * rectangle, with r < n, can be extended into an (r+1) x n
1026      * latin rectangle. In other words, we can reliably generate a
1027      * latin square row by row, by at every stage writing down any
1028      * row at all which doesn't conflict with previous rows, and
1029      * the theorem guarantees that we will never have to backtrack.
1030      *
1031      * To find a viable row at each stage, we can make use of the
1032      * support functions in maxflow.c.
1033      */
1034
1035     sq = snewn(o*o, digit);
1036
1037     /*
1038      * In case this method of generation introduces a really subtle
1039      * top-to-bottom directional bias, we'll generate the rows in
1040      * random order.
1041      */
1042     row = snewn(o, digit);
1043     col = snewn(o, digit);
1044     numinv = snewn(o, digit);
1045     num = snewn(o, digit);
1046     for (i = 0; i < o; i++)
1047         row[i] = i;
1048     shuffle(row, i, sizeof(*row), rs);
1049
1050     /*
1051      * Set up the infrastructure for the maxflow algorithm.
1052      */
1053     scratchsize = maxflow_scratch_size(o * 2 + 2);
1054     scratch = smalloc(scratchsize);
1055     backedges = snewn(o*o + 2*o, int);
1056     edges = snewn((o*o + 2*o) * 2, int);
1057     capacity = snewn(o*o + 2*o, int);
1058     flow = snewn(o*o + 2*o, int);
1059     /* Set up the edge array, and the initial capacities. */
1060     ne = 0;
1061     for (i = 0; i < o; i++) {
1062         /* Each LHS vertex is connected to all RHS vertices. */
1063         for (j = 0; j < o; j++) {
1064             edges[ne*2] = i;
1065             edges[ne*2+1] = j+o;
1066             /* capacity for this edge is set later on */
1067             ne++;
1068         }
1069     }
1070     for (i = 0; i < o; i++) {
1071         /* Each RHS vertex is connected to the distinguished sink vertex. */
1072         edges[ne*2] = i+o;
1073         edges[ne*2+1] = o*2+1;
1074         capacity[ne] = 1;
1075         ne++;
1076     }
1077     for (i = 0; i < o; i++) {
1078         /* And the distinguished source vertex connects to each LHS vertex. */
1079         edges[ne*2] = o*2;
1080         edges[ne*2+1] = i;
1081         capacity[ne] = 1;
1082         ne++;
1083     }
1084     assert(ne == o*o + 2*o);
1085     /* Now set up backedges. */
1086     maxflow_setup_backedges(ne, edges, backedges);
1087     
1088     /*
1089      * Now generate each row of the latin square.
1090      */
1091     for (i = 0; i < o; i++) {
1092         /*
1093          * To prevent maxflow from behaving deterministically, we
1094          * separately permute the columns and the digits for the
1095          * purposes of the algorithm, differently for every row.
1096          */
1097         for (j = 0; j < o; j++)
1098             col[j] = num[j] = j;
1099         shuffle(col, j, sizeof(*col), rs);
1100         shuffle(num, j, sizeof(*num), rs);
1101         /* We need the num permutation in both forward and inverse forms. */
1102         for (j = 0; j < o; j++)
1103             numinv[num[j]] = j;
1104
1105         /*
1106          * Set up the capacities for the maxflow run, by examining
1107          * the existing latin square.
1108          */
1109         for (j = 0; j < o*o; j++)
1110             capacity[j] = 1;
1111         for (j = 0; j < i; j++)
1112             for (k = 0; k < o; k++) {
1113                 int n = num[sq[row[j]*o + col[k]] - 1];
1114                 capacity[k*o+n] = 0;
1115             }
1116
1117         /*
1118          * Run maxflow.
1119          */
1120         j = maxflow_with_scratch(scratch, o*2+2, 2*o, 2*o+1, ne,
1121                                  edges, backedges, capacity, flow, NULL);
1122         assert(j == o);   /* by the above theorem, this must have succeeded */
1123
1124         /*
1125          * And examine the flow array to pick out the new row of
1126          * the latin square.
1127          */
1128         for (j = 0; j < o; j++) {
1129             for (k = 0; k < o; k++) {
1130                 if (flow[j*o+k])
1131                     break;
1132             }
1133             assert(k < o);
1134             sq[row[i]*o + col[j]] = numinv[k] + 1;
1135         }
1136     }
1137
1138     /*
1139      * Done. Free our internal workspaces...
1140      */
1141     sfree(flow);
1142     sfree(capacity);
1143     sfree(edges);
1144     sfree(backedges);
1145     sfree(scratch);
1146     sfree(numinv);
1147     sfree(num);
1148     sfree(col);
1149     sfree(row);
1150
1151     /*
1152      * ... and return our completed latin square.
1153      */
1154     return sq;
1155 }
1156
1157 /* --------------------------------------------------------
1158  * Checking.
1159  */
1160
1161 typedef struct lcparams {
1162     digit elt;
1163     int count;
1164 } lcparams;
1165
1166 static int latin_check_cmp(void *v1, void *v2)
1167 {
1168     lcparams *lc1 = (lcparams *)v1;
1169     lcparams *lc2 = (lcparams *)v2;
1170
1171     if (lc1->elt < lc2->elt) return -1;
1172     if (lc1->elt > lc2->elt) return 1;
1173     return 0;
1174 }
1175
1176 #define ELT(sq,x,y) (sq[((y)*order)+(x)])
1177
1178 /* returns non-zero if sq is not a latin square. */
1179 int latin_check(digit *sq, int order)
1180 {
1181     tree234 *dict = newtree234(latin_check_cmp);
1182     int c, r;
1183     int ret = 0;
1184     lcparams *lcp, lc, *aret;
1185
1186     /* Use a tree234 as a simple hash table, go through the square
1187      * adding elements as we go or incrementing their counts. */
1188     for (c = 0; c < order; c++) {
1189         for (r = 0; r < order; r++) {
1190             lc.elt = ELT(sq, c, r); lc.count = 0;
1191             lcp = find234(dict, &lc, NULL);
1192             if (!lcp) {
1193                 lcp = snew(lcparams);
1194                 lcp->elt = ELT(sq, c, r);
1195                 lcp->count = 1;
1196                 aret = add234(dict, lcp);
1197                 assert(aret == lcp);
1198             } else {
1199                 lcp->count++;
1200             }
1201         }
1202     }
1203
1204     /* There should be precisely 'order' letters in the alphabet,
1205      * each occurring 'order' times (making the OxO tree) */
1206     if (count234(dict) != order) ret = 1;
1207     else {
1208         for (c = 0; (lcp = index234(dict, c)) != NULL; c++) {
1209             if (lcp->count != order) ret = 1;
1210         }
1211     }
1212     for (c = 0; (lcp = index234(dict, c)) != NULL; c++)
1213         sfree(lcp);
1214     freetree234(dict);
1215
1216     return ret;
1217 }
1218
1219
1220 /* --------------------------------------------------------
1221  * Testing (and printing).
1222  */
1223
1224 #ifdef STANDALONE_LATIN_TEST
1225
1226 #include <stdio.h>
1227 #include <time.h>
1228
1229 const char *quis;
1230
1231 static void latin_print(digit *sq, int order)
1232 {
1233     int x, y;
1234
1235     for (y = 0; y < order; y++) {
1236         for (x = 0; x < order; x++) {
1237             printf("%2u ", ELT(sq, x, y));
1238         }
1239         printf("\n");
1240     }
1241     printf("\n");
1242 }
1243
1244 static void gen(int order, random_state *rs, int debug)
1245 {
1246     digit *sq;
1247
1248     solver_show_working = debug;
1249
1250     sq = latin_generate(order, rs);
1251     latin_print(sq, order);
1252     if (latin_check(sq, order)) {
1253         fprintf(stderr, "Square is not a latin square!");
1254         exit(1);
1255     }
1256
1257     sfree(sq);
1258 }
1259
1260 void test_soak(int order, random_state *rs)
1261 {
1262     digit *sq;
1263     int n = 0;
1264     time_t tt_start, tt_now, tt_last;
1265
1266     solver_show_working = 0;
1267     tt_now = tt_start = time(NULL);
1268
1269     while(1) {
1270         sq = latin_generate(order, rs);
1271         sfree(sq);
1272         n++;
1273
1274         tt_last = time(NULL);
1275         if (tt_last > tt_now) {
1276             tt_now = tt_last;
1277             printf("%d total, %3.1f/s\n", n,
1278                    (double)n / (double)(tt_now - tt_start));
1279         }
1280     }
1281 }
1282
1283 void usage_exit(const char *msg)
1284 {
1285     if (msg)
1286         fprintf(stderr, "%s: %s\n", quis, msg);
1287     fprintf(stderr, "Usage: %s [--seed SEED] --soak <params> | [game_id [game_id ...]]\n", quis);
1288     exit(1);
1289 }
1290
1291 int main(int argc, char *argv[])
1292 {
1293     int i, soak = 0;
1294     random_state *rs;
1295     time_t seed = time(NULL);
1296
1297     quis = argv[0];
1298     while (--argc > 0) {
1299         const char *p = *++argv;
1300         if (!strcmp(p, "--soak"))
1301             soak = 1;
1302         else if (!strcmp(p, "--seed")) {
1303             if (argc == 0)
1304                 usage_exit("--seed needs an argument");
1305             seed = (time_t)atoi(*++argv);
1306             argc--;
1307         } else if (*p == '-')
1308                 usage_exit("unrecognised option");
1309         else
1310             break; /* finished options */
1311     }
1312
1313     rs = random_new((void*)&seed, sizeof(time_t));
1314
1315     if (soak == 1) {
1316         if (argc != 1) usage_exit("only one argument for --soak");
1317         test_soak(atoi(*argv), rs);
1318     } else {
1319         if (argc > 0) {
1320             for (i = 0; i < argc; i++) {
1321                 gen(atoi(*argv++), rs, 1);
1322             }
1323         } else {
1324             while (1) {
1325                 i = random_upto(rs, 20) + 1;
1326                 gen(i, rs, 0);
1327             }
1328         }
1329     }
1330     random_free(rs);
1331     return 0;
1332 }
1333
1334 #endif
1335
1336 /* vim: set shiftwidth=4 tabstop=8: */