chiark / gitweb /
Script to tabulate results as a web page.
[matchsticks-search.git] / partition.py
1 #!/usr/bin/env python
2
3 import sys, subprocess, math, getopt
4 from fractions import Fraction
5
6 import bounds
7
8 verbose = False
9 def vprint(*args):
10     # In verbose mode, print diagnostics.
11     if verbose:
12         sys.stdout.write(" ".join(map(str,args)) + "\n")
13
14 def find_partitions(n, minfrag, maxfrag=None, prefix=(), ret=None):
15     """Find all partitions of n into integer pieces at least minfrag.
16
17 Returns a list of tuples.
18
19 For recursive calls: appends to an existing 'ret' if none is
20 specified."""
21     if ret is None:
22         ret = [] # new array
23     maxfrag = min(maxfrag, n) if maxfrag is not None else n
24     if n == 0:
25         ret.append(prefix)
26     else:
27         for frag in xrange(minfrag, maxfrag+1):
28             find_partitions(n-frag, frag, maxfrag, prefix+(frag,), ret)
29     return ret
30
31 def try_one_minfrag(n, m, minfrag, d):
32     # Find all possible partitions of the two stick lengths.
33     nparts = find_partitions(n*d, minfrag, m*d)
34     mparts = find_partitions(m*d, minfrag)
35
36     # Winnow by discarding any partition using a length not present in
37     # the other.
38     while True:
39         vprint("Partitions of %d:" % n)
40         for np in nparts:
41             vprint("   ", np)
42         vprint("Partitions of %d:" % m)
43         for mp in mparts:
44             vprint("   ", mp)
45
46         oldlens = len(nparts), len(mparts)
47         nlengths = set(sum(nparts, ()))
48         mlengths = set(sum(mparts, ()))
49         new_nparts = []
50         for np in nparts:
51             s = set([k for k in np if k not in mlengths])
52             if len(s) == 0:
53                 new_nparts.append(np)
54             else:
55                 vprint("Winnowing %s (can't use %s)" % (
56                     np, ", ".join(map(str,sorted(s)))))
57         new_mparts = []
58         for mp in mparts:
59             s = set([k for k in mp if k not in nlengths])
60             if len(s) == 0:
61                 new_mparts.append(mp)
62             else:
63                 vprint("Winnowing %s (can't use %s)" % (
64                     mp, ", ".join(map(str,sorted(s)))))
65         nparts = new_nparts
66         mparts = new_mparts
67         if oldlens == (len(nparts), len(mparts)):
68             break # we have converged
69
70     if len(nparts) == 0 or len(mparts) == 0:
71         vprint("No partitions available.")
72         return None
73     # Now we need to look for an integer occurrence count of each
74     # nparts row, summing to m, and one for each mparts row, summing
75     # to n, with the right constraints. We do this by appealing to an
76     # ILP solver :-)
77     ilp_file = ""
78     nvarnames = {}
79     for np in nparts:
80         nvarnames[np] = "_".join(["n"] + map(str,np))
81     mvarnames = {}
82     for mp in mparts:
83         mvarnames[mp] = "_".join(["m"] + map(str,mp))
84     varlist = sorted(nvarnames.values()) + sorted(mvarnames.values())
85     # Have to try to minimise _something_!
86     ilp_file += "min: %s;\n" % " + ".join(["%d * %s" % (i+1, name) for i, name in enumerate(varlist)])
87     for var in varlist:
88         ilp_file += "%s >= 0;\n" % var
89     ilp_file += " + ".join(sorted(nvarnames.values())) + " = %d;\n" % m
90     ilp_file += " + ".join(sorted(mvarnames.values())) + " = %d;\n" % n
91     assert nlengths == mlengths
92     for k in nlengths:
93         ns = []
94         for np in nparts:
95             count = len([x for x in np if x == k])
96             if count > 0:
97                 ns.append((count, nvarnames[np]))
98         ms = []
99         for mp in mparts:
100             count = len([x for x in mp if x == k])
101             if count > 0:
102                 ms.append((count, mvarnames[mp]))
103         ilp_file += " + ".join(["%d * %s" % t for t in ns]) + " = "
104         ilp_file += " + ".join(["%d * %s" % t for t in ms]) + ";\n"
105     for var in sorted(nvarnames.values()) + sorted(mvarnames.values()):
106         ilp_file += "int %s;\n" % var
107
108     p = subprocess.Popen(["lp_solve", "-lp"], stdin=subprocess.PIPE,
109                          stdout=subprocess.PIPE)
110     stdout, stderr = p.communicate(ilp_file)
111     if p.wait() != 0:
112         vprint("ILP solver failed")
113         return None
114     else:
115         ncounts = {}
116         mcounts = {}
117         for line in stdout.splitlines():
118             words = line.split()
119             if len(words) == 0:
120                 pass # rule out for future elifs
121             elif words[0][:2] == "n_" and words[1] != "0":
122                 ncounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
123             elif words[0][:2] == "m_" and words[1] != "0":
124                 mcounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
125         return ncounts, mcounts
126
127 def search(n, m):
128     if n % m == 0:
129         # Trivial special case.
130         return (m, 1, {(m,):n}, {(m,)*(n/m):m})
131     best = (0,)
132     bound, _ = bounds.upper_bound(n, m)
133     for d in xrange(1, n+1):
134         for k in xrange(int(bound*d), int(math.ceil(best[0]*d))-1, -1):
135             result = try_one_minfrag(n, m, k, d)
136             if result is not None:
137                 best = (Fraction(k, d), d) + result
138                 break
139         if best == bound:
140             break # terminate early if we know it's the best answer
141     return best
142
143 def search_and_report(n, m):
144     best = search(n, m)
145     d = best[1]
146     print "%d into %d best min fragment found: %s" % (n, m, best[0])
147     print "  Cut up %d sticks of length %d like this:" % (n, m)
148     for row, count in sorted(best[3].items(), reverse=True):
149         print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in row]))
150     print "  Reassemble as %d sticks of length %d like this:" % (m, n)
151     for col, count in sorted(best[2].items(), reverse=True):
152         print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in col]))
153
154 def main():
155     global verbose
156     opts, args = getopt.gnu_getopt(sys.argv[1:], "v")
157     for opt, val in opts:
158         if opt == "-v":
159             verbose = True
160         else:
161             assert False, "unrecognised option '%s'" % opt
162     m, n = sorted(map(int,args[:2]))
163     search_and_report(n, m)
164
165 if __name__ == "__main__":
166     main()