chiark / gitweb /
df6899206b4793b3bf9fab0f6c58a748ea67f52d
[mm] / mm.c
1 /* -*-c-*-
2  *
3  * $Id$
4  *
5  * Simple mastermind game
6  *
7  * (c) 2006 Mark Wooding
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of mm: a simple Mastermind game.
13  *
14  * mm is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU General Public License as published by
16  * the Free Software Foundation; either version 2 of the License, or
17  * (at your option) any later version.
18  * 
19  * mm is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU General Public License for more details.
23  * 
24  * You should have received a copy of the GNU General Public License
25  * along with mm; if not, write to the Free Software Foundation,
26  * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
27  */
28
29 /*----- Header files ------------------------------------------------------*/
30
31 #include <ctype.h>
32 #include <stdio.h>
33 #include <stdlib.h>
34 #include <string.h>
35 #include <time.h>
36
37 #include <mLib/alloc.h>
38 #include <mLib/mdwopt.h>
39 #include <mLib/quis.h>
40 #include <mLib/report.h>
41 #include <mLib/sub.h>
42
43 /*----- Data structures ---------------------------------------------------*/
44
45 /* --- Digits --- *
46  *
47  * The symbols which make up the code to be guessed.
48  */
49
50 typedef unsigned char dig;
51
52 /* --- The game parameters --- */
53
54 typedef struct mm {
55   dig k;                                /* Number of symbols in the code */
56   dig n;                                /* Number of distinct symbols */
57 } mm;
58
59 /*----- Rating guesses ----------------------------------------------------*/
60
61 /* --- Algorithm --- *
62  *
63  * Rating guesses efficiently is quite important.
64  *
65  * The rate context structure contains a copy of the game parameters, and
66  * three arrays, @v@, @s@ and @t@ allocated from the same @malloc@ed block:
67  *
68  *   * %$v_i$% counts occurrences of the symbol %$i$% in the code.
69  *   * %$s$% is a copy of the code.
70  *   * %$t$% is temporary work space for the rating function.
71  *
72  * The rating function works by taking a pass over the guess, computing a
73  * count table %$v'$%; but for each guess symbol which matches the
74  * correspondng code symbol, decrementing the count %$v_i$% for that symbol
75  * (in a temporary copy of the table %$v$%).  The black count is then the
76  * number of such matches, and the white count is given by
77  *
78  *   %$w = \displaystyle \sum_{0<i\le n} \min(v_i, v'_i).$%
79  *
80  * Thus, the work is %$O(k + n)$%, rather than the %$O(k^2)$% for a
81  * %%na\"\i{}ve%% implementation.
82  */
83
84 typedef struct ratectx {
85   mm m;
86   dig *v;
87   dig *s;
88   dig *t;
89 } ratectx;
90
91 static ratectx *rate_alloc(const mm *m)
92 {
93   ratectx *r;
94   dig *v;
95
96   r = CREATE(ratectx);
97   v = xmalloc((3 * m->n + m->k) * sizeof(dig));
98   r->m = *m;
99   r->v = v;
100   r->s = r->v + m->n;
101   r->t = r->s + m->k;
102   return (r);
103 }
104
105 static void rate_init(ratectx *r, const dig *s)
106 {
107   unsigned i;
108
109   memset(r->v, 0, r->m.n * sizeof(dig));
110   for (i = 0; i < r->m.k; i++)
111     r->v[s[i]]++;
112   memcpy(r->s, s, r->m.k * sizeof(dig));
113 }
114
115 static ratectx *rate_new(const mm *m, const dig *s)
116 {
117   ratectx *r = rate_alloc(m);
118
119   rate_init(r, s);
120   return (r);
121 }
122
123 static void rate(const ratectx *r, const dig *g, unsigned  *b, unsigned *w)
124 {
125   unsigned i;
126   unsigned k = r->m.k, n = r->m.n;
127   dig *v = r->t;
128   dig *vv = v + n;
129   const dig *s = r->s;
130   unsigned bb = 0, ww = 0;
131
132   memset(v, 0, n * sizeof(dig));
133   memcpy(vv, r->v, n * sizeof(dig));
134   for (i = 0; i < k; i++) {
135     if (g[i] != s[i])
136       v[g[i]]++;
137     else {
138       vv[g[i]]--;
139       bb++;
140     }
141   }
142   for (i = 0; i < n; i++)
143     ww += v[i] < vv[i] ? v[i] : vv[i];
144   *b = bb;
145   *w = ww;
146 }
147
148 static void rate_free(ratectx *r) { xfree(r->v); DESTROY(r); }
149
150 /*----- Computer player ---------------------------------------------------*/
151
152 /* --- About the algorithms --- *
153  *
154  * At each stage, we attampt to choose the guess which will give us the most
155  * information, regardless of the outcome.  For each guess candidate, we
156  * count the remaining possible codes for each outcome, and choose the
157  * candidate with the least square sum.  There are wrinkles.
158  *
159  * Firstly the number of possible guesses is large, and the number of
160  * possible codes is huge too; and our algorithm takes time proportional to
161  * the product of the two.  However, a symbol we've never tried before is as
162  * good as any other, so we can narrow the list of candidate guesses by
163  * considering only %%\emph{prototypes}%% where we use only the smallest
164  * untried symbol at any point to represent introducing any new symbol.  The
165  * number of initial prototypes is quite small.  For the four-symbol game,
166  * they are 0000, 0001, 0011, 0012, 0111, 0112, 0122, and 0123.
167  *
168  * Secondly, when the number of possible codes become small, we want to bias
169  * the guess selection algorithm towards possible codes (so that we win if
170  * we're lucky).  Since the algorithm chooses the first guess with the lowest
171  * sum-of-squares value, we simply run through the possible codes before
172  * enumerating the prototype guesses.
173  */
174
175 typedef struct cpc {
176   mm m;                                 /* Game parameters */
177   dig *s; /* n^k * k */                 /* Remaining guesses */
178   size_t ns;                            /* Number of remaining guesses */
179   dig *bg; /* k */                      /* Current best guess */
180   dig *t; /* k */                       /* Scratch-space for prototype */
181   double bmax;                          /* Best guess least-squares score */
182   dig x, bx;                            /* Next unused symbol index */
183   size_t *v; /* (k + 1)^2 */            /* Bin counts for least-squares */
184   ratectx *r;                           /* Rate context for search */
185 } cpc;
186
187 static void print_guess(const mm *m, const dig *d)
188 {
189   unsigned k = m->k, i;
190
191   for (i = 0; i < k; i++) {
192     if (i) putchar(' ');
193     printf("%u", d[i]);
194   }
195 }
196
197 static void dofep(cpc *c, void (*fn)(cpc *c, const dig *g, unsigned x),
198                   unsigned k, unsigned n, unsigned i, unsigned x)
199 {
200   unsigned j;
201   dig *t = c->t;
202
203   if (i == k)
204     fn(c, c->t, x);
205   else {
206     for (j = 0; j < x; j++) {
207       t[i] = j;
208       dofep(c, fn, k, n, i + 1, x);
209     }
210     if (x < n) {
211       t[i] = x;
212       dofep(c, fn, k, n, i + 1, x + 1);
213     }
214   }
215 }
216
217 static void foreach_proto(cpc *c, void (*fn)(cpc *c,
218                                              const dig *g,
219                                              unsigned x))
220
221   unsigned k = c->m.k, n = c->m.n;
222
223   dofep(c, fn, k, n, 0, c->x);
224 }
225
226 static void try_guess(cpc *c, const dig *g, unsigned x)
227 {
228   size_t i;
229   unsigned b, w;
230   const dig *s;
231   unsigned k = c->m.k;
232   size_t *v = c->v;
233   size_t *vp;
234   double max;
235
236   rate_init(c->r, g);
237   memset(v, 0, (k + 1) * (k + 1) * sizeof(size_t));
238   for (i = c->ns, s = c->s; i; i--, s += k) {
239     rate(c->r, s, &b, &w);
240     v[b * (k + 1) + w]++;
241   }
242   max = 0;
243   for (i = (k + 1) * (k + 1), vp = v; i; i--, vp++)
244     max += (double)*vp * (double)*vp;
245   if (c->bmax < 0 || max < c->bmax) {
246     memcpy(c->bg, g, k * sizeof(dig));
247     c->bmax = max;
248     c->bx = x;
249   }
250 }
251
252 static void best_guess(cpc *c)
253 {
254   c->bmax = -1;
255   if (c->ns < 1024) {
256     unsigned k = c->m.k;
257     const dig *s;
258     size_t i;
259
260     for (i = c->ns, s = c->s; i; i--, s += k)
261       try_guess(c, s, c->x);
262   }
263   foreach_proto(c, try_guess);
264   c->x = c->bx;
265 }
266
267 static void filter_guesses(cpc *c, const dig *g, unsigned b, unsigned w)
268 {
269   unsigned k = c->m.k;
270   size_t i;
271   const dig *s;
272   unsigned bb, ww;
273   dig *ss;
274
275   rate_init(c->r, g);
276   for (i = c->ns, s = ss = c->s; i; i--, s += k) {
277     rate(c->r, s, &bb, &ww);
278     if (b == bb && w == ww) {
279       memmove(ss, s, k * sizeof(dig));
280       ss += k;
281     }
282   }
283   c->ns = (ss - c->s) / k;
284 }
285
286 static size_t ipow(size_t b, size_t x)
287 {
288   size_t a = 1;
289   while (x) {
290     if (x & 1)
291       a *= b;
292     b *= b;
293     x >>= 1;
294   }
295   return (a);
296 }
297
298 static void all_guesses(dig **ss, unsigned k, unsigned n,
299                         unsigned i, const dig *b)
300 {
301   unsigned j;
302
303   if (i == k) {
304     (*ss) += k;
305     return;
306   }
307   for (j = 0; j < n; j++) {
308     dig *s = *ss;
309     if (i)
310       memcpy(*ss, b, i * sizeof(dig));
311     s[i] = j;
312     all_guesses(ss, k, n, i + 1, s);
313   }
314 }
315
316 #define THINK(what, how) do {                                           \
317   clock_t _t0, _t1;                                                     \
318   fputs(what "...", stdout);                                            \
319   fflush(stdout);                                                       \
320   _t0 = clock();                                                        \
321   do how while (0);                                                     \
322   _t1 = clock();                                                        \
323   printf(" done (%.2fs)\n", (_t1 - _t0)/(double)CLOCKS_PER_SEC);        \
324 } while (0)
325
326 static cpc *cpc_new(const mm *m)
327 {
328   cpc *c = CREATE(cpc);
329   c->m = *m;
330   c->ns = ipow(c->m.n, c->m.k);
331   c->s = xmalloc((c->ns + 2) * c->m.k * sizeof(dig));
332   c->bg = c->s + c->ns * c->m.k;
333   c->t = c->bg + c->m.k;
334   c->x = 0;
335   c->v = xmalloc((c->m.k + 1) * (c->m.k + 1) * sizeof(size_t));
336   c->r = rate_alloc(m);
337   THINK("Setting up", {
338     dig *ss = c->s; all_guesses(&ss, c->m.k, c->m.n, 0, 0);
339   });
340   return (c);
341 }
342
343 static void cpc_free(cpc *c)
344 {
345   xfree(c->s);
346   xfree(c->v);
347   rate_free(c->r);
348   DESTROY(c);
349 }
350
351 static void cp_rate(void *r, const dig *g, unsigned *b, unsigned *w)
352   { rate(r, g, b, w); }
353
354 static const dig *cp_guess(void *cc)
355 {
356   cpc *c = cc;
357
358   if (c->ns == 0) {
359     printf("Liar!  All solutions ruled out.\n");
360     return (0);
361   }
362   if (c->ns == 1) {
363     fputs("Done!  Solution is ", stdout);
364     print_guess(&c->m, c->s);
365     putchar('\n');
366     return (c->s);
367   }
368   printf("(Possible solutions remaining = %lu)\n",
369          (unsigned long)c->ns);
370   if (c->ns < 32) {
371     const dig *s;
372     size_t i;
373     for (i = c->ns, s = c->s; i; i--, s += c->m.k) {
374       printf("  %2lu: ", (unsigned long)(c->ns - i + 1));
375       print_guess(&c->m, s);
376       putchar('\n');
377     }
378   }
379   THINK("Pondering", {
380     best_guess(c);
381   });
382   return (c->bg);
383 }
384
385 static void cp_update(void *cc, const dig *g, unsigned b, unsigned w)
386 {
387   cpc *c = cc;
388   fputs("My guess = ", stdout);
389   print_guess(&c->m, g);
390   printf("; rating = %u black, %u white\n", b, w);
391   THINK("Filtering", {
392     filter_guesses(c, g, b, w);
393   });
394 }
395
396 /*----- Human player ------------------------------------------------------*/
397
398 typedef struct hpc {
399   mm m;
400   dig *t;
401 } hpc;
402
403 static hpc *hpc_new(const mm *m)
404 {
405   hpc *h = CREATE(hpc);
406   h->t = xmalloc(m->k * sizeof(dig));
407   h->m = *m;
408   return (h);
409 }
410
411 static void hpc_free(hpc *h)
412 {
413   xfree(h->t);
414   DESTROY(h);
415 }
416
417 static void hp_rate(void *mp, const dig *g, unsigned *b, unsigned *w)
418 {
419   mm *m = mp;
420   fputs("Guess = ", stdout);
421   print_guess(m, g);
422   printf("; rating: ");
423   fflush(stdout);
424   scanf("%u %u", b, w);
425 }
426
427 static const dig *hp_guess(void *hh)
428 {
429   hpc *h = hh;
430   unsigned i;
431
432   fputs("Your guess: ", stdout);
433   fflush(stdout);
434   for (i = 0; i < h->m.k; i++) {
435     unsigned x;
436     scanf("%u", &x);
437     h->t[i] = x;
438   }
439   return (h->t);
440 }
441
442 static void hp_update(void *cc, const dig *g, unsigned b, unsigned w)
443 {
444   printf("Rating = %u black, %u white\n", b, w);
445 }
446
447 /*----- Solver player -----------------------------------------------------*/
448
449 typedef struct spc {
450   cpc *c;
451   hpc *h;
452   int i;
453 } spc;
454
455 static spc *spc_new(const mm *m)
456 {
457   spc *s = CREATE(spc);
458   s->c = cpc_new(m);
459   s->h = hpc_new(m);
460   s->i = 0;
461   return (s);
462 }
463
464 static void spc_free(spc *s)
465 {
466   cpc_free(s->c);
467   hpc_free(s->h);
468   DESTROY(s);
469 }
470
471 static const dig *sp_guess(void *ss)
472 {
473   spc *s = ss;
474   hpc *h = s->h;
475   unsigned i;
476   int ch;
477
478 again:
479   if (s->i)
480     return (cp_guess(s->c));
481
482   fputs("Your guess (dot for end): ", stdout);
483   fflush(stdout);
484   do ch = getchar(); while (isspace(ch));
485   if (!isdigit(ch)) { s->i = 1; goto again; }
486   ungetc(ch, stdin);
487   for (i = 0; i < h->m.k; i++) {
488     unsigned x;
489     scanf("%u", &x);
490     h->t[i] = x;
491   }
492   return (h->t);    
493 }
494
495 static void sp_update(void *ss, const dig *g, unsigned b, unsigned w)
496   { spc *s = ss; cp_update(s->c, g, b, w); }
497
498 /*----- Main game logic ---------------------------------------------------*/
499
500 static int play(const mm *m,
501                 void (*ratefn)(void *rr, const dig *g,
502                                unsigned *b, unsigned *w),
503                 void *rr,
504                 const dig *(*guessfn)(void *gg),
505                 void (*updatefn)(void *gg, const dig *g,
506                                  unsigned b, unsigned w),
507                 void *gg)
508 {
509   unsigned b, w;
510   const dig *g;
511   unsigned i;
512
513   i = 0;
514   for (;;) {
515     i++;
516     g = guessfn(gg);
517     if (!g)
518       return (-1);
519     ratefn(rr, g, &b, &w);
520     if (b == m->k)
521       return (i);
522     updatefn(gg, g, b, w);
523   }
524 }
525
526 int main(int argc, char *argv[])
527 {
528   unsigned h = 0;
529   void *rr = 0;
530   void (*ratefn)(void *rr, const dig *g, unsigned *b, unsigned *w) = 0;
531   mm m;
532   int n;
533
534   ego(argv[0]);
535   for (;;) {
536     static struct option opt[] = {
537       { "computer",     0,      0,      'C' },
538       { "human",        0,      0,      'H' },
539       { "solver",       0,      0,      'S' },
540       { 0,              0,      0,      0 }
541     };
542     int i = mdwopt(argc, argv, "CHS", opt, 0, 0, 0);
543     if (i < 0)
544       break;
545     switch (i) {
546       case 'C': h = 0; break;
547       case 'H': h = 1; break;
548       case 'S': h = 2; break;
549       default:
550         exit(1);
551     }
552   }
553   if (argc - optind == 0) {
554     m.k = 4;
555     m.n = 6;
556   } else if (argc - optind < 2)
557     die(1, "bad parameters");
558   else {
559     m.k = atoi(argv[optind++]);
560     m.n = atoi(argv[optind++]);
561     if (argc - optind >= m.k) {
562       dig *s = xmalloc(m.k * sizeof(dig));
563       int i;
564       for (i = 0; i < m.k; i++)
565         s[i] = atoi(argv[optind++]);
566       rr = rate_new(&m, s);
567       ratefn = cp_rate;
568       xfree(s);
569     }
570     if (argc != optind)
571       die(1, "bad parameters");
572   }
573
574   switch (h) {
575     case 1: {
576       hpc *hh = hpc_new(&m);
577       if (!rr) {
578         dig *s = xmalloc(m.k * sizeof(dig));
579         int i;
580         srand(time(0));
581         for (i = 0; i < m.k; i++)
582           s[i] = rand() % m.n;
583         rr = rate_new(&m, s);
584         ratefn = cp_rate;
585         xfree(s);
586       }
587       n = play(&m, ratefn, rr, hp_guess, hp_update, hh);
588       hpc_free(hh);
589     } break;
590     case 0: {
591       cpc *cc = cpc_new(&m);
592       if (rr) 
593         n = play(&m, ratefn, rr, cp_guess, cp_update, cc);
594       else
595         n = play(&m, hp_rate, &m, cp_guess, cp_update, cc);
596       cpc_free(cc);
597     } break;
598     case 2: {
599       spc *ss = spc_new(&m);
600       n = play(&m, hp_rate, &m, sp_guess, sp_update, ss);
601       spc_free(ss);
602     } break;
603     default:
604       abort();
605   }
606   if (n > 0)
607     printf("Solved in %d guesses\n", n);
608   else
609     die(1, "gave up");
610   return (0);
611 }
612
613 /*----- That's all, folks -------------------------------------------------*/