chiark / gitweb /
An entirely different search program.
[matchsticks-search.git] / partition.py
1 #!/usr/bin/env python
2
3 import sys, subprocess, math
4 from fractions import Fraction
5
6 def vprint(*args):
7     # In verbose mode, print diagnostics.
8     pass # FIXME: implement verbose mode
9
10 def find_partitions(n, minfrag, maxfrag=None, prefix=(), ret=None):
11     """Find all partitions of n into integer pieces at least minfrag.
12
13 Returns a list of tuples.
14
15 For recursive calls: appends to an existing 'ret' if none is
16 specified."""
17     if ret is None:
18         ret = [] # new array
19     maxfrag = min(maxfrag, n) if maxfrag is not None else n
20     if n == 0:
21         ret.append(prefix)
22     else:
23         for frag in xrange(minfrag, maxfrag+1):
24             find_partitions(n-frag, frag, maxfrag, prefix+(frag,), ret)
25     return ret
26
27 def try_one_minfrag(n, m, minfrag, d):
28     # Find all possible partitions of the two stick lengths.
29     nparts = find_partitions(n*d, minfrag, m*d)
30     mparts = find_partitions(m*d, minfrag)
31
32     # Winnow by discarding any partition using a length not present in
33     # the other.
34     while True:
35         vprint("Partitions of %d:" % n)
36         for np in nparts:
37             vprint("   ", np)
38         vprint("Partitions of %d:" % m)
39         for mp in mparts:
40             vprint("   ", mp)
41
42         oldlens = len(nparts), len(mparts)
43         nlengths = set(sum(nparts, ()))
44         mlengths = set(sum(mparts, ()))
45         new_nparts = []
46         for np in nparts:
47             s = set([k for k in np if k not in mlengths])
48             if len(s) == 0:
49                 new_nparts.append(np)
50             else:
51                 vprint("Winnowing %s (can't use %s)" % (
52                     np, ", ".join(map(str,sorted(s)))))
53         new_mparts = []
54         for mp in mparts:
55             s = set([k for k in mp if k not in nlengths])
56             if len(s) == 0:
57                 new_mparts.append(mp)
58             else:
59                 vprint("Winnowing %s (can't use %s)" % (
60                     mp, ", ".join(map(str,sorted(s)))))
61         nparts = new_nparts
62         mparts = new_mparts
63         if oldlens == (len(nparts), len(mparts)):
64             break # we have converged
65
66     if len(nparts) == 0 or len(mparts) == 0:
67         vprint("No partitions available.")
68         return None
69     # Now we need to look for an integer occurrence count of each
70     # nparts row, summing to m, and one for each mparts row, summing
71     # to n, with the right constraints. We do this by appealing to an
72     # ILP solver :-)
73     ilp_file = ""
74     nvarnames = {}
75     for np in nparts:
76         nvarnames[np] = "_".join(["n"] + map(str,np))
77     mvarnames = {}
78     for mp in mparts:
79         mvarnames[mp] = "_".join(["m"] + map(str,mp))
80     varlist = sorted(nvarnames.values()) + sorted(mvarnames.values())
81     # Have to try to minimise _something_!
82     ilp_file += "min: %s;\n" % " + ".join(["%d * %s" % (i+1, name) for i, name in enumerate(varlist)])
83     for var in varlist:
84         ilp_file += "%s >= 0;\n" % var
85     ilp_file += " + ".join(sorted(nvarnames.values())) + " = %d;\n" % m
86     ilp_file += " + ".join(sorted(mvarnames.values())) + " = %d;\n" % n
87     assert nlengths == mlengths
88     for k in nlengths:
89         ns = []
90         for np in nparts:
91             count = len([x for x in np if x == k])
92             if count > 0:
93                 ns.append((count, nvarnames[np]))
94         ms = []
95         for mp in mparts:
96             count = len([x for x in mp if x == k])
97             if count > 0:
98                 ms.append((count, mvarnames[mp]))
99         ilp_file += " + ".join(["%d * %s" % t for t in ns]) + " = "
100         ilp_file += " + ".join(["%d * %s" % t for t in ms]) + ";\n"
101     for var in sorted(nvarnames.values()) + sorted(mvarnames.values()):
102         ilp_file += "int %s;\n" % var
103
104     p = subprocess.Popen(["lp_solve", "-lp"], stdin=subprocess.PIPE,
105                          stdout=subprocess.PIPE)
106     stdout, stderr = p.communicate(ilp_file)
107     if p.wait() != 0:
108         vprint("ILP solver failed")
109         return None
110     else:
111         ncounts = {}
112         mcounts = {}
113         for line in stdout.splitlines():
114             words = line.split()
115             if len(words) == 0:
116                 pass # rule out for future elifs
117             elif words[0][:2] == "n_" and words[1] != "0":
118                 ncounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
119             elif words[0][:2] == "m_" and words[1] != "0":
120                 mcounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
121         return ncounts, mcounts
122
123 def search(n, m):
124     if n % m == 0:
125         # Trivial special case.
126         return (m, 1, {(m,):n}, {(m,)*(n/m):m})
127     best = (0,)
128     for d in xrange(1, n+1):
129         for k in xrange(m*d/2, int(math.ceil(best[0]*d)), -1):
130             result = try_one_minfrag(n, m, k, d)
131             if result is not None:
132                 best = (Fraction(k, d), d) + result
133                 break
134     return best
135
136 def search_and_report(n, m):
137     best = search(n, m)
138     d = best[1]
139     print "%d into %d best min fragment found: %s" % (n, m, best[0])
140     print "  Cut up %d sticks of length %d like this:" % (n, m)
141     for row, count in sorted(best[2].items(), reverse=True):
142         print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in row]))
143     print "  Reassemble as %d sticks of length %d like this:" % (m, n)
144     for col, count in sorted(best[3].items(), reverse=True):
145         print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in col]))
146
147 def main():
148     m, n = sorted(map(int,sys.argv[1:3]))
149     search_and_report(n, m)
150
151 if __name__ == "__main__":
152     main()