chiark / gitweb /
Fix a bug causing some solutions to be missed, e.g. (16,5).
[matchsticks-search.git] / partition.py
1 #!/usr/bin/env python
2
3 import sys, subprocess, math, getopt
4 from fractions import Fraction
5
6 verbose = False
7 def vprint(*args):
8     # In verbose mode, print diagnostics.
9     if verbose:
10         sys.stdout.write(" ".join(map(str,args)) + "\n")
11
12 def find_partitions(n, minfrag, maxfrag=None, prefix=(), ret=None):
13     """Find all partitions of n into integer pieces at least minfrag.
14
15 Returns a list of tuples.
16
17 For recursive calls: appends to an existing 'ret' if none is
18 specified."""
19     if ret is None:
20         ret = [] # new array
21     maxfrag = min(maxfrag, n) if maxfrag is not None else n
22     if n == 0:
23         ret.append(prefix)
24     else:
25         for frag in xrange(minfrag, maxfrag+1):
26             find_partitions(n-frag, frag, maxfrag, prefix+(frag,), ret)
27     return ret
28
29 def try_one_minfrag(n, m, minfrag, d):
30     # Find all possible partitions of the two stick lengths.
31     nparts = find_partitions(n*d, minfrag, m*d)
32     mparts = find_partitions(m*d, minfrag)
33
34     # Winnow by discarding any partition using a length not present in
35     # the other.
36     while True:
37         vprint("Partitions of %d:" % n)
38         for np in nparts:
39             vprint("   ", np)
40         vprint("Partitions of %d:" % m)
41         for mp in mparts:
42             vprint("   ", mp)
43
44         oldlens = len(nparts), len(mparts)
45         nlengths = set(sum(nparts, ()))
46         mlengths = set(sum(mparts, ()))
47         new_nparts = []
48         for np in nparts:
49             s = set([k for k in np if k not in mlengths])
50             if len(s) == 0:
51                 new_nparts.append(np)
52             else:
53                 vprint("Winnowing %s (can't use %s)" % (
54                     np, ", ".join(map(str,sorted(s)))))
55         new_mparts = []
56         for mp in mparts:
57             s = set([k for k in mp if k not in nlengths])
58             if len(s) == 0:
59                 new_mparts.append(mp)
60             else:
61                 vprint("Winnowing %s (can't use %s)" % (
62                     mp, ", ".join(map(str,sorted(s)))))
63         nparts = new_nparts
64         mparts = new_mparts
65         if oldlens == (len(nparts), len(mparts)):
66             break # we have converged
67
68     if len(nparts) == 0 or len(mparts) == 0:
69         vprint("No partitions available.")
70         return None
71     # Now we need to look for an integer occurrence count of each
72     # nparts row, summing to m, and one for each mparts row, summing
73     # to n, with the right constraints. We do this by appealing to an
74     # ILP solver :-)
75     ilp_file = ""
76     nvarnames = {}
77     for np in nparts:
78         nvarnames[np] = "_".join(["n"] + map(str,np))
79     mvarnames = {}
80     for mp in mparts:
81         mvarnames[mp] = "_".join(["m"] + map(str,mp))
82     varlist = sorted(nvarnames.values()) + sorted(mvarnames.values())
83     # Have to try to minimise _something_!
84     ilp_file += "min: %s;\n" % " + ".join(["%d * %s" % (i+1, name) for i, name in enumerate(varlist)])
85     for var in varlist:
86         ilp_file += "%s >= 0;\n" % var
87     ilp_file += " + ".join(sorted(nvarnames.values())) + " = %d;\n" % m
88     ilp_file += " + ".join(sorted(mvarnames.values())) + " = %d;\n" % n
89     assert nlengths == mlengths
90     for k in nlengths:
91         ns = []
92         for np in nparts:
93             count = len([x for x in np if x == k])
94             if count > 0:
95                 ns.append((count, nvarnames[np]))
96         ms = []
97         for mp in mparts:
98             count = len([x for x in mp if x == k])
99             if count > 0:
100                 ms.append((count, mvarnames[mp]))
101         ilp_file += " + ".join(["%d * %s" % t for t in ns]) + " = "
102         ilp_file += " + ".join(["%d * %s" % t for t in ms]) + ";\n"
103     for var in sorted(nvarnames.values()) + sorted(mvarnames.values()):
104         ilp_file += "int %s;\n" % var
105
106     p = subprocess.Popen(["lp_solve", "-lp"], stdin=subprocess.PIPE,
107                          stdout=subprocess.PIPE)
108     stdout, stderr = p.communicate(ilp_file)
109     if p.wait() != 0:
110         vprint("ILP solver failed")
111         return None
112     else:
113         ncounts = {}
114         mcounts = {}
115         for line in stdout.splitlines():
116             words = line.split()
117             if len(words) == 0:
118                 pass # rule out for future elifs
119             elif words[0][:2] == "n_" and words[1] != "0":
120                 ncounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
121             elif words[0][:2] == "m_" and words[1] != "0":
122                 mcounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
123         return ncounts, mcounts
124
125 def search(n, m):
126     if n % m == 0:
127         # Trivial special case.
128         return (m, 1, {(m,):n}, {(m,)*(n/m):m})
129     best = (0,)
130     for d in xrange(1, n+1):
131         for k in xrange(m*d/2, int(math.ceil(best[0]*d))-1, -1):
132             result = try_one_minfrag(n, m, k, d)
133             if result is not None:
134                 best = (Fraction(k, d), d) + result
135                 break
136     return best
137
138 def search_and_report(n, m):
139     best = search(n, m)
140     d = best[1]
141     print "%d into %d best min fragment found: %s" % (n, m, best[0])
142     print "  Cut up %d sticks of length %d like this:" % (n, m)
143     for row, count in sorted(best[2].items(), reverse=True):
144         print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in row]))
145     print "  Reassemble as %d sticks of length %d like this:" % (m, n)
146     for col, count in sorted(best[3].items(), reverse=True):
147         print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in col]))
148
149 def main():
150     global verbose
151     opts, args = getopt.gnu_getopt(sys.argv[1:], "v")
152     for opt, val in opts:
153         if opt == "-v":
154             verbose = True
155         else:
156             assert False, "unrecognised option '%s'" % opt
157     m, n = sorted(map(int,args[:2]))
158     search_and_report(n, m)
159
160 if __name__ == "__main__":
161     main()