chiark / gitweb /
Tents: mark squares as non-tents with {Shift,Control}-cursor keys.
[sgt-puzzles.git] / latin.c
1 #include <assert.h>
2 #include <string.h>
3 #include <stdarg.h>
4
5 #include "puzzles.h"
6 #include "tree234.h"
7 #include "maxflow.h"
8
9 #ifdef STANDALONE_LATIN_TEST
10 #define STANDALONE_SOLVER
11 #endif
12
13 #include "latin.h"
14
15 /* --------------------------------------------------------
16  * Solver.
17  */
18
19 static int latin_solver_top(struct latin_solver *solver, int maxdiff,
20                             int diff_simple, int diff_set_0, int diff_set_1,
21                             int diff_forcing, int diff_recursive,
22                             usersolver_t const *usersolvers, void *ctx,
23                             ctxnew_t ctxnew, ctxfree_t ctxfree);
24
25 #ifdef STANDALONE_SOLVER
26 int solver_show_working, solver_recurse_depth;
27 #endif
28
29 /*
30  * Function called when we are certain that a particular square has
31  * a particular number in it. The y-coordinate passed in here is
32  * transformed.
33  */
34 void latin_solver_place(struct latin_solver *solver, int x, int y, int n)
35 {
36     int i, o = solver->o;
37
38     assert(n <= o);
39     assert(cube(x,y,n));
40
41     /*
42      * Rule out all other numbers in this square.
43      */
44     for (i = 1; i <= o; i++)
45         if (i != n)
46             cube(x,y,i) = FALSE;
47
48     /*
49      * Rule out this number in all other positions in the row.
50      */
51     for (i = 0; i < o; i++)
52         if (i != y)
53             cube(x,i,n) = FALSE;
54
55     /*
56      * Rule out this number in all other positions in the column.
57      */
58     for (i = 0; i < o; i++)
59         if (i != x)
60             cube(i,y,n) = FALSE;
61
62     /*
63      * Enter the number in the result grid.
64      */
65     solver->grid[y*o+x] = n;
66
67     /*
68      * Cross out this number from the list of numbers left to place
69      * in its row, its column and its block.
70      */
71     solver->row[y*o+n-1] = solver->col[x*o+n-1] = TRUE;
72 }
73
74 int latin_solver_elim(struct latin_solver *solver, int start, int step
75 #ifdef STANDALONE_SOLVER
76                       , char *fmt, ...
77 #endif
78                       )
79 {
80     int o = solver->o;
81 #ifdef STANDALONE_SOLVER
82     char **names = solver->names;
83 #endif
84     int fpos, m, i;
85
86     /*
87      * Count the number of set bits within this section of the
88      * cube.
89      */
90     m = 0;
91     fpos = -1;
92     for (i = 0; i < o; i++)
93         if (solver->cube[start+i*step]) {
94             fpos = start+i*step;
95             m++;
96         }
97
98     if (m == 1) {
99         int x, y, n;
100         assert(fpos >= 0);
101
102         n = 1 + fpos % o;
103         y = fpos / o;
104         x = y / o;
105         y %= o;
106
107         if (!solver->grid[y*o+x]) {
108 #ifdef STANDALONE_SOLVER
109             if (solver_show_working) {
110                 va_list ap;
111                 printf("%*s", solver_recurse_depth*4, "");
112                 va_start(ap, fmt);
113                 vprintf(fmt, ap);
114                 va_end(ap);
115                 printf(":\n%*s  placing %s at (%d,%d)\n",
116                        solver_recurse_depth*4, "", names[n-1],
117                        x+1, y+1);
118             }
119 #endif
120             latin_solver_place(solver, x, y, n);
121             return +1;
122         }
123     } else if (m == 0) {
124 #ifdef STANDALONE_SOLVER
125         if (solver_show_working) {
126             va_list ap;
127             printf("%*s", solver_recurse_depth*4, "");
128             va_start(ap, fmt);
129             vprintf(fmt, ap);
130             va_end(ap);
131             printf(":\n%*s  no possibilities available\n",
132                    solver_recurse_depth*4, "");
133         }
134 #endif
135         return -1;
136     }
137
138     return 0;
139 }
140
141 struct latin_solver_scratch {
142     unsigned char *grid, *rowidx, *colidx, *set;
143     int *neighbours, *bfsqueue;
144 #ifdef STANDALONE_SOLVER
145     int *bfsprev;
146 #endif
147 };
148
149 int latin_solver_set(struct latin_solver *solver,
150                      struct latin_solver_scratch *scratch,
151                      int start, int step1, int step2
152 #ifdef STANDALONE_SOLVER
153                      , char *fmt, ...
154 #endif
155                      )
156 {
157     int o = solver->o;
158 #ifdef STANDALONE_SOLVER
159     char **names = solver->names;
160 #endif
161     int i, j, n, count;
162     unsigned char *grid = scratch->grid;
163     unsigned char *rowidx = scratch->rowidx;
164     unsigned char *colidx = scratch->colidx;
165     unsigned char *set = scratch->set;
166
167     /*
168      * We are passed a o-by-o matrix of booleans. Our first job
169      * is to winnow it by finding any definite placements - i.e.
170      * any row with a solitary 1 - and discarding that row and the
171      * column containing the 1.
172      */
173     memset(rowidx, TRUE, o);
174     memset(colidx, TRUE, o);
175     for (i = 0; i < o; i++) {
176         int count = 0, first = -1;
177         for (j = 0; j < o; j++)
178             if (solver->cube[start+i*step1+j*step2])
179                 first = j, count++;
180
181         if (count == 0) return -1;
182         if (count == 1)
183             rowidx[i] = colidx[first] = FALSE;
184     }
185
186     /*
187      * Convert each of rowidx/colidx from a list of 0s and 1s to a
188      * list of the indices of the 1s.
189      */
190     for (i = j = 0; i < o; i++)
191         if (rowidx[i])
192             rowidx[j++] = i;
193     n = j;
194     for (i = j = 0; i < o; i++)
195         if (colidx[i])
196             colidx[j++] = i;
197     assert(n == j);
198
199     /*
200      * And create the smaller matrix.
201      */
202     for (i = 0; i < n; i++)
203         for (j = 0; j < n; j++)
204             grid[i*o+j] = solver->cube[start+rowidx[i]*step1+colidx[j]*step2];
205
206     /*
207      * Having done that, we now have a matrix in which every row
208      * has at least two 1s in. Now we search to see if we can find
209      * a rectangle of zeroes (in the set-theoretic sense of
210      * `rectangle', i.e. a subset of rows crossed with a subset of
211      * columns) whose width and height add up to n.
212      */
213
214     memset(set, 0, n);
215     count = 0;
216     while (1) {
217         /*
218          * We have a candidate set. If its size is <=1 or >=n-1
219          * then we move on immediately.
220          */
221         if (count > 1 && count < n-1) {
222             /*
223              * The number of rows we need is n-count. See if we can
224              * find that many rows which each have a zero in all
225              * the positions listed in `set'.
226              */
227             int rows = 0;
228             for (i = 0; i < n; i++) {
229                 int ok = TRUE;
230                 for (j = 0; j < n; j++)
231                     if (set[j] && grid[i*o+j]) {
232                         ok = FALSE;
233                         break;
234                     }
235                 if (ok)
236                     rows++;
237             }
238
239             /*
240              * We expect never to be able to get _more_ than
241              * n-count suitable rows: this would imply that (for
242              * example) there are four numbers which between them
243              * have at most three possible positions, and hence it
244              * indicates a faulty deduction before this point or
245              * even a bogus clue.
246              */
247             if (rows > n - count) {
248 #ifdef STANDALONE_SOLVER
249                 if (solver_show_working) {
250                     va_list ap;
251                     printf("%*s", solver_recurse_depth*4,
252                            "");
253                     va_start(ap, fmt);
254                     vprintf(fmt, ap);
255                     va_end(ap);
256                     printf(":\n%*s  contradiction reached\n",
257                            solver_recurse_depth*4, "");
258                 }
259 #endif
260                 return -1;
261             }
262
263             if (rows >= n - count) {
264                 int progress = FALSE;
265
266                 /*
267                  * We've got one! Now, for each row which _doesn't_
268                  * satisfy the criterion, eliminate all its set
269                  * bits in the positions _not_ listed in `set'.
270                  * Return +1 (meaning progress has been made) if we
271                  * successfully eliminated anything at all.
272                  *
273                  * This involves referring back through
274                  * rowidx/colidx in order to work out which actual
275                  * positions in the cube to meddle with.
276                  */
277                 for (i = 0; i < n; i++) {
278                     int ok = TRUE;
279                     for (j = 0; j < n; j++)
280                         if (set[j] && grid[i*o+j]) {
281                             ok = FALSE;
282                             break;
283                         }
284                     if (!ok) {
285                         for (j = 0; j < n; j++)
286                             if (!set[j] && grid[i*o+j]) {
287                                 int fpos = (start+rowidx[i]*step1+
288                                             colidx[j]*step2);
289 #ifdef STANDALONE_SOLVER
290                                 if (solver_show_working) {
291                                     int px, py, pn;
292
293                                     if (!progress) {
294                                         va_list ap;
295                                         printf("%*s", solver_recurse_depth*4,
296                                                "");
297                                         va_start(ap, fmt);
298                                         vprintf(fmt, ap);
299                                         va_end(ap);
300                                         printf(":\n");
301                                     }
302
303                                     pn = 1 + fpos % o;
304                                     py = fpos / o;
305                                     px = py / o;
306                                     py %= o;
307
308                                     printf("%*s  ruling out %s at (%d,%d)\n",
309                                            solver_recurse_depth*4, "",
310                                            names[pn-1], px+1, py+1);
311                                 }
312 #endif
313                                 progress = TRUE;
314                                 solver->cube[fpos] = FALSE;
315                             }
316                     }
317                 }
318
319                 if (progress) {
320                     return +1;
321                 }
322             }
323         }
324
325         /*
326          * Binary increment: change the rightmost 0 to a 1, and
327          * change all 1s to the right of it to 0s.
328          */
329         i = n;
330         while (i > 0 && set[i-1])
331             set[--i] = 0, count--;
332         if (i > 0)
333             set[--i] = 1, count++;
334         else
335             break;                     /* done */
336     }
337
338     return 0;
339 }
340
341 /*
342  * Look for forcing chains. A forcing chain is a path of
343  * pairwise-exclusive squares (i.e. each pair of adjacent squares
344  * in the path are in the same row, column or block) with the
345  * following properties:
346  *
347  *  (a) Each square on the path has precisely two possible numbers.
348  *
349  *  (b) Each pair of squares which are adjacent on the path share
350  *      at least one possible number in common.
351  *
352  *  (c) Each square in the middle of the path shares _both_ of its
353  *      numbers with at least one of its neighbours (not the same
354  *      one with both neighbours).
355  *
356  * These together imply that at least one of the possible number
357  * choices at one end of the path forces _all_ the rest of the
358  * numbers along the path. In order to make real use of this, we
359  * need further properties:
360  *
361  *  (c) Ruling out some number N from the square at one end
362  *      of the path forces the square at the other end to
363  *      take number N.
364  *
365  *  (d) The two end squares are both in line with some third
366  *      square.
367  *
368  *  (e) That third square currently has N as a possibility.
369  *
370  * If we can find all of that lot, we can deduce that at least one
371  * of the two ends of the forcing chain has number N, and that
372  * therefore the mutually adjacent third square does not.
373  *
374  * To find forcing chains, we're going to start a bfs at each
375  * suitable square, once for each of its two possible numbers.
376  */
377 int latin_solver_forcing(struct latin_solver *solver,
378                          struct latin_solver_scratch *scratch)
379 {
380     int o = solver->o;
381 #ifdef STANDALONE_SOLVER
382     char **names = solver->names;
383 #endif
384     int *bfsqueue = scratch->bfsqueue;
385 #ifdef STANDALONE_SOLVER
386     int *bfsprev = scratch->bfsprev;
387 #endif
388     unsigned char *number = scratch->grid;
389     int *neighbours = scratch->neighbours;
390     int x, y;
391
392     for (y = 0; y < o; y++)
393         for (x = 0; x < o; x++) {
394             int count, t, n;
395
396             /*
397              * If this square doesn't have exactly two candidate
398              * numbers, don't try it.
399              *
400              * In this loop we also sum the candidate numbers,
401              * which is a nasty hack to allow us to quickly find
402              * `the other one' (since we will shortly know there
403              * are exactly two).
404              */
405             for (count = t = 0, n = 1; n <= o; n++)
406                 if (cube(x, y, n))
407                     count++, t += n;
408             if (count != 2)
409                 continue;
410
411             /*
412              * Now attempt a bfs for each candidate.
413              */
414             for (n = 1; n <= o; n++)
415                 if (cube(x, y, n)) {
416                     int orign, currn, head, tail;
417
418                     /*
419                      * Begin a bfs.
420                      */
421                     orign = n;
422
423                     memset(number, o+1, o*o);
424                     head = tail = 0;
425                     bfsqueue[tail++] = y*o+x;
426 #ifdef STANDALONE_SOLVER
427                     bfsprev[y*o+x] = -1;
428 #endif
429                     number[y*o+x] = t - n;
430
431                     while (head < tail) {
432                         int xx, yy, nneighbours, xt, yt, i;
433
434                         xx = bfsqueue[head++];
435                         yy = xx / o;
436                         xx %= o;
437
438                         currn = number[yy*o+xx];
439
440                         /*
441                          * Find neighbours of yy,xx.
442                          */
443                         nneighbours = 0;
444                         for (yt = 0; yt < o; yt++)
445                             neighbours[nneighbours++] = yt*o+xx;
446                         for (xt = 0; xt < o; xt++)
447                             neighbours[nneighbours++] = yy*o+xt;
448
449                         /*
450                          * Try visiting each of those neighbours.
451                          */
452                         for (i = 0; i < nneighbours; i++) {
453                             int cc, tt, nn;
454
455                             xt = neighbours[i] % o;
456                             yt = neighbours[i] / o;
457
458                             /*
459                              * We need this square to not be
460                              * already visited, and to include
461                              * currn as a possible number.
462                              */
463                             if (number[yt*o+xt] <= o)
464                                 continue;
465                             if (!cube(xt, yt, currn))
466                                 continue;
467
468                             /*
469                              * Don't visit _this_ square a second
470                              * time!
471                              */
472                             if (xt == xx && yt == yy)
473                                 continue;
474
475                             /*
476                              * To continue with the bfs, we need
477                              * this square to have exactly two
478                              * possible numbers.
479                              */
480                             for (cc = tt = 0, nn = 1; nn <= o; nn++)
481                                 if (cube(xt, yt, nn))
482                                     cc++, tt += nn;
483                             if (cc == 2) {
484                                 bfsqueue[tail++] = yt*o+xt;
485 #ifdef STANDALONE_SOLVER
486                                 bfsprev[yt*o+xt] = yy*o+xx;
487 #endif
488                                 number[yt*o+xt] = tt - currn;
489                             }
490
491                             /*
492                              * One other possibility is that this
493                              * might be the square in which we can
494                              * make a real deduction: if it's
495                              * adjacent to x,y, and currn is equal
496                              * to the original number we ruled out.
497                              */
498                             if (currn == orign &&
499                                 (xt == x || yt == y)) {
500 #ifdef STANDALONE_SOLVER
501                                 if (solver_show_working) {
502                                     char *sep = "";
503                                     int xl, yl;
504                                     printf("%*sforcing chain, %s at ends of ",
505                                            solver_recurse_depth*4, "",
506                                            names[orign-1]);
507                                     xl = xx;
508                                     yl = yy;
509                                     while (1) {
510                                         printf("%s(%d,%d)", sep, xl+1,
511                                                yl+1);
512                                         xl = bfsprev[yl*o+xl];
513                                         if (xl < 0)
514                                             break;
515                                         yl = xl / o;
516                                         xl %= o;
517                                         sep = "-";
518                                     }
519                                     printf("\n%*s  ruling out %s at (%d,%d)\n",
520                                            solver_recurse_depth*4, "",
521                                            names[orign-1],
522                                            xt+1, yt+1);
523                                 }
524 #endif
525                                 cube(xt, yt, orign) = FALSE;
526                                 return 1;
527                             }
528                         }
529                     }
530                 }
531         }
532
533     return 0;
534 }
535
536 struct latin_solver_scratch *latin_solver_new_scratch(struct latin_solver *solver)
537 {
538     struct latin_solver_scratch *scratch = snew(struct latin_solver_scratch);
539     int o = solver->o;
540     scratch->grid = snewn(o*o, unsigned char);
541     scratch->rowidx = snewn(o, unsigned char);
542     scratch->colidx = snewn(o, unsigned char);
543     scratch->set = snewn(o, unsigned char);
544     scratch->neighbours = snewn(3*o, int);
545     scratch->bfsqueue = snewn(o*o, int);
546 #ifdef STANDALONE_SOLVER
547     scratch->bfsprev = snewn(o*o, int);
548 #endif
549     return scratch;
550 }
551
552 void latin_solver_free_scratch(struct latin_solver_scratch *scratch)
553 {
554 #ifdef STANDALONE_SOLVER
555     sfree(scratch->bfsprev);
556 #endif
557     sfree(scratch->bfsqueue);
558     sfree(scratch->neighbours);
559     sfree(scratch->set);
560     sfree(scratch->colidx);
561     sfree(scratch->rowidx);
562     sfree(scratch->grid);
563     sfree(scratch);
564 }
565
566 void latin_solver_alloc(struct latin_solver *solver, digit *grid, int o)
567 {
568     int x, y;
569
570     solver->o = o;
571     solver->cube = snewn(o*o*o, unsigned char);
572     solver->grid = grid;                /* write straight back to the input */
573     memset(solver->cube, TRUE, o*o*o);
574
575     solver->row = snewn(o*o, unsigned char);
576     solver->col = snewn(o*o, unsigned char);
577     memset(solver->row, FALSE, o*o);
578     memset(solver->col, FALSE, o*o);
579
580     for (x = 0; x < o; x++)
581         for (y = 0; y < o; y++)
582             if (grid[y*o+x])
583                 latin_solver_place(solver, x, y, grid[y*o+x]);
584
585 #ifdef STANDALONE_SOLVER
586     solver->names = NULL;
587 #endif
588 }
589
590 void latin_solver_free(struct latin_solver *solver)
591 {
592     sfree(solver->cube);
593     sfree(solver->row);
594     sfree(solver->col);
595 }
596
597 int latin_solver_diff_simple(struct latin_solver *solver)
598 {
599     int x, y, n, ret, o = solver->o;
600 #ifdef STANDALONE_SOLVER
601     char **names = solver->names;
602 #endif
603
604     /*
605      * Row-wise positional elimination.
606      */
607     for (y = 0; y < o; y++)
608         for (n = 1; n <= o; n++)
609             if (!solver->row[y*o+n-1]) {
610                 ret = latin_solver_elim(solver, cubepos(0,y,n), o*o
611 #ifdef STANDALONE_SOLVER
612                                         , "positional elimination,"
613                                         " %s in row %d", names[n-1],
614                                         y+1
615 #endif
616                                         );
617                 if (ret != 0) return ret;
618             }
619     /*
620      * Column-wise positional elimination.
621      */
622     for (x = 0; x < o; x++)
623         for (n = 1; n <= o; n++)
624             if (!solver->col[x*o+n-1]) {
625                 ret = latin_solver_elim(solver, cubepos(x,0,n), o
626 #ifdef STANDALONE_SOLVER
627                                         , "positional elimination,"
628                                         " %s in column %d", names[n-1], x+1
629 #endif
630                                         );
631                 if (ret != 0) return ret;
632             }
633
634     /*
635      * Numeric elimination.
636      */
637     for (x = 0; x < o; x++)
638         for (y = 0; y < o; y++)
639             if (!solver->grid[y*o+x]) {
640                 ret = latin_solver_elim(solver, cubepos(x,y,1), 1
641 #ifdef STANDALONE_SOLVER
642                                         , "numeric elimination at (%d,%d)",
643                                         x+1, y+1
644 #endif
645                                         );
646                 if (ret != 0) return ret;
647             }
648     return 0;
649 }
650
651 int latin_solver_diff_set(struct latin_solver *solver,
652                           struct latin_solver_scratch *scratch,
653                           int extreme)
654 {
655     int x, y, n, ret, o = solver->o;
656 #ifdef STANDALONE_SOLVER
657     char **names = solver->names;
658 #endif
659
660     if (!extreme) {
661         /*
662          * Row-wise set elimination.
663          */
664         for (y = 0; y < o; y++) {
665             ret = latin_solver_set(solver, scratch, cubepos(0,y,1), o*o, 1
666 #ifdef STANDALONE_SOLVER
667                                    , "set elimination, row %d", y+1
668 #endif
669                                   );
670             if (ret != 0) return ret;
671         }
672         /*
673          * Column-wise set elimination.
674          */
675         for (x = 0; x < o; x++) {
676             ret = latin_solver_set(solver, scratch, cubepos(x,0,1), o, 1
677 #ifdef STANDALONE_SOLVER
678                                    , "set elimination, column %d", x+1
679 #endif
680                                   );
681             if (ret != 0) return ret;
682         }
683     } else {
684         /*
685          * Row-vs-column set elimination on a single number
686          * (much tricker for a human to do!)
687          */
688         for (n = 1; n <= o; n++) {
689             ret = latin_solver_set(solver, scratch, cubepos(0,0,n), o*o, o
690 #ifdef STANDALONE_SOLVER
691                                    , "positional set elimination on %s",
692                                    names[n-1]
693 #endif
694                                   );
695             if (ret != 0) return ret;
696         }
697     }
698     return 0;
699 }
700
701 /*
702  * Returns:
703  * 0 for 'didn't do anything' implying it was already solved.
704  * -1 for 'impossible' (no solution)
705  * 1 for 'single solution'
706  * >1 for 'multiple solutions' (you don't get to know how many, and
707  *     the first such solution found will be set.
708  *
709  * and this function may well assert if given an impossible board.
710  */
711 static int latin_solver_recurse
712     (struct latin_solver *solver, int diff_simple, int diff_set_0,
713      int diff_set_1, int diff_forcing, int diff_recursive,
714      usersolver_t const *usersolvers, void *ctx,
715      ctxnew_t ctxnew, ctxfree_t ctxfree)
716 {
717     int best, bestcount;
718     int o = solver->o, x, y, n;
719 #ifdef STANDALONE_SOLVER
720     char **names = solver->names;
721 #endif
722
723     best = -1;
724     bestcount = o+1;
725
726     for (y = 0; y < o; y++)
727         for (x = 0; x < o; x++)
728             if (!solver->grid[y*o+x]) {
729                 int count;
730
731                 /*
732                  * An unfilled square. Count the number of
733                  * possible digits in it.
734                  */
735                 count = 0;
736                 for (n = 1; n <= o; n++)
737                     if (cube(x,y,n))
738                         count++;
739
740                 /*
741                  * We should have found any impossibilities
742                  * already, so this can safely be an assert.
743                  */
744                 assert(count > 1);
745
746                 if (count < bestcount) {
747                     bestcount = count;
748                     best = y*o+x;
749                 }
750             }
751
752     if (best == -1)
753         /* we were complete already. */
754         return 0;
755     else {
756         int i, j;
757         digit *list, *ingrid, *outgrid;
758         int diff = diff_impossible;    /* no solution found yet */
759
760         /*
761          * Attempt recursion.
762          */
763         y = best / o;
764         x = best % o;
765
766         list = snewn(o, digit);
767         ingrid = snewn(o*o, digit);
768         outgrid = snewn(o*o, digit);
769         memcpy(ingrid, solver->grid, o*o);
770
771         /* Make a list of the possible digits. */
772         for (j = 0, n = 1; n <= o; n++)
773             if (cube(x,y,n))
774                 list[j++] = n;
775
776 #ifdef STANDALONE_SOLVER
777         if (solver_show_working) {
778             char *sep = "";
779             printf("%*srecursing on (%d,%d) [",
780                    solver_recurse_depth*4, "", x+1, y+1);
781             for (i = 0; i < j; i++) {
782                 printf("%s%s", sep, names[list[i]-1]);
783                 sep = " or ";
784             }
785             printf("]\n");
786         }
787 #endif
788
789         /*
790          * And step along the list, recursing back into the
791          * main solver at every stage.
792          */
793         for (i = 0; i < j; i++) {
794             int ret;
795             void *newctx;
796             struct latin_solver subsolver;
797
798             memcpy(outgrid, ingrid, o*o);
799             outgrid[y*o+x] = list[i];
800
801 #ifdef STANDALONE_SOLVER
802             if (solver_show_working)
803                 printf("%*sguessing %s at (%d,%d)\n",
804                        solver_recurse_depth*4, "", names[list[i]-1], x+1, y+1);
805             solver_recurse_depth++;
806 #endif
807
808             if (ctxnew) {
809                 newctx = ctxnew(ctx);
810             } else {
811                 newctx = ctx;
812             }
813             latin_solver_alloc(&subsolver, outgrid, o);
814 #ifdef STANDALONE_SOLVER
815             subsolver.names = solver->names;
816 #endif
817             ret = latin_solver_top(&subsolver, diff_recursive,
818                                    diff_simple, diff_set_0, diff_set_1,
819                                    diff_forcing, diff_recursive,
820                                    usersolvers, newctx, ctxnew, ctxfree);
821             latin_solver_free(&subsolver);
822             if (ctxnew)
823                 ctxfree(newctx);
824
825 #ifdef STANDALONE_SOLVER
826             solver_recurse_depth--;
827             if (solver_show_working) {
828                 printf("%*sretracting %s at (%d,%d)\n",
829                        solver_recurse_depth*4, "", names[list[i]-1], x+1, y+1);
830             }
831 #endif
832             /* we recurse as deep as we can, so we should never find
833              * find ourselves giving up on a puzzle without declaring it
834              * impossible.  */
835             assert(ret != diff_unfinished);
836
837             /*
838              * If we have our first solution, copy it into the
839              * grid we will return.
840              */
841             if (diff == diff_impossible && ret != diff_impossible)
842                 memcpy(solver->grid, outgrid, o*o);
843
844             if (ret == diff_ambiguous)
845                 diff = diff_ambiguous;
846             else if (ret == diff_impossible)
847                 /* do not change our return value */;
848             else {
849                 /* the recursion turned up exactly one solution */
850                 if (diff == diff_impossible)
851                     diff = diff_recursive;
852                 else
853                     diff = diff_ambiguous;
854             }
855
856             /*
857              * As soon as we've found more than one solution,
858              * give up immediately.
859              */
860             if (diff == diff_ambiguous)
861                 break;
862         }
863
864         sfree(outgrid);
865         sfree(ingrid);
866         sfree(list);
867
868         if (diff == diff_impossible)
869             return -1;
870         else if (diff == diff_ambiguous)
871             return 2;
872         else {
873             assert(diff == diff_recursive);
874             return 1;
875         }
876     }
877 }
878
879 static int latin_solver_top(struct latin_solver *solver, int maxdiff,
880                             int diff_simple, int diff_set_0, int diff_set_1,
881                             int diff_forcing, int diff_recursive,
882                             usersolver_t const *usersolvers, void *ctx,
883                             ctxnew_t ctxnew, ctxfree_t ctxfree)
884 {
885     struct latin_solver_scratch *scratch = latin_solver_new_scratch(solver);
886     int ret, diff = diff_simple;
887
888     assert(maxdiff <= diff_recursive);
889     /*
890      * Now loop over the grid repeatedly trying all permitted modes
891      * of reasoning. The loop terminates if we complete an
892      * iteration without making any progress; we then return
893      * failure or success depending on whether the grid is full or
894      * not.
895      */
896     while (1) {
897         int i;
898
899         cont:
900
901         latin_solver_debug(solver->cube, solver->o);
902
903         for (i = 0; i <= maxdiff; i++) {
904             if (usersolvers[i])
905                 ret = usersolvers[i](solver, ctx);
906             else
907                 ret = 0;
908             if (ret == 0 && i == diff_simple)
909                 ret = latin_solver_diff_simple(solver);
910             if (ret == 0 && i == diff_set_0)
911                 ret = latin_solver_diff_set(solver, scratch, 0);
912             if (ret == 0 && i == diff_set_1)
913                 ret = latin_solver_diff_set(solver, scratch, 1);
914             if (ret == 0 && i == diff_forcing)
915                 ret = latin_solver_forcing(solver, scratch);
916
917             if (ret < 0) {
918                 diff = diff_impossible;
919                 goto got_result;
920             } else if (ret > 0) {
921                 diff = max(diff, i);
922                 goto cont;
923             }
924         }
925
926         /*
927          * If we reach here, we have made no deductions in this
928          * iteration, so the algorithm terminates.
929          */
930         break;
931     }
932
933     /*
934      * Last chance: if we haven't fully solved the puzzle yet, try
935      * recursing based on guesses for a particular square. We pick
936      * one of the most constrained empty squares we can find, which
937      * has the effect of pruning the search tree as much as
938      * possible.
939      */
940     if (maxdiff == diff_recursive) {
941         int nsol = latin_solver_recurse(solver,
942                                         diff_simple, diff_set_0, diff_set_1,
943                                         diff_forcing, diff_recursive,
944                                         usersolvers, ctx, ctxnew, ctxfree);
945         if (nsol < 0) diff = diff_impossible;
946         else if (nsol == 1) diff = diff_recursive;
947         else if (nsol > 1) diff = diff_ambiguous;
948         /* if nsol == 0 then we were complete anyway
949          * (and thus don't need to change diff) */
950     } else {
951         /*
952          * We're forbidden to use recursion, so we just see whether
953          * our grid is fully solved, and return diff_unfinished
954          * otherwise.
955          */
956         int x, y, o = solver->o;
957
958         for (y = 0; y < o; y++)
959             for (x = 0; x < o; x++)
960                 if (!solver->grid[y*o+x])
961                     diff = diff_unfinished;
962     }
963
964     got_result:
965
966 #ifdef STANDALONE_SOLVER
967     if (solver_show_working)
968         printf("%*s%s found\n",
969                solver_recurse_depth*4, "",
970                diff == diff_impossible ? "no solution (impossible)" :
971                diff == diff_unfinished ? "no solution (unfinished)" :
972                diff == diff_ambiguous ? "multiple solutions" :
973                "one solution");
974 #endif
975
976     latin_solver_free_scratch(scratch);
977
978     return diff;
979 }
980
981 int latin_solver_main(struct latin_solver *solver, int maxdiff,
982                       int diff_simple, int diff_set_0, int diff_set_1,
983                       int diff_forcing, int diff_recursive,
984                       usersolver_t const *usersolvers, void *ctx,
985                       ctxnew_t ctxnew, ctxfree_t ctxfree)
986 {
987     int diff;
988 #ifdef STANDALONE_SOLVER
989     int o = solver->o;
990     char *text = NULL, **names = NULL;
991 #endif
992
993 #ifdef STANDALONE_SOLVER
994     if (!solver->names) {
995         char *p;
996         int i;
997
998         text = snewn(40 * o, char);
999         p = text;
1000
1001         solver->names = snewn(o, char *);
1002
1003         for (i = 0; i < o; i++) {
1004             solver->names[i] = p;
1005             p += 1 + sprintf(p, "%d", i+1);
1006         }
1007     }
1008 #endif
1009
1010     diff = latin_solver_top(solver, maxdiff,
1011                             diff_simple, diff_set_0, diff_set_1,
1012                             diff_forcing, diff_recursive,
1013                             usersolvers, ctx, ctxnew, ctxfree);
1014
1015 #ifdef STANDALONE_SOLVER
1016     sfree(names);
1017     sfree(text);
1018 #endif
1019
1020     return diff;
1021 }
1022
1023 int latin_solver(digit *grid, int o, int maxdiff,
1024                  int diff_simple, int diff_set_0, int diff_set_1,
1025                  int diff_forcing, int diff_recursive,
1026                  usersolver_t const *usersolvers, void *ctx,
1027                  ctxnew_t ctxnew, ctxfree_t ctxfree)
1028 {
1029     struct latin_solver solver;
1030     int diff;
1031
1032     latin_solver_alloc(&solver, grid, o);
1033     diff = latin_solver_main(&solver, maxdiff,
1034                              diff_simple, diff_set_0, diff_set_1,
1035                              diff_forcing, diff_recursive,
1036                              usersolvers, ctx, ctxnew, ctxfree);
1037     latin_solver_free(&solver);
1038     return diff;
1039 }
1040
1041 void latin_solver_debug(unsigned char *cube, int o)
1042 {
1043 #ifdef STANDALONE_SOLVER
1044     if (solver_show_working > 1) {
1045         struct latin_solver ls, *solver = &ls;
1046         char *dbg;
1047         int x, y, i, c = 0;
1048
1049         ls.cube = cube; ls.o = o; /* for cube() to work */
1050
1051         dbg = snewn(3*o*o*o, char);
1052         for (y = 0; y < o; y++) {
1053             for (x = 0; x < o; x++) {
1054                 for (i = 1; i <= o; i++) {
1055                     if (cube(x,y,i))
1056                         dbg[c++] = i + '0';
1057                     else
1058                         dbg[c++] = '.';
1059                 }
1060                 dbg[c++] = ' ';
1061             }
1062             dbg[c++] = '\n';
1063         }
1064         dbg[c++] = '\n';
1065         dbg[c++] = '\0';
1066
1067         printf("%s", dbg);
1068         sfree(dbg);
1069     }
1070 #endif
1071 }
1072
1073 void latin_debug(digit *sq, int o)
1074 {
1075 #ifdef STANDALONE_SOLVER
1076     if (solver_show_working) {
1077         int x, y;
1078
1079         for (y = 0; y < o; y++) {
1080             for (x = 0; x < o; x++) {
1081                 printf("%2d ", sq[y*o+x]);
1082             }
1083             printf("\n");
1084         }
1085         printf("\n");
1086     }
1087 #endif
1088 }
1089
1090 /* --------------------------------------------------------
1091  * Generation.
1092  */
1093
1094 digit *latin_generate(int o, random_state *rs)
1095 {
1096     digit *sq;
1097     int *edges, *backedges, *capacity, *flow;
1098     void *scratch;
1099     int ne, scratchsize;
1100     int i, j, k;
1101     digit *row, *col, *numinv, *num;
1102
1103     /*
1104      * To efficiently generate a latin square in such a way that
1105      * all possible squares are possible outputs from the function,
1106      * we make use of a theorem which states that any r x n latin
1107      * rectangle, with r < n, can be extended into an (r+1) x n
1108      * latin rectangle. In other words, we can reliably generate a
1109      * latin square row by row, by at every stage writing down any
1110      * row at all which doesn't conflict with previous rows, and
1111      * the theorem guarantees that we will never have to backtrack.
1112      *
1113      * To find a viable row at each stage, we can make use of the
1114      * support functions in maxflow.c.
1115      */
1116
1117     sq = snewn(o*o, digit);
1118
1119     /*
1120      * In case this method of generation introduces a really subtle
1121      * top-to-bottom directional bias, we'll generate the rows in
1122      * random order.
1123      */
1124     row = snewn(o, digit);
1125     col = snewn(o, digit);
1126     numinv = snewn(o, digit);
1127     num = snewn(o, digit);
1128     for (i = 0; i < o; i++)
1129         row[i] = i;
1130     shuffle(row, i, sizeof(*row), rs);
1131
1132     /*
1133      * Set up the infrastructure for the maxflow algorithm.
1134      */
1135     scratchsize = maxflow_scratch_size(o * 2 + 2);
1136     scratch = smalloc(scratchsize);
1137     backedges = snewn(o*o + 2*o, int);
1138     edges = snewn((o*o + 2*o) * 2, int);
1139     capacity = snewn(o*o + 2*o, int);
1140     flow = snewn(o*o + 2*o, int);
1141     /* Set up the edge array, and the initial capacities. */
1142     ne = 0;
1143     for (i = 0; i < o; i++) {
1144         /* Each LHS vertex is connected to all RHS vertices. */
1145         for (j = 0; j < o; j++) {
1146             edges[ne*2] = i;
1147             edges[ne*2+1] = j+o;
1148             /* capacity for this edge is set later on */
1149             ne++;
1150         }
1151     }
1152     for (i = 0; i < o; i++) {
1153         /* Each RHS vertex is connected to the distinguished sink vertex. */
1154         edges[ne*2] = i+o;
1155         edges[ne*2+1] = o*2+1;
1156         capacity[ne] = 1;
1157         ne++;
1158     }
1159     for (i = 0; i < o; i++) {
1160         /* And the distinguished source vertex connects to each LHS vertex. */
1161         edges[ne*2] = o*2;
1162         edges[ne*2+1] = i;
1163         capacity[ne] = 1;
1164         ne++;
1165     }
1166     assert(ne == o*o + 2*o);
1167     /* Now set up backedges. */
1168     maxflow_setup_backedges(ne, edges, backedges);
1169     
1170     /*
1171      * Now generate each row of the latin square.
1172      */
1173     for (i = 0; i < o; i++) {
1174         /*
1175          * To prevent maxflow from behaving deterministically, we
1176          * separately permute the columns and the digits for the
1177          * purposes of the algorithm, differently for every row.
1178          */
1179         for (j = 0; j < o; j++)
1180             col[j] = num[j] = j;
1181         shuffle(col, j, sizeof(*col), rs);
1182         shuffle(num, j, sizeof(*num), rs);
1183         /* We need the num permutation in both forward and inverse forms. */
1184         for (j = 0; j < o; j++)
1185             numinv[num[j]] = j;
1186
1187         /*
1188          * Set up the capacities for the maxflow run, by examining
1189          * the existing latin square.
1190          */
1191         for (j = 0; j < o*o; j++)
1192             capacity[j] = 1;
1193         for (j = 0; j < i; j++)
1194             for (k = 0; k < o; k++) {
1195                 int n = num[sq[row[j]*o + col[k]] - 1];
1196                 capacity[k*o+n] = 0;
1197             }
1198
1199         /*
1200          * Run maxflow.
1201          */
1202         j = maxflow_with_scratch(scratch, o*2+2, 2*o, 2*o+1, ne,
1203                                  edges, backedges, capacity, flow, NULL);
1204         assert(j == o);   /* by the above theorem, this must have succeeded */
1205
1206         /*
1207          * And examine the flow array to pick out the new row of
1208          * the latin square.
1209          */
1210         for (j = 0; j < o; j++) {
1211             for (k = 0; k < o; k++) {
1212                 if (flow[j*o+k])
1213                     break;
1214             }
1215             assert(k < o);
1216             sq[row[i]*o + col[j]] = numinv[k] + 1;
1217         }
1218     }
1219
1220     /*
1221      * Done. Free our internal workspaces...
1222      */
1223     sfree(flow);
1224     sfree(capacity);
1225     sfree(edges);
1226     sfree(backedges);
1227     sfree(scratch);
1228     sfree(numinv);
1229     sfree(num);
1230     sfree(col);
1231     sfree(row);
1232
1233     /*
1234      * ... and return our completed latin square.
1235      */
1236     return sq;
1237 }
1238
1239 digit *latin_generate_rect(int w, int h, random_state *rs)
1240 {
1241     int o = max(w, h), x, y;
1242     digit *latin, *latin_rect;
1243
1244     latin = latin_generate(o, rs);
1245     latin_rect = snewn(w*h, digit);
1246
1247     for (x = 0; x < w; x++) {
1248         for (y = 0; y < h; y++) {
1249             latin_rect[y*w + x] = latin[y*o + x];
1250         }
1251     }
1252
1253     sfree(latin);
1254     return latin_rect;
1255 }
1256
1257 /* --------------------------------------------------------
1258  * Checking.
1259  */
1260
1261 typedef struct lcparams {
1262     digit elt;
1263     int count;
1264 } lcparams;
1265
1266 static int latin_check_cmp(void *v1, void *v2)
1267 {
1268     lcparams *lc1 = (lcparams *)v1;
1269     lcparams *lc2 = (lcparams *)v2;
1270
1271     if (lc1->elt < lc2->elt) return -1;
1272     if (lc1->elt > lc2->elt) return 1;
1273     return 0;
1274 }
1275
1276 #define ELT(sq,x,y) (sq[((y)*order)+(x)])
1277
1278 /* returns non-zero if sq is not a latin square. */
1279 int latin_check(digit *sq, int order)
1280 {
1281     tree234 *dict = newtree234(latin_check_cmp);
1282     int c, r;
1283     int ret = 0;
1284     lcparams *lcp, lc, *aret;
1285
1286     /* Use a tree234 as a simple hash table, go through the square
1287      * adding elements as we go or incrementing their counts. */
1288     for (c = 0; c < order; c++) {
1289         for (r = 0; r < order; r++) {
1290             lc.elt = ELT(sq, c, r); lc.count = 0;
1291             lcp = find234(dict, &lc, NULL);
1292             if (!lcp) {
1293                 lcp = snew(lcparams);
1294                 lcp->elt = ELT(sq, c, r);
1295                 lcp->count = 1;
1296                 aret = add234(dict, lcp);
1297                 assert(aret == lcp);
1298             } else {
1299                 lcp->count++;
1300             }
1301         }
1302     }
1303
1304     /* There should be precisely 'order' letters in the alphabet,
1305      * each occurring 'order' times (making the OxO tree) */
1306     if (count234(dict) != order) ret = 1;
1307     else {
1308         for (c = 0; (lcp = index234(dict, c)) != NULL; c++) {
1309             if (lcp->count != order) ret = 1;
1310         }
1311     }
1312     for (c = 0; (lcp = index234(dict, c)) != NULL; c++)
1313         sfree(lcp);
1314     freetree234(dict);
1315
1316     return ret;
1317 }
1318
1319
1320 /* --------------------------------------------------------
1321  * Testing (and printing).
1322  */
1323
1324 #ifdef STANDALONE_LATIN_TEST
1325
1326 #include <stdio.h>
1327 #include <time.h>
1328
1329 const char *quis;
1330
1331 static void latin_print(digit *sq, int order)
1332 {
1333     int x, y;
1334
1335     for (y = 0; y < order; y++) {
1336         for (x = 0; x < order; x++) {
1337             printf("%2u ", ELT(sq, x, y));
1338         }
1339         printf("\n");
1340     }
1341     printf("\n");
1342 }
1343
1344 static void gen(int order, random_state *rs, int debug)
1345 {
1346     digit *sq;
1347
1348     solver_show_working = debug;
1349
1350     sq = latin_generate(order, rs);
1351     latin_print(sq, order);
1352     if (latin_check(sq, order)) {
1353         fprintf(stderr, "Square is not a latin square!");
1354         exit(1);
1355     }
1356
1357     sfree(sq);
1358 }
1359
1360 void test_soak(int order, random_state *rs)
1361 {
1362     digit *sq;
1363     int n = 0;
1364     time_t tt_start, tt_now, tt_last;
1365
1366     solver_show_working = 0;
1367     tt_now = tt_start = time(NULL);
1368
1369     while(1) {
1370         sq = latin_generate(order, rs);
1371         sfree(sq);
1372         n++;
1373
1374         tt_last = time(NULL);
1375         if (tt_last > tt_now) {
1376             tt_now = tt_last;
1377             printf("%d total, %3.1f/s\n", n,
1378                    (double)n / (double)(tt_now - tt_start));
1379         }
1380     }
1381 }
1382
1383 void usage_exit(const char *msg)
1384 {
1385     if (msg)
1386         fprintf(stderr, "%s: %s\n", quis, msg);
1387     fprintf(stderr, "Usage: %s [--seed SEED] --soak <params> | [game_id [game_id ...]]\n", quis);
1388     exit(1);
1389 }
1390
1391 int main(int argc, char *argv[])
1392 {
1393     int i, soak = 0;
1394     random_state *rs;
1395     time_t seed = time(NULL);
1396
1397     quis = argv[0];
1398     while (--argc > 0) {
1399         const char *p = *++argv;
1400         if (!strcmp(p, "--soak"))
1401             soak = 1;
1402         else if (!strcmp(p, "--seed")) {
1403             if (argc == 0)
1404                 usage_exit("--seed needs an argument");
1405             seed = (time_t)atoi(*++argv);
1406             argc--;
1407         } else if (*p == '-')
1408                 usage_exit("unrecognised option");
1409         else
1410             break; /* finished options */
1411     }
1412
1413     rs = random_new((void*)&seed, sizeof(time_t));
1414
1415     if (soak == 1) {
1416         if (argc != 1) usage_exit("only one argument for --soak");
1417         test_soak(atoi(*argv), rs);
1418     } else {
1419         if (argc > 0) {
1420             for (i = 0; i < argc; i++) {
1421                 gen(atoi(*argv++), rs, 1);
1422             }
1423         } else {
1424             while (1) {
1425                 i = random_upto(rs, 20) + 1;
1426                 gen(i, rs, 0);
1427             }
1428         }
1429     }
1430     random_free(rs);
1431     return 0;
1432 }
1433
1434 #endif
1435
1436 /* vim: set shiftwidth=4 tabstop=8: */