chiark / gitweb /
mm.6: Add a manpage.
[mm] / mm.c
1 /* -*-c-*-
2  *
3  * $Id$
4  *
5  * Simple mastermind game
6  *
7  * (c) 2006 Mark Wooding
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of mm: a simple Mastermind game.
13  *
14  * mm is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU General Public License as published by
16  * the Free Software Foundation; either version 2 of the License, or
17  * (at your option) any later version.
18  * 
19  * mm is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU General Public License for more details.
23  * 
24  * You should have received a copy of the GNU General Public License
25  * along with mm; if not, write to the Free Software Foundation,
26  * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
27  */
28
29 /*----- Header files ------------------------------------------------------*/
30
31 #include <ctype.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <time.h>
36
37 #include <mLib/alloc.h>
38 #include <mLib/mdwopt.h>
39 #include <mLib/quis.h>
40 #include <mLib/report.h>
41 #include <mLib/sub.h>
42
43 /*----- Data structures ---------------------------------------------------*/
44
45 /* --- Digits --- *
46  *
47  * The symbols which make up the code to be guessed.
48  */
49
50 typedef unsigned char dig;
51
52 /* --- The game parameters --- */
53
54 typedef struct mm {
55   dig k;                                /* Number of symbols in the code */
56   dig n;                                /* Number of distinct symbols */
57 } mm;
58
59 /*----- Rating guesses ----------------------------------------------------*/
60
61 /* --- Algorithm --- *
62  *
63  * Rating guesses efficiently is quite important.
64  *
65  * The rate context structure contains a copy of the game parameters, and
66  * three arrays, @v@, @s@ and @t@ allocated from the same @malloc@ed block:
67  *
68  *   * %$v_i$% counts occurrences of the symbol %$i$% in the code.
69  *   * %$s$% is a copy of the code.
70  *   * %$t$% is temporary work space for the rating function.
71  *
72  * The rating function works by taking a pass over the guess, computing a
73  * count table %$v'$%; but for each guess symbol which matches the
74  * correspondng code symbol, decrementing the count %$v_i$% for that symbol
75  * (in a temporary copy of the table %$v$%).  The black count is then the
76  * number of such matches, and the white count is given by
77  *
78  *   %$w = \displaystyle \sum_{0<i\le n} \min(v_i, v'_i).$%
79  *
80  * Thus, the work is %$O(k + n)$%, rather than the %$O(k^2)$% for a
81  * %%na\"\i{}ve%% implementation.
82  */
83
84 typedef struct ratectx {
85   mm m;
86   dig *v;
87   dig *s;
88   dig *t;
89 } ratectx;
90
91 static ratectx *rate_alloc(const mm *m)
92 {
93   ratectx *r;
94   dig *v;
95
96   r = CREATE(ratectx);
97   v = xmalloc((3 * m->n + m->k) * sizeof(dig));
98   r->m = *m;
99   r->v = v;
100   r->s = r->v + m->n;
101   r->t = r->s + m->k;
102   return (r);
103 }
104
105 static void rate_init(ratectx *r, const dig *s)
106 {
107   unsigned i;
108
109   memset(r->v, 0, r->m.n * sizeof(dig));
110   for (i = 0; i < r->m.k; i++)
111     r->v[s[i]]++;
112   memcpy(r->s, s, r->m.k * sizeof(dig));
113 }
114
115 static ratectx *rate_new(const mm *m, const dig *s)
116 {
117   ratectx *r = rate_alloc(m);
118   rate_init(r, s);
119   return (r);
120 }
121
122 static void rate(const ratectx *r, const dig *g, unsigned  *b, unsigned *w)
123 {
124   unsigned i;
125   unsigned k = r->m.k, n = r->m.n;
126   dig *v = r->t;
127   dig *vv = v + n;
128   const dig *s = r->s;
129   unsigned bb = 0, ww = 0;
130
131   memset(v, 0, n * sizeof(dig));
132   memcpy(vv, r->v, n * sizeof(dig));
133   for (i = 0; i < k; i++) {
134     if (g[i] != s[i])
135       v[g[i]]++;
136     else {
137       vv[g[i]]--;
138       bb++;
139     }
140   }
141   for (i = 0; i < n; i++)
142     ww += v[i] < vv[i] ? v[i] : vv[i];
143   *b = bb;
144   *w = ww;
145 }
146
147 static void rate_free(ratectx *r)
148 {
149   xfree(r->v);
150   DESTROY(r);
151 }
152
153 /*----- Computer player ---------------------------------------------------*/
154
155 /* --- About the algorithms --- *
156  *
157  * At each stage, we attampt to choose the guess which will give us the most
158  * information, regardless of the outcome.  For each guess candidate, we
159  * count the remaining possible codes for each outcome, and choose the
160  * candidate with the least square sum.  There are wrinkles.
161  *
162  * Firstly the number of possible guesses is large, and the number of
163  * possible codes is huge too; and our algorithm takes time proportional to
164  * the product of the two.  However, a symbol we've never tried before is as
165  * good as any other, so we can narrow the list of candidate guesses by
166  * considering only %%\emph{prototypes}%% where we use only the smallest
167  * untried symbol at any point to represent introducing any new symbol.  The
168  * number of initial prototypes is quite small.  For the four-symbol game,
169  * they are 0000, 0001, 0011, 0012, 0111, 0112, 0122, and 0123.
170  *
171  * Secondly, when the number of possible codes become small, we want to bias
172  * the guess selection algorithm towards possible codes (so that we win if
173  * we're lucky).  Since the algorithm chooses the first guess with the lowest
174  * sum-of-squares value, we simply run through the possible codes before
175  * enumerating the prototype guesses.
176  */
177
178 typedef struct cpc {
179   mm m;                                 /* Game parameters */
180   dig *s; /* n^k * k */                 /* Remaining guesses */
181   size_t ns;                            /* Number of remaining guesses */
182   dig *bg; /* k */                      /* Current best guess */
183   dig *t; /* k */                       /* Scratch-space for prototype */
184   double bmax;                          /* Best guess least-squares score */
185   dig x, bx;                            /* Next unused symbol index */
186   size_t *v; /* (k + 1)^2 */            /* Bin counts for least-squares */
187   ratectx *r;                           /* Rate context for search */
188 } cpc;
189
190 static void print_guess(const mm *m, const dig *d)
191 {
192   unsigned k = m->k, i;
193
194   for (i = 0; i < k; i++) {
195     if (i) putchar(' ');
196     printf("%u", d[i]);
197   }
198 }
199
200 static void dofep(cpc *c, void (*fn)(cpc *c, const dig *g, unsigned x),
201                   unsigned k, unsigned n, unsigned i, unsigned x)
202 {
203   unsigned j;
204   dig *t = c->t;
205
206   if (i == k)
207     fn(c, c->t, x);
208   else {
209     for (j = 0; j < x; j++) {
210       t[i] = j;
211       dofep(c, fn, k, n, i + 1, x);
212     }
213     if (x < n) {
214       t[i] = x;
215       dofep(c, fn, k, n, i + 1, x + 1);
216     }
217   }
218 }
219
220 static void foreach_proto(cpc *c, void (*fn)(cpc *c,
221                                              const dig *g,
222                                              unsigned x))
223
224   unsigned k = c->m.k, n = c->m.n;
225   dofep(c, fn, k, n, 0, c->x);
226 }
227
228 static void try_guess(cpc *c, const dig *g, unsigned x)
229 {
230   size_t i;
231   unsigned b, w;
232   const dig *s;
233   unsigned k = c->m.k;
234   size_t *v = c->v;
235   size_t *vp;
236   double max;
237
238   rate_init(c->r, g);
239   memset(v, 0, (k + 1) * (k + 1) * sizeof(size_t));
240   for (i = c->ns, s = c->s; i; i--, s += k) {
241     rate(c->r, s, &b, &w);
242     v[b * (k + 1) + w]++;
243   }
244   max = 0;
245   for (i = (k + 1) * (k + 1), vp = v; i; i--, vp++)
246     max += (double)*vp * (double)*vp;
247   if (c->bmax < 0 || max < c->bmax) {
248     memcpy(c->bg, g, k * sizeof(dig));
249     c->bmax = max;
250     c->bx = x;
251   }
252 }
253
254 static void best_guess(cpc *c)
255 {
256   c->bmax = -1;
257   if (c->ns < 1024) {
258     unsigned k = c->m.k;
259     const dig *s;
260     size_t i;
261
262     for (i = c->ns, s = c->s; i; i--, s += k)
263       try_guess(c, s, c->x);
264   }
265   foreach_proto(c, try_guess);
266   c->x = c->bx;
267 }
268
269 static void filter_guesses(cpc *c, const dig *g, unsigned b, unsigned w)
270 {
271   unsigned k = c->m.k;
272   size_t i;
273   const dig *s;
274   unsigned bb, ww;
275   dig *ss;
276
277   rate_init(c->r, g);
278   for (i = c->ns, s = ss = c->s; i; i--, s += k) {
279     rate(c->r, s, &bb, &ww);
280     if (b == bb && w == ww) {
281       memmove(ss, s, k * sizeof(dig));
282       ss += k;
283     }
284   }
285   c->ns = (ss - c->s) / k;
286 }
287
288 static size_t ipow(size_t b, size_t x)
289 {
290   size_t a = 1;
291   while (x) {
292     if (x & 1)
293       a *= b;
294     b *= b;
295     x >>= 1;
296   }
297   return (a);
298 }
299
300 static void all_guesses(dig **ss, unsigned k, unsigned n,
301                         unsigned i, const dig *b)
302 {
303   unsigned j;
304
305   if (i == k) {
306     (*ss) += k;
307     return;
308   }
309   for (j = 0; j < n; j++) {
310     dig *s = *ss;
311     if (i)
312       memcpy(*ss, b, i * sizeof(dig));
313     s[i] = j;
314     all_guesses(ss, k, n, i + 1, s);
315   }
316 }
317
318 #define THINK(what, how) do {                                           \
319   clock_t _t0, _t1;                                                     \
320   fputs(what "...", stdout);                                            \
321   fflush(stdout);                                                       \
322   _t0 = clock();                                                        \
323   do how while (0);                                                     \
324   _t1 = clock();                                                        \
325   printf(" done (%.2fs)\n", (_t1 - _t0)/(double)CLOCKS_PER_SEC);        \
326 } while (0)
327
328 static cpc *cpc_new(const mm *m)
329 {
330   cpc *c = CREATE(cpc);
331   c->m = *m;
332   c->ns = ipow(c->m.n, c->m.k);
333   c->s = xmalloc((c->ns + 2) * c->m.k * sizeof(dig));
334   c->bg = c->s + c->ns * c->m.k;
335   c->t = c->bg + c->m.k;
336   c->x = 0;
337   c->v = xmalloc((c->m.k + 1) * (c->m.k + 1) * sizeof(size_t));
338   c->r = rate_alloc(m);
339   THINK("Setting up", {
340     dig *ss = c->s; all_guesses(&ss, c->m.k, c->m.n, 0, 0);
341   });
342   return (c);
343 }
344
345 static void cpc_free(cpc *c)
346 {
347   xfree(c->s);
348   xfree(c->v);
349   rate_free(c->r);
350   DESTROY(c);
351 }
352
353 static void cp_rate(void *r, const dig *g, unsigned *b, unsigned *w)
354 {
355   rate(r, g, b, w);
356 }
357
358 static const dig *cp_guess(void *cc)
359 {
360   cpc *c = cc;
361
362   if (c->ns == 0) {
363     printf("Liar!  All solutions ruled out.\n");
364     return (0);
365   }
366   if (c->ns == 1) {
367     fputs("Done!  Solution is ", stdout);
368     print_guess(&c->m, c->s);
369     putchar('\n');
370     return (c->s);
371   }
372   printf("(Possible solutions remaining = %lu)\n",
373          (unsigned long)c->ns);
374   if (c->ns < 32) {
375     const dig *s;
376     size_t i;
377     for (i = c->ns, s = c->s; i; i--, s += c->m.k) {
378       printf("  %2lu: ", (unsigned long)(c->ns - i + 1));
379       print_guess(&c->m, s);
380       putchar('\n');
381     }
382   }
383   THINK("Pondering", {
384     best_guess(c);
385   });
386   return (c->bg);
387 }
388
389 static void cp_update(void *cc, const dig *g, unsigned b, unsigned w)
390 {
391   cpc *c = cc;
392   fputs("My guess = ", stdout);
393   print_guess(&c->m, g);
394   printf("; rating = %u black, %u white\n", b, w);
395   THINK("Filtering", {
396     filter_guesses(c, g, b, w);
397   });
398 }
399
400 /*----- Human player ------------------------------------------------------*/
401
402 typedef struct hpc {
403   mm m;
404   dig *t;
405 } hpc;
406
407 static hpc *hpc_new(const mm *m)
408 {
409   hpc *h = CREATE(hpc);
410   h->t = xmalloc(m->k * sizeof(dig));
411   h->m = *m;
412   return (h);
413 }
414
415 static void hpc_free(hpc *h)
416 {
417   xfree(h->t);
418   DESTROY(h);
419 }
420
421 static void hp_rate(void *mp, const dig *g, unsigned *b, unsigned *w)
422 {
423   mm *m = mp;
424   fputs("Guess = ", stdout);
425   print_guess(m, g);
426   printf("; rating: ");
427   fflush(stdout);
428   scanf("%u %u", b, w);
429 }
430
431 static const dig *hp_guess(void *hh)
432 {
433   hpc *h = hh;
434   unsigned i;
435
436   fputs("Your guess: ", stdout);
437   fflush(stdout);
438   for (i = 0; i < h->m.k; i++) {
439     unsigned x;
440     scanf("%u", &x);
441     h->t[i] = x;
442   }
443   return (h->t);
444 }
445
446 static void hp_update(void *cc, const dig *g, unsigned b, unsigned w)
447 {
448   printf("Rating = %u black, %u white\n", b, w);
449 }
450
451 /*----- Solver player -----------------------------------------------------*/
452
453 typedef struct spc {
454   cpc *c;
455   hpc *h;
456   int i;
457 } spc;
458
459 static spc *spc_new(const mm *m)
460 {
461   spc *s = CREATE(spc);
462   s->c = cpc_new(m);
463   s->h = hpc_new(m);
464   s->i = 0;
465   return (s);
466 }
467
468 static void spc_free(spc *s)
469 {
470   cpc_free(s->c);
471   hpc_free(s->h);
472   DESTROY(s);
473 }
474
475 static const dig *sp_guess(void *ss)
476 {
477   spc *s = ss;
478   hpc *h = s->h;
479   unsigned i;
480   int ch;
481
482 again:
483   if (s->i)
484     return (cp_guess(s->c));
485
486   fputs("Your guess (dot for end): ", stdout);
487   fflush(stdout);
488   do ch = getchar(); while (isspace(ch));
489   if (!isdigit(ch)) { s->i = 1; goto again; }
490   ungetc(ch, stdin);
491   for (i = 0; i < h->m.k; i++) {
492     unsigned x;
493     scanf("%u", &x);
494     h->t[i] = x;
495   }
496   return (h->t);    
497 }
498
499 static void sp_update(void *ss, const dig *g, unsigned b, unsigned w)
500 {
501   spc *s = ss;
502   cp_update(s->c, g, b, w);
503 }
504
505 /*----- Main game logic ---------------------------------------------------*/
506
507 static int play(const mm *m,
508                 void (*ratefn)(void *rr, const dig *g,
509                                unsigned *b, unsigned *w),
510                 void *rr,
511                 const dig *(*guessfn)(void *gg),
512                 void (*updatefn)(void *gg, const dig *g,
513                                  unsigned b, unsigned w),
514                 void *gg)
515 {
516   unsigned b, w;
517   const dig *g;
518   unsigned i;
519
520   i = 0;
521   for (;;) {
522     i++;
523     g = guessfn(gg);
524     if (!g)
525       return (-1);
526     ratefn(rr, g, &b, &w);
527     if (b == m->k)
528       return (i);
529     updatefn(gg, g, b, w);
530   }
531 }
532
533 int main(int argc, char *argv[])
534 {
535   unsigned h = 0;
536   void *rr = 0;
537   void (*ratefn)(void *rr, const dig *g, unsigned *b, unsigned *w) = 0;
538   mm m;
539   int n;
540
541   ego(argv[0]);
542   for (;;) {
543     static struct option opt[] = {
544       { "computer",     0,      0,      'C' },
545       { "human",        0,      0,      'H' },
546       { "solver",       0,      0,      'S' },
547       { 0,              0,      0,      0 }
548     };
549     int i = mdwopt(argc, argv, "CHS", opt, 0, 0, 0);
550     if (i < 0)
551       break;
552     switch (i) {
553       case 'C': h = 0; break;
554       case 'H': h = 1; break;
555       case 'S': h = 2; break;
556       default:
557         exit(1);
558     }
559   }
560   if (argc - optind == 0) {
561     m.k = 4;
562     m.n = 6;
563   } else if (argc - optind < 2)
564     die(1, "bad parameters");
565   else {
566     m.k = atoi(argv[optind++]);
567     m.n = atoi(argv[optind++]);
568     if (argc - optind >= m.k) {
569       dig *s = xmalloc(m.k * sizeof(dig));
570       int i;
571       for (i = 0; i < m.k; i++)
572         s[i] = atoi(argv[optind++]);
573       rr = rate_new(&m, s);
574       ratefn = cp_rate;
575       xfree(s);
576     }
577     if (argc != optind)
578       die(1, "bad parameters");
579   }
580
581   switch (h) {
582     case 1: {
583       hpc *hh = hpc_new(&m);
584       if (!rr) {
585         dig *s = xmalloc(m.k * sizeof(dig));
586         int i;
587         srand(time(0));
588         for (i = 0; i < m.k; i++)
589           s[i] = rand() % m.n;
590         rr = rate_new(&m, s);
591         ratefn = cp_rate;
592         xfree(s);
593       }
594       n = play(&m, ratefn, rr, hp_guess, hp_update, hh);
595       hpc_free(hh);
596     } break;
597     case 0: {
598       cpc *cc = cpc_new(&m);
599       if (rr) 
600         n = play(&m, ratefn, rr, cp_guess, cp_update, cc);
601       else
602         n = play(&m, hp_rate, &m, cp_guess, cp_update, cc);
603       cpc_free(cc);
604     } break;
605     case 2: {
606       spc *ss = spc_new(&m);
607       n = play(&m, hp_rate, &m, sp_guess, sp_update, ss);
608       spc_free(ss);
609     } break;
610     default:
611       abort();
612   }
613   if (n > 0)
614     printf("Solved in %d guesses\n", n);
615   else
616     die(1, "gave up");
617   return (0);
618 }
619
620 /*----- That's all, folks -------------------------------------------------*/