chiark / gitweb /
Tweak the semantics of dsf_merge() so that the canonical element of
[sgt-puzzles.git] / latin.c
1 #include <assert.h>
2 #include <string.h>
3 #include <stdarg.h>
4
5 #include "puzzles.h"
6 #include "tree234.h"
7 #include "maxflow.h"
8
9 #ifdef STANDALONE_LATIN_TEST
10 #define STANDALONE_SOLVER
11 #endif
12
13 #include "latin.h"
14
15 /* --------------------------------------------------------
16  * Solver.
17  */
18
19 /*
20  * Function called when we are certain that a particular square has
21  * a particular number in it. The y-coordinate passed in here is
22  * transformed.
23  */
24 void latin_solver_place(struct latin_solver *solver, int x, int y, int n)
25 {
26     int i, o = solver->o;
27
28     assert(n <= o);
29     assert(cube(x,y,n));
30
31     /*
32      * Rule out all other numbers in this square.
33      */
34     for (i = 1; i <= o; i++)
35         if (i != n)
36             cube(x,y,i) = FALSE;
37
38     /*
39      * Rule out this number in all other positions in the row.
40      */
41     for (i = 0; i < o; i++)
42         if (i != y)
43             cube(x,i,n) = FALSE;
44
45     /*
46      * Rule out this number in all other positions in the column.
47      */
48     for (i = 0; i < o; i++)
49         if (i != x)
50             cube(i,y,n) = FALSE;
51
52     /*
53      * Enter the number in the result grid.
54      */
55     solver->grid[YUNTRANS(y)*o+x] = n;
56
57     /*
58      * Cross out this number from the list of numbers left to place
59      * in its row, its column and its block.
60      */
61     solver->row[y*o+n-1] = solver->col[x*o+n-1] = TRUE;
62 }
63
64 int latin_solver_elim(struct latin_solver *solver, int start, int step
65 #ifdef STANDALONE_SOLVER
66                       , char *fmt, ...
67 #endif
68                       )
69 {
70     int o = solver->o;
71     int fpos, m, i;
72
73     /*
74      * Count the number of set bits within this section of the
75      * cube.
76      */
77     m = 0;
78     fpos = -1;
79     for (i = 0; i < o; i++)
80         if (solver->cube[start+i*step]) {
81             fpos = start+i*step;
82             m++;
83         }
84
85     if (m == 1) {
86         int x, y, n;
87         assert(fpos >= 0);
88
89         n = 1 + fpos % o;
90         y = fpos / o;
91         x = y / o;
92         y %= o;
93
94         if (!solver->grid[YUNTRANS(y)*o+x]) {
95 #ifdef STANDALONE_SOLVER
96             if (solver_show_working) {
97                 va_list ap;
98                 printf("%*s", solver_recurse_depth*4, "");
99                 va_start(ap, fmt);
100                 vprintf(fmt, ap);
101                 va_end(ap);
102                 printf(":\n%*s  placing %d at (%d,%d)\n",
103                        solver_recurse_depth*4, "", n, x, YUNTRANS(y));
104             }
105 #endif
106             latin_solver_place(solver, x, y, n);
107             return +1;
108         }
109     } else if (m == 0) {
110 #ifdef STANDALONE_SOLVER
111         if (solver_show_working) {
112             va_list ap;
113             printf("%*s", solver_recurse_depth*4, "");
114             va_start(ap, fmt);
115             vprintf(fmt, ap);
116             va_end(ap);
117             printf(":\n%*s  no possibilities available\n",
118                    solver_recurse_depth*4, "");
119         }
120 #endif
121         return -1;
122     }
123
124     return 0;
125 }
126
127 struct latin_solver_scratch {
128     unsigned char *grid, *rowidx, *colidx, *set;
129     int *neighbours, *bfsqueue;
130 #ifdef STANDALONE_SOLVER
131     int *bfsprev;
132 #endif
133 };
134
135 int latin_solver_set(struct latin_solver *solver,
136                      struct latin_solver_scratch *scratch,
137                      int start, int step1, int step2
138 #ifdef STANDALONE_SOLVER
139                      , char *fmt, ...
140 #endif
141                      )
142 {
143     int o = solver->o;
144     int i, j, n, count;
145     unsigned char *grid = scratch->grid;
146     unsigned char *rowidx = scratch->rowidx;
147     unsigned char *colidx = scratch->colidx;
148     unsigned char *set = scratch->set;
149
150     /*
151      * We are passed a o-by-o matrix of booleans. Our first job
152      * is to winnow it by finding any definite placements - i.e.
153      * any row with a solitary 1 - and discarding that row and the
154      * column containing the 1.
155      */
156     memset(rowidx, TRUE, o);
157     memset(colidx, TRUE, o);
158     for (i = 0; i < o; i++) {
159         int count = 0, first = -1;
160         for (j = 0; j < o; j++)
161             if (solver->cube[start+i*step1+j*step2])
162                 first = j, count++;
163
164         if (count == 0) return -1;
165         if (count == 1)
166             rowidx[i] = colidx[first] = FALSE;
167     }
168
169     /*
170      * Convert each of rowidx/colidx from a list of 0s and 1s to a
171      * list of the indices of the 1s.
172      */
173     for (i = j = 0; i < o; i++)
174         if (rowidx[i])
175             rowidx[j++] = i;
176     n = j;
177     for (i = j = 0; i < o; i++)
178         if (colidx[i])
179             colidx[j++] = i;
180     assert(n == j);
181
182     /*
183      * And create the smaller matrix.
184      */
185     for (i = 0; i < n; i++)
186         for (j = 0; j < n; j++)
187             grid[i*o+j] = solver->cube[start+rowidx[i]*step1+colidx[j]*step2];
188
189     /*
190      * Having done that, we now have a matrix in which every row
191      * has at least two 1s in. Now we search to see if we can find
192      * a rectangle of zeroes (in the set-theoretic sense of
193      * `rectangle', i.e. a subset of rows crossed with a subset of
194      * columns) whose width and height add up to n.
195      */
196
197     memset(set, 0, n);
198     count = 0;
199     while (1) {
200         /*
201          * We have a candidate set. If its size is <=1 or >=n-1
202          * then we move on immediately.
203          */
204         if (count > 1 && count < n-1) {
205             /*
206              * The number of rows we need is n-count. See if we can
207              * find that many rows which each have a zero in all
208              * the positions listed in `set'.
209              */
210             int rows = 0;
211             for (i = 0; i < n; i++) {
212                 int ok = TRUE;
213                 for (j = 0; j < n; j++)
214                     if (set[j] && grid[i*o+j]) {
215                         ok = FALSE;
216                         break;
217                     }
218                 if (ok)
219                     rows++;
220             }
221
222             /*
223              * We expect never to be able to get _more_ than
224              * n-count suitable rows: this would imply that (for
225              * example) there are four numbers which between them
226              * have at most three possible positions, and hence it
227              * indicates a faulty deduction before this point or
228              * even a bogus clue.
229              */
230             if (rows > n - count) {
231 #ifdef STANDALONE_SOLVER
232                 if (solver_show_working) {
233                     va_list ap;
234                     printf("%*s", solver_recurse_depth*4,
235                            "");
236                     va_start(ap, fmt);
237                     vprintf(fmt, ap);
238                     va_end(ap);
239                     printf(":\n%*s  contradiction reached\n",
240                            solver_recurse_depth*4, "");
241                 }
242 #endif
243                 return -1;
244             }
245
246             if (rows >= n - count) {
247                 int progress = FALSE;
248
249                 /*
250                  * We've got one! Now, for each row which _doesn't_
251                  * satisfy the criterion, eliminate all its set
252                  * bits in the positions _not_ listed in `set'.
253                  * Return +1 (meaning progress has been made) if we
254                  * successfully eliminated anything at all.
255                  *
256                  * This involves referring back through
257                  * rowidx/colidx in order to work out which actual
258                  * positions in the cube to meddle with.
259                  */
260                 for (i = 0; i < n; i++) {
261                     int ok = TRUE;
262                     for (j = 0; j < n; j++)
263                         if (set[j] && grid[i*o+j]) {
264                             ok = FALSE;
265                             break;
266                         }
267                     if (!ok) {
268                         for (j = 0; j < n; j++)
269                             if (!set[j] && grid[i*o+j]) {
270                                 int fpos = (start+rowidx[i]*step1+
271                                             colidx[j]*step2);
272 #ifdef STANDALONE_SOLVER
273                                 if (solver_show_working) {
274                                     int px, py, pn;
275
276                                     if (!progress) {
277                                         va_list ap;
278                                         printf("%*s", solver_recurse_depth*4,
279                                                "");
280                                         va_start(ap, fmt);
281                                         vprintf(fmt, ap);
282                                         va_end(ap);
283                                         printf(":\n");
284                                     }
285
286                                     pn = 1 + fpos % o;
287                                     py = fpos / o;
288                                     px = py / o;
289                                     py %= o;
290
291                                     printf("%*s  ruling out %d at (%d,%d)\n",
292                                            solver_recurse_depth*4, "",
293                                            pn, px, YUNTRANS(py));
294                                 }
295 #endif
296                                 progress = TRUE;
297                                 solver->cube[fpos] = FALSE;
298                             }
299                     }
300                 }
301
302                 if (progress) {
303                     return +1;
304                 }
305             }
306         }
307
308         /*
309          * Binary increment: change the rightmost 0 to a 1, and
310          * change all 1s to the right of it to 0s.
311          */
312         i = n;
313         while (i > 0 && set[i-1])
314             set[--i] = 0, count--;
315         if (i > 0)
316             set[--i] = 1, count++;
317         else
318             break;                     /* done */
319     }
320
321     return 0;
322 }
323
324 /*
325  * Look for forcing chains. A forcing chain is a path of
326  * pairwise-exclusive squares (i.e. each pair of adjacent squares
327  * in the path are in the same row, column or block) with the
328  * following properties:
329  *
330  *  (a) Each square on the path has precisely two possible numbers.
331  *
332  *  (b) Each pair of squares which are adjacent on the path share
333  *      at least one possible number in common.
334  *
335  *  (c) Each square in the middle of the path shares _both_ of its
336  *      numbers with at least one of its neighbours (not the same
337  *      one with both neighbours).
338  *
339  * These together imply that at least one of the possible number
340  * choices at one end of the path forces _all_ the rest of the
341  * numbers along the path. In order to make real use of this, we
342  * need further properties:
343  *
344  *  (c) Ruling out some number N from the square at one end
345  *      of the path forces the square at the other end to
346  *      take number N.
347  *
348  *  (d) The two end squares are both in line with some third
349  *      square.
350  *
351  *  (e) That third square currently has N as a possibility.
352  *
353  * If we can find all of that lot, we can deduce that at least one
354  * of the two ends of the forcing chain has number N, and that
355  * therefore the mutually adjacent third square does not.
356  *
357  * To find forcing chains, we're going to start a bfs at each
358  * suitable square, once for each of its two possible numbers.
359  */
360 int latin_solver_forcing(struct latin_solver *solver,
361                          struct latin_solver_scratch *scratch)
362 {
363     int o = solver->o;
364     int *bfsqueue = scratch->bfsqueue;
365 #ifdef STANDALONE_SOLVER
366     int *bfsprev = scratch->bfsprev;
367 #endif
368     unsigned char *number = scratch->grid;
369     int *neighbours = scratch->neighbours;
370     int x, y;
371
372     for (y = 0; y < o; y++)
373         for (x = 0; x < o; x++) {
374             int count, t, n;
375
376             /*
377              * If this square doesn't have exactly two candidate
378              * numbers, don't try it.
379              *
380              * In this loop we also sum the candidate numbers,
381              * which is a nasty hack to allow us to quickly find
382              * `the other one' (since we will shortly know there
383              * are exactly two).
384              */
385             for (count = t = 0, n = 1; n <= o; n++)
386                 if (cube(x, y, n))
387                     count++, t += n;
388             if (count != 2)
389                 continue;
390
391             /*
392              * Now attempt a bfs for each candidate.
393              */
394             for (n = 1; n <= o; n++)
395                 if (cube(x, y, n)) {
396                     int orign, currn, head, tail;
397
398                     /*
399                      * Begin a bfs.
400                      */
401                     orign = n;
402
403                     memset(number, o+1, o*o);
404                     head = tail = 0;
405                     bfsqueue[tail++] = y*o+x;
406 #ifdef STANDALONE_SOLVER
407                     bfsprev[y*o+x] = -1;
408 #endif
409                     number[y*o+x] = t - n;
410
411                     while (head < tail) {
412                         int xx, yy, nneighbours, xt, yt, i;
413
414                         xx = bfsqueue[head++];
415                         yy = xx / o;
416                         xx %= o;
417
418                         currn = number[yy*o+xx];
419
420                         /*
421                          * Find neighbours of yy,xx.
422                          */
423                         nneighbours = 0;
424                         for (yt = 0; yt < o; yt++)
425                             neighbours[nneighbours++] = yt*o+xx;
426                         for (xt = 0; xt < o; xt++)
427                             neighbours[nneighbours++] = yy*o+xt;
428
429                         /*
430                          * Try visiting each of those neighbours.
431                          */
432                         for (i = 0; i < nneighbours; i++) {
433                             int cc, tt, nn;
434
435                             xt = neighbours[i] % o;
436                             yt = neighbours[i] / o;
437
438                             /*
439                              * We need this square to not be
440                              * already visited, and to include
441                              * currn as a possible number.
442                              */
443                             if (number[yt*o+xt] <= o)
444                                 continue;
445                             if (!cube(xt, yt, currn))
446                                 continue;
447
448                             /*
449                              * Don't visit _this_ square a second
450                              * time!
451                              */
452                             if (xt == xx && yt == yy)
453                                 continue;
454
455                             /*
456                              * To continue with the bfs, we need
457                              * this square to have exactly two
458                              * possible numbers.
459                              */
460                             for (cc = tt = 0, nn = 1; nn <= o; nn++)
461                                 if (cube(xt, yt, nn))
462                                     cc++, tt += nn;
463                             if (cc == 2) {
464                                 bfsqueue[tail++] = yt*o+xt;
465 #ifdef STANDALONE_SOLVER
466                                 bfsprev[yt*o+xt] = yy*o+xx;
467 #endif
468                                 number[yt*o+xt] = tt - currn;
469                             }
470
471                             /*
472                              * One other possibility is that this
473                              * might be the square in which we can
474                              * make a real deduction: if it's
475                              * adjacent to x,y, and currn is equal
476                              * to the original number we ruled out.
477                              */
478                             if (currn == orign &&
479                                 (xt == x || yt == y)) {
480 #ifdef STANDALONE_SOLVER
481                                 if (solver_show_working) {
482                                     char *sep = "";
483                                     int xl, yl;
484                                     printf("%*sforcing chain, %d at ends of ",
485                                            solver_recurse_depth*4, "", orign);
486                                     xl = xx;
487                                     yl = yy;
488                                     while (1) {
489                                         printf("%s(%d,%d)", sep, xl,
490                                                YUNTRANS(yl));
491                                         xl = bfsprev[yl*o+xl];
492                                         if (xl < 0)
493                                             break;
494                                         yl = xl / o;
495                                         xl %= o;
496                                         sep = "-";
497                                     }
498                                     printf("\n%*s  ruling out %d at (%d,%d)\n",
499                                            solver_recurse_depth*4, "",
500                                            orign, xt, YUNTRANS(yt));
501                                 }
502 #endif
503                                 cube(xt, yt, orign) = FALSE;
504                                 return 1;
505                             }
506                         }
507                     }
508                 }
509         }
510
511     return 0;
512 }
513
514 struct latin_solver_scratch *latin_solver_new_scratch(struct latin_solver *solver)
515 {
516     struct latin_solver_scratch *scratch = snew(struct latin_solver_scratch);
517     int o = solver->o;
518     scratch->grid = snewn(o*o, unsigned char);
519     scratch->rowidx = snewn(o, unsigned char);
520     scratch->colidx = snewn(o, unsigned char);
521     scratch->set = snewn(o, unsigned char);
522     scratch->neighbours = snewn(3*o, int);
523     scratch->bfsqueue = snewn(o*o, int);
524 #ifdef STANDALONE_SOLVER
525     scratch->bfsprev = snewn(o*o, int);
526 #endif
527     return scratch;
528 }
529
530 void latin_solver_free_scratch(struct latin_solver_scratch *scratch)
531 {
532 #ifdef STANDALONE_SOLVER
533     sfree(scratch->bfsprev);
534 #endif
535     sfree(scratch->bfsqueue);
536     sfree(scratch->neighbours);
537     sfree(scratch->set);
538     sfree(scratch->colidx);
539     sfree(scratch->rowidx);
540     sfree(scratch->grid);
541     sfree(scratch);
542 }
543
544 void latin_solver_alloc(struct latin_solver *solver, digit *grid, int o)
545 {
546     int x, y;
547
548     solver->o = o;
549     solver->cube = snewn(o*o*o, unsigned char);
550     solver->grid = grid;                /* write straight back to the input */
551     memset(solver->cube, TRUE, o*o*o);
552
553     solver->row = snewn(o*o, unsigned char);
554     solver->col = snewn(o*o, unsigned char);
555     memset(solver->row, FALSE, o*o);
556     memset(solver->col, FALSE, o*o);
557
558     for (x = 0; x < o; x++)
559         for (y = 0; y < o; y++)
560             if (grid[y*o+x])
561                 latin_solver_place(solver, x, YTRANS(y), grid[y*o+x]);
562 }
563
564 void latin_solver_free(struct latin_solver *solver)
565 {
566     sfree(solver->cube);
567     sfree(solver->row);
568     sfree(solver->col);
569 }
570
571 int latin_solver_diff_simple(struct latin_solver *solver)
572 {
573     int x, y, n, ret, o = solver->o;
574     /*
575      * Row-wise positional elimination.
576      */
577     for (y = 0; y < o; y++)
578         for (n = 1; n <= o; n++)
579             if (!solver->row[y*o+n-1]) {
580                 ret = latin_solver_elim(solver, cubepos(0,y,n), o*o
581 #ifdef STANDALONE_SOLVER
582                                         , "positional elimination,"
583                                         " %d in row %d", n, YUNTRANS(y)
584 #endif
585                                         );
586                 if (ret != 0) return ret;
587             }
588     /*
589      * Column-wise positional elimination.
590      */
591     for (x = 0; x < o; x++)
592         for (n = 1; n <= o; n++)
593             if (!solver->col[x*o+n-1]) {
594                 ret = latin_solver_elim(solver, cubepos(x,0,n), o
595 #ifdef STANDALONE_SOLVER
596                                         , "positional elimination,"
597                                         " %d in column %d", n, x
598 #endif
599                                         );
600                 if (ret != 0) return ret;
601             }
602
603     /*
604      * Numeric elimination.
605      */
606     for (x = 0; x < o; x++)
607         for (y = 0; y < o; y++)
608             if (!solver->grid[YUNTRANS(y)*o+x]) {
609                 ret = latin_solver_elim(solver, cubepos(x,y,1), 1
610 #ifdef STANDALONE_SOLVER
611                                         , "numeric elimination at (%d,%d)", x,
612                                         YUNTRANS(y)
613 #endif
614                                         );
615                 if (ret != 0) return ret;
616             }
617     return 0;
618 }
619
620 int latin_solver_diff_set(struct latin_solver *solver,
621                           struct latin_solver_scratch *scratch,
622                           int extreme)
623 {
624     int x, y, n, ret, o = solver->o;
625
626     if (!extreme) {
627         /*
628          * Row-wise set elimination.
629          */
630         for (y = 0; y < o; y++) {
631             ret = latin_solver_set(solver, scratch, cubepos(0,y,1), o*o, 1
632 #ifdef STANDALONE_SOLVER
633                                    , "set elimination, row %d", YUNTRANS(y)
634 #endif
635                                   );
636             if (ret != 0) return ret;
637         }
638         /*
639          * Column-wise set elimination.
640          */
641         for (x = 0; x < o; x++) {
642             ret = latin_solver_set(solver, scratch, cubepos(x,0,1), o, 1
643 #ifdef STANDALONE_SOLVER
644                                    , "set elimination, column %d", x
645 #endif
646                                   );
647             if (ret != 0) return ret;
648         }
649     } else {
650         /*
651          * Row-vs-column set elimination on a single number
652          * (much tricker for a human to do!)
653          */
654         for (n = 1; n <= o; n++) {
655             ret = latin_solver_set(solver, scratch, cubepos(0,0,n), o*o, o
656 #ifdef STANDALONE_SOLVER
657                                    , "positional set elimination, number %d", n
658 #endif
659                                   );
660             if (ret != 0) return ret;
661         }
662     }
663     return 0;
664 }
665
666 /*
667  * Returns:
668  * 0 for 'didn't do anything' implying it was already solved.
669  * -1 for 'impossible' (no solution)
670  * 1 for 'single solution'
671  * >1 for 'multiple solutions' (you don't get to know how many, and
672  *     the first such solution found will be set.
673  *
674  * and this function may well assert if given an impossible board.
675  */
676 static int latin_solver_recurse
677     (struct latin_solver *solver, int diff_simple, int diff_set_0,
678      int diff_set_1, int diff_forcing, int diff_recursive,
679      usersolver_t const *usersolvers, void *ctx,
680      ctxnew_t ctxnew, ctxfree_t ctxfree)
681 {
682     int best, bestcount;
683     int o = solver->o, x, y, n;
684
685     best = -1;
686     bestcount = o+1;
687
688     for (y = 0; y < o; y++)
689         for (x = 0; x < o; x++)
690             if (!solver->grid[y*o+x]) {
691                 int count;
692
693                 /*
694                  * An unfilled square. Count the number of
695                  * possible digits in it.
696                  */
697                 count = 0;
698                 for (n = 1; n <= o; n++)
699                     if (cube(x,YTRANS(y),n))
700                         count++;
701
702                 /*
703                  * We should have found any impossibilities
704                  * already, so this can safely be an assert.
705                  */
706                 assert(count > 1);
707
708                 if (count < bestcount) {
709                     bestcount = count;
710                     best = y*o+x;
711                 }
712             }
713
714     if (best == -1)
715         /* we were complete already. */
716         return 0;
717     else {
718         int i, j;
719         digit *list, *ingrid, *outgrid;
720         int diff = diff_impossible;    /* no solution found yet */
721
722         /*
723          * Attempt recursion.
724          */
725         y = best / o;
726         x = best % o;
727
728         list = snewn(o, digit);
729         ingrid = snewn(o*o, digit);
730         outgrid = snewn(o*o, digit);
731         memcpy(ingrid, solver->grid, o*o);
732
733         /* Make a list of the possible digits. */
734         for (j = 0, n = 1; n <= o; n++)
735             if (cube(x,YTRANS(y),n))
736                 list[j++] = n;
737
738 #ifdef STANDALONE_SOLVER
739         if (solver_show_working) {
740             char *sep = "";
741             printf("%*srecursing on (%d,%d) [",
742                    solver_recurse_depth*4, "", x, y);
743             for (i = 0; i < j; i++) {
744                 printf("%s%d", sep, list[i]);
745                 sep = " or ";
746             }
747             printf("]\n");
748         }
749 #endif
750
751         /*
752          * And step along the list, recursing back into the
753          * main solver at every stage.
754          */
755         for (i = 0; i < j; i++) {
756             int ret;
757             void *newctx;
758
759             memcpy(outgrid, ingrid, o*o);
760             outgrid[y*o+x] = list[i];
761
762 #ifdef STANDALONE_SOLVER
763             if (solver_show_working)
764                 printf("%*sguessing %d at (%d,%d)\n",
765                        solver_recurse_depth*4, "", list[i], x, y);
766             solver_recurse_depth++;
767 #endif
768
769             if (ctxnew) {
770                 newctx = ctxnew(ctx);
771             } else {
772                 newctx = ctx;
773             }
774             ret = latin_solver(outgrid, o, diff_recursive,
775                                diff_simple, diff_set_0, diff_set_1,
776                                diff_forcing, diff_recursive,
777                                usersolvers, newctx, ctxnew, ctxfree);
778             if (ctxnew)
779                 ctxfree(newctx);
780
781 #ifdef STANDALONE_SOLVER
782             solver_recurse_depth--;
783             if (solver_show_working) {
784                 printf("%*sretracting %d at (%d,%d)\n",
785                        solver_recurse_depth*4, "", list[i], x, y);
786             }
787 #endif
788             /* we recurse as deep as we can, so we should never find
789              * find ourselves giving up on a puzzle without declaring it
790              * impossible.  */
791             assert(ret != diff_unfinished);
792
793             /*
794              * If we have our first solution, copy it into the
795              * grid we will return.
796              */
797             if (diff == diff_impossible && ret != diff_impossible)
798                 memcpy(solver->grid, outgrid, o*o);
799
800             if (ret == diff_ambiguous)
801                 diff = diff_ambiguous;
802             else if (ret == diff_impossible)
803                 /* do not change our return value */;
804             else {
805                 /* the recursion turned up exactly one solution */
806                 if (diff == diff_impossible)
807                     diff = diff_recursive;
808                 else
809                     diff = diff_ambiguous;
810             }
811
812             /*
813              * As soon as we've found more than one solution,
814              * give up immediately.
815              */
816             if (diff == diff_ambiguous)
817                 break;
818         }
819
820         sfree(outgrid);
821         sfree(ingrid);
822         sfree(list);
823
824         if (diff == diff_impossible)
825             return -1;
826         else if (diff == diff_ambiguous)
827             return 2;
828         else {
829             assert(diff == diff_recursive);
830             return 1;
831         }
832     }
833 }
834
835 int latin_solver_main(struct latin_solver *solver, int maxdiff,
836                       int diff_simple, int diff_set_0, int diff_set_1,
837                       int diff_forcing, int diff_recursive,
838                       usersolver_t const *usersolvers, void *ctx,
839                       ctxnew_t ctxnew, ctxfree_t ctxfree)
840 {
841     struct latin_solver_scratch *scratch = latin_solver_new_scratch(solver);
842     int ret, diff = diff_simple;
843
844     assert(maxdiff <= diff_recursive);
845     /*
846      * Now loop over the grid repeatedly trying all permitted modes
847      * of reasoning. The loop terminates if we complete an
848      * iteration without making any progress; we then return
849      * failure or success depending on whether the grid is full or
850      * not.
851      */
852     while (1) {
853         int i;
854
855         cont:
856
857         latin_solver_debug(solver->cube, solver->o);
858
859         for (i = 0; i <= maxdiff; i++) {
860             if (usersolvers[i])
861                 ret = usersolvers[i](solver, ctx);
862             else
863                 ret = 0;
864             if (ret == 0 && i == diff_simple)
865                 ret = latin_solver_diff_simple(solver);
866             if (ret == 0 && i == diff_set_0)
867                 ret = latin_solver_diff_set(solver, scratch, 0);
868             if (ret == 0 && i == diff_set_1)
869                 ret = latin_solver_diff_set(solver, scratch, 1);
870             if (ret == 0 && i == diff_forcing)
871                 ret = latin_solver_forcing(solver, scratch);
872
873             if (ret < 0) {
874                 diff = diff_impossible;
875                 goto got_result;
876             } else if (ret > 0) {
877                 diff = max(diff, i);
878                 goto cont;
879             }
880         }
881
882         /*
883          * If we reach here, we have made no deductions in this
884          * iteration, so the algorithm terminates.
885          */
886         break;
887     }
888
889     /*
890      * Last chance: if we haven't fully solved the puzzle yet, try
891      * recursing based on guesses for a particular square. We pick
892      * one of the most constrained empty squares we can find, which
893      * has the effect of pruning the search tree as much as
894      * possible.
895      */
896     if (maxdiff == diff_recursive) {
897         int nsol = latin_solver_recurse(solver,
898                                         diff_simple, diff_set_0, diff_set_1,
899                                         diff_forcing, diff_recursive,
900                                         usersolvers, ctx, ctxnew, ctxfree);
901         if (nsol < 0) diff = diff_impossible;
902         else if (nsol == 1) diff = diff_recursive;
903         else if (nsol > 1) diff = diff_ambiguous;
904         /* if nsol == 0 then we were complete anyway
905          * (and thus don't need to change diff) */
906     } else {
907         /*
908          * We're forbidden to use recursion, so we just see whether
909          * our grid is fully solved, and return diff_unfinished
910          * otherwise.
911          */
912         int x, y, o = solver->o;
913
914         for (y = 0; y < o; y++)
915             for (x = 0; x < o; x++)
916                 if (!solver->grid[y*o+x])
917                     diff = diff_unfinished;
918     }
919
920     got_result:
921
922 #ifdef STANDALONE_SOLVER
923     if (solver_show_working)
924         printf("%*s%s found\n",
925                solver_recurse_depth*4, "",
926                diff == diff_impossible ? "no solution (impossible)" :
927                diff == diff_unfinished ? "no solution (unfinished)" :
928                diff == diff_ambiguous ? "multiple solutions" :
929                "one solution");
930 #endif
931
932     latin_solver_free_scratch(scratch);
933
934     return diff;
935 }
936
937 int latin_solver(digit *grid, int o, int maxdiff,
938                  int diff_simple, int diff_set_0, int diff_set_1,
939                  int diff_forcing, int diff_recursive,
940                  usersolver_t const *usersolvers, void *ctx,
941                  ctxnew_t ctxnew, ctxfree_t ctxfree)
942 {
943     struct latin_solver solver;
944     int diff;
945
946     latin_solver_alloc(&solver, grid, o);
947     diff = latin_solver_main(&solver, maxdiff,
948                              diff_simple, diff_set_0, diff_set_1,
949                              diff_forcing, diff_recursive,
950                              usersolvers, ctx, ctxnew, ctxfree);
951     latin_solver_free(&solver);
952     return diff;
953 }
954
955 void latin_solver_debug(unsigned char *cube, int o)
956 {
957 #ifdef STANDALONE_SOLVER
958     if (solver_show_working > 1) {
959         struct latin_solver ls, *solver = &ls;
960         char *dbg;
961         int x, y, i, c = 0;
962
963         ls.cube = cube; ls.o = o; /* for cube() to work */
964
965         dbg = snewn(3*o*o*o, char);
966         for (y = 0; y < o; y++) {
967             for (x = 0; x < o; x++) {
968                 for (i = 1; i <= o; i++) {
969                     if (cube(x,y,i))
970                         dbg[c++] = i + '0';
971                     else
972                         dbg[c++] = '.';
973                 }
974                 dbg[c++] = ' ';
975             }
976             dbg[c++] = '\n';
977         }
978         dbg[c++] = '\n';
979         dbg[c++] = '\0';
980
981         printf("%s", dbg);
982         sfree(dbg);
983     }
984 #endif
985 }
986
987 void latin_debug(digit *sq, int o)
988 {
989 #ifdef STANDALONE_SOLVER
990     if (solver_show_working) {
991         int x, y;
992
993         for (y = 0; y < o; y++) {
994             for (x = 0; x < o; x++) {
995                 printf("%2d ", sq[y*o+x]);
996             }
997             printf("\n");
998         }
999         printf("\n");
1000     }
1001 #endif
1002 }
1003
1004 /* --------------------------------------------------------
1005  * Generation.
1006  */
1007
1008 digit *latin_generate(int o, random_state *rs)
1009 {
1010     digit *sq;
1011     int *edges, *backedges, *capacity, *flow;
1012     void *scratch;
1013     int ne, scratchsize;
1014     int i, j, k;
1015     digit *row, *col, *numinv, *num;
1016
1017     /*
1018      * To efficiently generate a latin square in such a way that
1019      * all possible squares are possible outputs from the function,
1020      * we make use of a theorem which states that any r x n latin
1021      * rectangle, with r < n, can be extended into an (r+1) x n
1022      * latin rectangle. In other words, we can reliably generate a
1023      * latin square row by row, by at every stage writing down any
1024      * row at all which doesn't conflict with previous rows, and
1025      * the theorem guarantees that we will never have to backtrack.
1026      *
1027      * To find a viable row at each stage, we can make use of the
1028      * support functions in maxflow.c.
1029      */
1030
1031     sq = snewn(o*o, digit);
1032
1033     /*
1034      * In case this method of generation introduces a really subtle
1035      * top-to-bottom directional bias, we'll generate the rows in
1036      * random order.
1037      */
1038     row = snewn(o, digit);
1039     col = snewn(o, digit);
1040     numinv = snewn(o, digit);
1041     num = snewn(o, digit);
1042     for (i = 0; i < o; i++)
1043         row[i] = i;
1044     shuffle(row, i, sizeof(*row), rs);
1045
1046     /*
1047      * Set up the infrastructure for the maxflow algorithm.
1048      */
1049     scratchsize = maxflow_scratch_size(o * 2 + 2);
1050     scratch = smalloc(scratchsize);
1051     backedges = snewn(o*o + 2*o, int);
1052     edges = snewn((o*o + 2*o) * 2, int);
1053     capacity = snewn(o*o + 2*o, int);
1054     flow = snewn(o*o + 2*o, int);
1055     /* Set up the edge array, and the initial capacities. */
1056     ne = 0;
1057     for (i = 0; i < o; i++) {
1058         /* Each LHS vertex is connected to all RHS vertices. */
1059         for (j = 0; j < o; j++) {
1060             edges[ne*2] = i;
1061             edges[ne*2+1] = j+o;
1062             /* capacity for this edge is set later on */
1063             ne++;
1064         }
1065     }
1066     for (i = 0; i < o; i++) {
1067         /* Each RHS vertex is connected to the distinguished sink vertex. */
1068         edges[ne*2] = i+o;
1069         edges[ne*2+1] = o*2+1;
1070         capacity[ne] = 1;
1071         ne++;
1072     }
1073     for (i = 0; i < o; i++) {
1074         /* And the distinguished source vertex connects to each LHS vertex. */
1075         edges[ne*2] = o*2;
1076         edges[ne*2+1] = i;
1077         capacity[ne] = 1;
1078         ne++;
1079     }
1080     assert(ne == o*o + 2*o);
1081     /* Now set up backedges. */
1082     maxflow_setup_backedges(ne, edges, backedges);
1083     
1084     /*
1085      * Now generate each row of the latin square.
1086      */
1087     for (i = 0; i < o; i++) {
1088         /*
1089          * To prevent maxflow from behaving deterministically, we
1090          * separately permute the columns and the digits for the
1091          * purposes of the algorithm, differently for every row.
1092          */
1093         for (j = 0; j < o; j++)
1094             col[j] = num[j] = j;
1095         shuffle(col, j, sizeof(*col), rs);
1096         shuffle(num, j, sizeof(*num), rs);
1097         /* We need the num permutation in both forward and inverse forms. */
1098         for (j = 0; j < o; j++)
1099             numinv[num[j]] = j;
1100
1101         /*
1102          * Set up the capacities for the maxflow run, by examining
1103          * the existing latin square.
1104          */
1105         for (j = 0; j < o*o; j++)
1106             capacity[j] = 1;
1107         for (j = 0; j < i; j++)
1108             for (k = 0; k < o; k++) {
1109                 int n = num[sq[row[j]*o + col[k]] - 1];
1110                 capacity[k*o+n] = 0;
1111             }
1112
1113         /*
1114          * Run maxflow.
1115          */
1116         j = maxflow_with_scratch(scratch, o*2+2, 2*o, 2*o+1, ne,
1117                                  edges, backedges, capacity, flow, NULL);
1118         assert(j == o);   /* by the above theorem, this must have succeeded */
1119
1120         /*
1121          * And examine the flow array to pick out the new row of
1122          * the latin square.
1123          */
1124         for (j = 0; j < o; j++) {
1125             for (k = 0; k < o; k++) {
1126                 if (flow[j*o+k])
1127                     break;
1128             }
1129             assert(k < o);
1130             sq[row[i]*o + col[j]] = numinv[k] + 1;
1131         }
1132     }
1133
1134     /*
1135      * Done. Free our internal workspaces...
1136      */
1137     sfree(flow);
1138     sfree(capacity);
1139     sfree(edges);
1140     sfree(backedges);
1141     sfree(scratch);
1142     sfree(numinv);
1143     sfree(num);
1144     sfree(col);
1145     sfree(row);
1146
1147     /*
1148      * ... and return our completed latin square.
1149      */
1150     return sq;
1151 }
1152
1153 /* --------------------------------------------------------
1154  * Checking.
1155  */
1156
1157 typedef struct lcparams {
1158     digit elt;
1159     int count;
1160 } lcparams;
1161
1162 static int latin_check_cmp(void *v1, void *v2)
1163 {
1164     lcparams *lc1 = (lcparams *)v1;
1165     lcparams *lc2 = (lcparams *)v2;
1166
1167     if (lc1->elt < lc2->elt) return -1;
1168     if (lc1->elt > lc2->elt) return 1;
1169     return 0;
1170 }
1171
1172 #define ELT(sq,x,y) (sq[((y)*order)+(x)])
1173
1174 /* returns non-zero if sq is not a latin square. */
1175 int latin_check(digit *sq, int order)
1176 {
1177     tree234 *dict = newtree234(latin_check_cmp);
1178     int c, r;
1179     int ret = 0;
1180     lcparams *lcp, lc, *aret;
1181
1182     /* Use a tree234 as a simple hash table, go through the square
1183      * adding elements as we go or incrementing their counts. */
1184     for (c = 0; c < order; c++) {
1185         for (r = 0; r < order; r++) {
1186             lc.elt = ELT(sq, c, r); lc.count = 0;
1187             lcp = find234(dict, &lc, NULL);
1188             if (!lcp) {
1189                 lcp = snew(lcparams);
1190                 lcp->elt = ELT(sq, c, r);
1191                 lcp->count = 1;
1192                 aret = add234(dict, lcp);
1193                 assert(aret == lcp);
1194             } else {
1195                 lcp->count++;
1196             }
1197         }
1198     }
1199
1200     /* There should be precisely 'order' letters in the alphabet,
1201      * each occurring 'order' times (making the OxO tree) */
1202     if (count234(dict) != order) ret = 1;
1203     else {
1204         for (c = 0; (lcp = index234(dict, c)) != NULL; c++) {
1205             if (lcp->count != order) ret = 1;
1206         }
1207     }
1208     for (c = 0; (lcp = index234(dict, c)) != NULL; c++)
1209         sfree(lcp);
1210     freetree234(dict);
1211
1212     return ret;
1213 }
1214
1215
1216 /* --------------------------------------------------------
1217  * Testing (and printing).
1218  */
1219
1220 #ifdef STANDALONE_LATIN_TEST
1221
1222 #include <stdio.h>
1223 #include <time.h>
1224
1225 const char *quis;
1226
1227 static void latin_print(digit *sq, int order)
1228 {
1229     int x, y;
1230
1231     for (y = 0; y < order; y++) {
1232         for (x = 0; x < order; x++) {
1233             printf("%2u ", ELT(sq, x, y));
1234         }
1235         printf("\n");
1236     }
1237     printf("\n");
1238 }
1239
1240 static void gen(int order, random_state *rs, int debug)
1241 {
1242     digit *sq;
1243
1244     solver_show_working = debug;
1245
1246     sq = latin_generate(order, rs);
1247     latin_print(sq, order);
1248     if (latin_check(sq, order)) {
1249         fprintf(stderr, "Square is not a latin square!");
1250         exit(1);
1251     }
1252
1253     sfree(sq);
1254 }
1255
1256 void test_soak(int order, random_state *rs)
1257 {
1258     digit *sq;
1259     int n = 0;
1260     time_t tt_start, tt_now, tt_last;
1261
1262     solver_show_working = 0;
1263     tt_now = tt_start = time(NULL);
1264
1265     while(1) {
1266         sq = latin_generate(order, rs);
1267         sfree(sq);
1268         n++;
1269
1270         tt_last = time(NULL);
1271         if (tt_last > tt_now) {
1272             tt_now = tt_last;
1273             printf("%d total, %3.1f/s\n", n,
1274                    (double)n / (double)(tt_now - tt_start));
1275         }
1276     }
1277 }
1278
1279 void usage_exit(const char *msg)
1280 {
1281     if (msg)
1282         fprintf(stderr, "%s: %s\n", quis, msg);
1283     fprintf(stderr, "Usage: %s [--seed SEED] --soak <params> | [game_id [game_id ...]]\n", quis);
1284     exit(1);
1285 }
1286
1287 int main(int argc, char *argv[])
1288 {
1289     int i, soak = 0;
1290     random_state *rs;
1291     time_t seed = time(NULL);
1292
1293     quis = argv[0];
1294     while (--argc > 0) {
1295         const char *p = *++argv;
1296         if (!strcmp(p, "--soak"))
1297             soak = 1;
1298         else if (!strcmp(p, "--seed")) {
1299             if (argc == 0)
1300                 usage_exit("--seed needs an argument");
1301             seed = (time_t)atoi(*++argv);
1302             argc--;
1303         } else if (*p == '-')
1304                 usage_exit("unrecognised option");
1305         else
1306             break; /* finished options */
1307     }
1308
1309     rs = random_new((void*)&seed, sizeof(time_t));
1310
1311     if (soak == 1) {
1312         if (argc != 1) usage_exit("only one argument for --soak");
1313         test_soak(atoi(*argv), rs);
1314     } else {
1315         if (argc > 0) {
1316             for (i = 0; i < argc; i++) {
1317                 gen(atoi(*argv++), rs, 1);
1318             }
1319         } else {
1320             while (1) {
1321                 i = random_upto(rs, 20) + 1;
1322                 gen(i, rs, 0);
1323             }
1324         }
1325     }
1326     random_free(rs);
1327     return 0;
1328 }
1329
1330 #endif
1331
1332 /* vim: set shiftwidth=4 tabstop=8: */