chiark / gitweb /
mm.c (rate): Fix comment typo.
[mm] / mm.c
1 /* -*-c-*-
2  *
3  * Simple mastermind game
4  *
5  * (c) 2006 Mark Wooding
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of mm: a simple Mastermind game.
11  *
12  * mm is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU General Public License as published by
14  * the Free Software Foundation; either version 2 of the License, or
15  * (at your option) any later version.
16  *
17  * mm is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU General Public License for more details.
21  *
22  * You should have received a copy of the GNU General Public License
23  * along with mm; if not, write to the Free Software Foundation,
24  * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
25  */
26
27 /*----- Header files ------------------------------------------------------*/
28
29 #include <assert.h>
30 #include <ctype.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <time.h>
35
36 #include <mLib/alloc.h>
37 #include <mLib/darray.h>
38 #include <mLib/mdwopt.h>
39 #include <mLib/quis.h>
40 #include <mLib/report.h>
41 #include <mLib/sub.h>
42
43 /*----- Data structures ---------------------------------------------------*/
44
45 /* --- Digits --- *
46  *
47  * The symbols which make up the code to be guessed.
48  */
49
50 typedef unsigned char dig;
51
52 /* --- The game parameters --- */
53
54 typedef struct mm {
55   dig k;                                /* Number of symbols in the code */
56   dig n;                                /* Number of distinct symbols */
57 } mm;
58
59 /*----- Rating guesses ----------------------------------------------------*/
60
61 /* --- Algorithm --- *
62  *
63  * Rating guesses efficiently is quite important.
64  *
65  * The rate context structure contains a copy of the game parameters, and
66  * three arrays, @v@, @s@ and @t@ allocated from the same @malloc@ed block:
67  *
68  *   * %$v_i$% counts occurrences of the symbol %$i$% in the code.
69  *   * %$s$% is a copy of the code.
70  *   * %$t$% is temporary work space for the rating function.
71  *
72  * The rating function works by taking a pass over the guess, computing a
73  * count table %$v'$%; but for each guess symbol which matches the
74  * corresponding code symbol, decrementing the count %$v_i$% for that symbol
75  * (in a temporary copy of the table %$v$%).  The black count is then the
76  * number of such matches, and the white count is given by
77  *
78  *   %$w = \displaystyle \sum_{0<i\le n} \min(v_i, v'_i).$%
79  *
80  * Thus, the work is %$O(k + n)$%, rather than the %$O(k^2)$% for a
81  * %%na\"\i{}ve%% implementation.
82  */
83
84 typedef struct ratectx {
85   mm m;
86   dig *v;
87   dig *s;
88   dig *t;
89 } ratectx;
90
91 static ratectx *rate_alloc(const mm *m)
92 {
93   ratectx *r;
94   dig *v;
95
96   r = CREATE(ratectx);
97   v = xmalloc((3 * m->n + m->k) * sizeof(dig));
98   r->m = *m;
99   r->v = v;
100   r->s = r->v + m->n;
101   r->t = r->s + m->k;
102   return (r);
103 }
104
105 static void rate_init(ratectx *r, const dig *s)
106 {
107   unsigned i;
108
109   memset(r->v, 0, r->m.n * sizeof(dig));
110   for (i = 0; i < r->m.k; i++)
111     r->v[s[i]]++;
112   memcpy(r->s, s, r->m.k * sizeof(dig));
113 }
114
115 static ratectx *rate_new(const mm *m, const dig *s)
116 {
117   ratectx *r = rate_alloc(m);
118
119   rate_init(r, s);
120   return (r);
121 }
122
123 static void rate(const ratectx *r, const dig *g, unsigned  *b, unsigned *w)
124 {
125   unsigned i;
126   unsigned k = r->m.k, n = r->m.n;
127   dig *v = r->t;
128   dig *vv = v + n;
129   const dig *s = r->s;
130   unsigned bb = 0, ww = 0;
131
132   memset(v, 0, n * sizeof(dig));
133   memcpy(vv, r->v, n * sizeof(dig));
134   for (i = 0; i < k; i++) {
135     if (g[i] != s[i])
136       v[g[i]]++;
137     else {
138       vv[g[i]]--;
139       bb++;
140     }
141   }
142   for (i = 0; i < n; i++)
143     ww += v[i] < vv[i] ? v[i] : vv[i];
144   *b = bb;
145   *w = ww;
146 }
147
148 static void rate_free(ratectx *r) { xfree(r->v); DESTROY(r); }
149
150 /*----- Computer player ---------------------------------------------------*/
151
152 /* --- About the algorithms --- *
153  *
154  * At each stage, we attampt to choose the guess which will give us the most
155  * information, regardless of the outcome.  For each guess candidate, we
156  * count the remaining possible codes for each outcome, and choose the
157  * candidate with the least square sum.  There are wrinkles.
158  *
159  * Firstly the number of possible guesses is large, and the number of
160  * possible codes is huge too; and our algorithm takes time proportional to
161  * the product of the two.  However, a symbol we've never tried before is as
162  * good as any other, so we can narrow the list of candidate guesses by
163  * considering only %%\emph{prototypes}%% where we use only the smallest
164  * untried symbol at any point to represent introducing any new symbol.  The
165  * number of initial prototypes is quite small.  For the four-symbol game,
166  * they are 0000, 0001, 0011, 0012, 0111, 0112, 0122, and 0123.
167  *
168  * Secondly, when the number of possible codes become small, we want to bias
169  * the guess selection algorithm towards possible codes (so that we win if
170  * we're lucky).  Since the algorithm chooses the first guess with the lowest
171  * sum-of-squares value, we simply run through the possible codes before
172  * enumerating the prototype guesses.
173  */
174
175 typedef struct cpc {
176   mm m;                                 /* Game parameters */
177   unsigned f;                           /* Various flags */
178 #define CPCF_QUIET 1u                   /*   Don't produce lots of output */
179   dig *s; /* n^k * k */                 /* Remaining guesses */
180   size_t ns;                            /* Number of remaining guesses */
181   dig *bg; /* k */                      /* Current best guess */
182   dig *t; /* k */                       /* Scratch-space for prototype */
183   double bmax;                          /* Best guess least-squares score */
184   dig x, bx;                            /* Next unused symbol index */
185   size_t *v; /* (k + 1)^2 */            /* Bin counts for least-squares */
186   ratectx *r;                           /* Rate context for search */
187 } cpc;
188
189 static void print_guess(const mm *m, const dig *d)
190 {
191   unsigned k = m->k, i;
192
193   for (i = 0; i < k; i++) {
194     if (i) putchar(' ');
195     printf("%u", d[i]);
196   }
197 }
198
199 static void dofep(cpc *c, void (*fn)(cpc *c, const dig *g, unsigned x),
200                   unsigned k, unsigned n, unsigned i, unsigned x)
201 {
202   unsigned j;
203   dig *t = c->t;
204
205   if (i == k)
206     fn(c, c->t, x);
207   else {
208     for (j = 0; j < x; j++) {
209       t[i] = j;
210       dofep(c, fn, k, n, i + 1, x);
211     }
212     if (x < n) {
213       t[i] = x;
214       dofep(c, fn, k, n, i + 1, x + 1);
215     }
216   }
217 }
218
219 static void foreach_proto(cpc *c, void (*fn)(cpc *c,
220                                              const dig *g,
221                                              unsigned x))
222 {
223   unsigned k = c->m.k, n = c->m.n;
224
225   dofep(c, fn, k, n, 0, c->x);
226 }
227
228 static void try_guess(cpc *c, const dig *g, unsigned x)
229 {
230   size_t i;
231   unsigned b, w;
232   const dig *s;
233   unsigned k = c->m.k;
234   size_t *v = c->v;
235   size_t *vp;
236   double max;
237
238   rate_init(c->r, g);
239   memset(v, 0, (k + 1) * (k + 1) * sizeof(size_t));
240   for (i = c->ns, s = c->s; i; i--, s += k) {
241     rate(c->r, s, &b, &w);
242     v[b * (k + 1) + w]++;
243   }
244   max = 0;
245   for (i = (k + 1) * (k + 1), vp = v; i; i--, vp++)
246     max += (double)*vp * (double)*vp;
247   if (c->bmax < 0 || max < c->bmax) {
248     memcpy(c->bg, g, k * sizeof(dig));
249     c->bmax = max;
250     c->bx = x;
251   }
252 }
253
254 static void best_guess(cpc *c)
255 {
256   c->bmax = -1;
257   if (c->ns < 1024) {
258     unsigned k = c->m.k;
259     const dig *s;
260     size_t i;
261
262     for (i = c->ns, s = c->s; i; i--, s += k)
263       try_guess(c, s, c->x);
264   }
265   foreach_proto(c, try_guess);
266   c->x = c->bx;
267 }
268
269 static void filter_guesses(cpc *c, const dig *g, unsigned b, unsigned w)
270 {
271   unsigned k = c->m.k;
272   size_t i;
273   const dig *s;
274   unsigned bb, ww;
275   dig *ss;
276
277   rate_init(c->r, g);
278   for (i = c->ns, s = ss = c->s; i; i--, s += k) {
279     rate(c->r, s, &bb, &ww);
280     if (b == bb && w == ww) {
281       memmove(ss, s, k * sizeof(dig));
282       ss += k;
283     }
284   }
285   c->ns = (ss - c->s) / k;
286 }
287
288 static size_t ipow(size_t b, size_t x)
289 {
290   size_t a = 1;
291   while (x) {
292     if (x & 1)
293       a *= b;
294     b *= b;
295     x >>= 1;
296   }
297   return (a);
298 }
299
300 static void all_guesses(dig **ss, unsigned k, unsigned n,
301                         unsigned i, const dig *b)
302 {
303   unsigned j;
304
305   if (i == k) {
306     (*ss) += k;
307     return;
308   }
309   for (j = 0; j < n; j++) {
310     dig *s = *ss;
311     if (i)
312       memcpy(*ss, b, i * sizeof(dig));
313     s[i] = j;
314     all_guesses(ss, k, n, i + 1, s);
315   }
316 }
317
318 #define THINK(c, what, how) do {                                        \
319   clock_t _t0 = 0, _t1;                                                 \
320   if (!(c->f & CPCF_QUIET)) {                                           \
321     fputs(what "...", stdout);                                          \
322     fflush(stdout);                                                     \
323     _t0 = clock();                                                      \
324   }                                                                     \
325   do how while (0);                                                     \
326   if (!(c->f & CPCF_QUIET)) {                                           \
327     _t1 = clock();                                                      \
328     printf(" done (%.2fs)\n", (_t1 - _t0)/(double)CLOCKS_PER_SEC);      \
329   }                                                                     \
330 } while (0)
331
332 static cpc *cpc_new(const mm *m, unsigned f)
333 {
334   cpc *c = CREATE(cpc);
335
336   c->f = f;
337   c->m = *m;
338   c->ns = ipow(c->m.n, c->m.k);
339   c->s = xmalloc((c->ns + 2) * c->m.k * sizeof(dig));
340   c->bg = c->s + c->ns * c->m.k;
341   c->t = c->bg + c->m.k;
342   c->x = 0;
343   c->v = xmalloc((c->m.k + 1) * (c->m.k + 1) * sizeof(size_t));
344   c->r = rate_alloc(m);
345   THINK(c, "Setting up", {
346     dig *ss = c->s; all_guesses(&ss, c->m.k, c->m.n, 0, 0);
347   });
348   return (c);
349 }
350
351 static void cpc_free(cpc *c)
352 {
353   xfree(c->s);
354   xfree(c->v);
355   rate_free(c->r);
356   DESTROY(c);
357 }
358
359 static void cp_rate(void *r, const dig *g, unsigned *b, unsigned *w)
360   { rate(r, g, b, w); }
361
362 static const dig *cp_guess(void *cc)
363 {
364   cpc *c = cc;
365
366   if (c->ns == 0) {
367     if (!(c->f & CPCF_QUIET))
368       printf("Liar!  All solutions ruled out.\n");
369     return (0);
370   }
371   if (c->ns == 1) {
372     if (!(c->f & CPCF_QUIET)) {
373       fputs("Done!  Solution is ", stdout);
374       print_guess(&c->m, c->s);
375       putchar('\n');
376     }
377     return (c->s);
378   }
379   if (!(c->f & CPCF_QUIET)) {
380     printf("(Possible solutions remaining = %lu)\n",
381            (unsigned long)c->ns);
382     if (c->ns < 32) {
383       const dig *s;
384       size_t i;
385       for (i = c->ns, s = c->s; i; i--, s += c->m.k) {
386         printf("  %2lu: ", (unsigned long)(c->ns - i + 1));
387         print_guess(&c->m, s);
388         putchar('\n');
389       }
390     }
391   }
392   THINK(c, "Pondering", {
393     best_guess(c);
394   });
395   return (c->bg);
396 }
397
398 static void cp_update(void *cc, const dig *g, unsigned b, unsigned w)
399 {
400   cpc *c = cc;
401
402   if (!(c->f & CPCF_QUIET)) {
403     fputs("My guess = ", stdout);
404     print_guess(&c->m, g);
405     printf("; rating = %u black, %u white\n", b, w);
406   }
407   THINK(c, "Filtering", {
408     filter_guesses(c, g, b, w);
409   });
410 }
411
412 /*----- Human player ------------------------------------------------------*/
413
414 typedef struct hpc {
415   mm m;
416   dig *t;
417 } hpc;
418
419 static hpc *hpc_new(const mm *m)
420 {
421   hpc *h = CREATE(hpc);
422   h->t = xmalloc(m->k * sizeof(dig));
423   h->m = *m;
424   return (h);
425 }
426
427 static void hpc_free(hpc *h)
428 {
429   xfree(h->t);
430   DESTROY(h);
431 }
432
433 static void hp_rate(void *mp, const dig *g, unsigned *b, unsigned *w)
434 {
435   mm *m = mp;
436   fputs("Guess = ", stdout);
437   print_guess(m, g);
438   printf("; rating: ");
439   fflush(stdout);
440   scanf("%u %u", b, w);
441 }
442
443 static const dig *hp_guess(void *hh)
444 {
445   hpc *h = hh;
446   unsigned i;
447
448   fputs("Your guess: ", stdout);
449   fflush(stdout);
450   for (i = 0; i < h->m.k; i++) {
451     unsigned x;
452     scanf("%u", &x);
453     h->t[i] = x;
454   }
455   return (h->t);
456 }
457
458 static void hp_update(void *cc, const dig *g, unsigned b, unsigned w)
459 {
460   printf("Rating = %u black, %u white\n", b, w);
461 }
462
463 /*----- Solver player -----------------------------------------------------*/
464
465 typedef struct spc {
466   cpc *c;
467   hpc *h;
468   int i;
469 } spc;
470
471 static spc *spc_new(const mm *m)
472 {
473   spc *s = CREATE(spc);
474   s->c = cpc_new(m, 0);
475   s->h = hpc_new(m);
476   s->i = 0;
477   return (s);
478 }
479
480 static void spc_free(spc *s)
481 {
482   cpc_free(s->c);
483   hpc_free(s->h);
484   DESTROY(s);
485 }
486
487 static const dig *sp_guess(void *ss)
488 {
489   spc *s = ss;
490   hpc *h = s->h;
491   unsigned i;
492   int ch;
493
494 again:
495   if (s->i)
496     return (cp_guess(s->c));
497
498   fputs("Your guess (dot for end): ", stdout);
499   fflush(stdout);
500   do ch = getchar(); while (isspace(ch));
501   if (!isdigit(ch)) { s->i = 1; goto again; }
502   ungetc(ch, stdin);
503   for (i = 0; i < h->m.k; i++) {
504     unsigned x;
505     scanf("%u", &x);
506     h->t[i] = x;
507   }
508   return (h->t);
509 }
510
511 static void sp_update(void *ss, const dig *g, unsigned b, unsigned w)
512   { spc *s = ss; cp_update(s->c, g, b, w); }
513
514 /*----- Full tournament stuff ---------------------------------------------*/
515
516 DA_DECL(uint_v, unsigned);
517
518 typedef struct allstats {
519   const mm *m;
520   unsigned f;
521 #define AF_VERBOSE 1u
522   uint_v gmap;
523   unsigned long g;
524   unsigned long n;
525   clock_t t;
526 } allstats;
527
528 static void dorunone(allstats *a, dig *s)
529 {
530   ratectx *r = rate_new(a->m, s);
531   clock_t t = 0, t0, t1;
532   cpc *c;
533   int n = 0;
534   const dig *g;
535   unsigned b, w;
536
537   if (a->f & AF_VERBOSE) {
538     print_guess(a->m, s);
539     fputs(": ", stdout);
540     fflush(stdout);
541   }
542
543   c = cpc_new(a->m, CPCF_QUIET);
544   for (;;) {
545     t0 = clock();
546     g = cp_guess(c);
547     t1 = clock();
548     t += t1 - t0;
549     assert(g);
550     n++;
551     rate(r, g, &b, &w);
552     if (b == a->m->k)
553       break;
554     t0 = clock();
555     cp_update(c, g, b, w);
556     t1 = clock();
557     t += t1 - t0;
558   }
559   a->t += t;
560   a->g += n;
561   while (DA_LEN(&a->gmap) <= n)
562     DA_PUSH(&a->gmap, 0);
563   DA(&a->gmap)[n]++;
564   rate_free(r);
565   cpc_free(c);
566
567   if (a->f & AF_VERBOSE)
568     printf("%2u (%5.2fs)\n", n, t/(double)CLOCKS_PER_SEC);
569 }
570
571 static void dorunall(allstats *a, dig *s, unsigned i)
572 {
573   dig j;
574
575   if (i >= a->m->k) {
576     dorunone(a, s);
577     a->n++;
578   } else {
579     for (j = 0; j < a->m->n; j++) {
580       s[i] = j;
581       dorunall(a, s, i + 1);
582     }
583   }
584 }
585
586 static void run_all(const mm *m)
587 {
588   dig *s = xmalloc(m->k * sizeof(dig));
589   allstats a;
590   unsigned i;
591
592   a.m = m;
593   a.f = AF_VERBOSE;
594   DA_CREATE(&a.gmap);
595   a.n = 0;
596   a.g = 0;
597   a.t = 0;
598   dorunall(&a, s, 0);
599   xfree(s);
600
601   for (i = 1; i < DA_LEN(&a.gmap); i++)
602     printf("%2u guesses: %5u games\n", i, DA(&a.gmap)[i]);
603   printf("Average: %.4f (%.2fs)\n",
604          (double)a.g/a.n, a.t/(a.n * (double)CLOCKS_PER_SEC));
605 }
606
607 /*----- Main game logic ---------------------------------------------------*/
608
609 static int play(const mm *m,
610                 void (*ratefn)(void *rr, const dig *g,
611                                unsigned *b, unsigned *w),
612                 void *rr,
613                 const dig *(*guessfn)(void *gg),
614                 void (*updatefn)(void *gg, const dig *g,
615                                  unsigned b, unsigned w),
616                 void *gg)
617 {
618   unsigned b, w;
619   const dig *g;
620   unsigned i;
621
622   i = 0;
623   for (;;) {
624     i++;
625     g = guessfn(gg);
626     if (!g)
627       return (-1);
628     ratefn(rr, g, &b, &w);
629     if (b == m->k)
630       return (i);
631     updatefn(gg, g, b, w);
632   }
633 }
634
635 int main(int argc, char *argv[])
636 {
637   unsigned h = 0;
638   void *rr = 0;
639   void (*ratefn)(void *rr, const dig *g, unsigned *b, unsigned *w) = 0;
640   mm m;
641   int n;
642
643   ego(argv[0]);
644   for (;;) {
645     static struct option opt[] = {
646       { "computer",     0,      0,      'C' },
647       { "human",        0,      0,      'H' },
648       { "solver",       0,      0,      'S' },
649       { "all",          0,      0,      'a' },
650       { 0,              0,      0,      0 }
651     };
652     int i = mdwopt(argc, argv, "CHSa", opt, 0, 0, 0);
653     if (i < 0)
654       break;
655     switch (i) {
656       case 'C': h = 0; break;
657       case 'H': h = 1; break;
658       case 'S': h = 2; break;
659       case 'a': h = 99; break;
660       default:
661         exit(1);
662     }
663   }
664   if (argc - optind == 0) {
665     m.k = 4;
666     m.n = 6;
667   } else if (argc - optind < 2)
668     die(1, "bad parameters");
669   else {
670     m.k = atoi(argv[optind++]);
671     m.n = atoi(argv[optind++]);
672     if (argc - optind >= m.k) {
673       dig *s = xmalloc(m.k * sizeof(dig));
674       int i;
675       for (i = 0; i < m.k; i++)
676         s[i] = atoi(argv[optind++]);
677       rr = rate_new(&m, s);
678       ratefn = cp_rate;
679       xfree(s);
680     }
681     if (argc != optind)
682       die(1, "bad parameters");
683   }
684
685   switch (h) {
686     case 1: {
687       hpc *hh = hpc_new(&m);
688       if (!rr) {
689         dig *s = xmalloc(m.k * sizeof(dig));
690         int i;
691         srand(time(0));
692         for (i = 0; i < m.k; i++)
693           s[i] = rand() % m.n;
694         rr = rate_new(&m, s);
695         ratefn = cp_rate;
696         xfree(s);
697       }
698       n = play(&m, ratefn, rr, hp_guess, hp_update, hh);
699       hpc_free(hh);
700     } break;
701     case 0: {
702       cpc *cc = cpc_new(&m, 0);
703       if (rr)
704         n = play(&m, ratefn, rr, cp_guess, cp_update, cc);
705       else
706         n = play(&m, hp_rate, &m, cp_guess, cp_update, cc);
707       cpc_free(cc);
708     } break;
709     case 2: {
710       spc *ss = spc_new(&m);
711       n = play(&m, hp_rate, &m, sp_guess, sp_update, ss);
712       spc_free(ss);
713     } break;
714     case 99:
715       run_all(&m);
716       return (0);
717       break;
718     default:
719       abort();
720   }
721   if (n > 0)
722     printf("Solved in %d guesses\n", n);
723   else
724     die(1, "gave up");
725   return (0);
726 }
727
728 /*----- That's all, folks -------------------------------------------------*/