chiark / gitweb /
An entirely different search program.
[matchsticks-search.git] / partition.py
diff --git a/partition.py b/partition.py
new file mode 100755 (executable)
index 0000000..d0c0dd6
--- /dev/null
@@ -0,0 +1,152 @@
+#!/usr/bin/env python
+
+import sys, subprocess, math
+from fractions import Fraction
+
+def vprint(*args):
+    # In verbose mode, print diagnostics.
+    pass # FIXME: implement verbose mode
+
+def find_partitions(n, minfrag, maxfrag=None, prefix=(), ret=None):
+    """Find all partitions of n into integer pieces at least minfrag.
+
+Returns a list of tuples.
+
+For recursive calls: appends to an existing 'ret' if none is
+specified."""
+    if ret is None:
+        ret = [] # new array
+    maxfrag = min(maxfrag, n) if maxfrag is not None else n
+    if n == 0:
+        ret.append(prefix)
+    else:
+        for frag in xrange(minfrag, maxfrag+1):
+            find_partitions(n-frag, frag, maxfrag, prefix+(frag,), ret)
+    return ret
+
+def try_one_minfrag(n, m, minfrag, d):
+    # Find all possible partitions of the two stick lengths.
+    nparts = find_partitions(n*d, minfrag, m*d)
+    mparts = find_partitions(m*d, minfrag)
+
+    # Winnow by discarding any partition using a length not present in
+    # the other.
+    while True:
+        vprint("Partitions of %d:" % n)
+        for np in nparts:
+            vprint("   ", np)
+        vprint("Partitions of %d:" % m)
+        for mp in mparts:
+            vprint("   ", mp)
+
+        oldlens = len(nparts), len(mparts)
+        nlengths = set(sum(nparts, ()))
+        mlengths = set(sum(mparts, ()))
+        new_nparts = []
+        for np in nparts:
+            s = set([k for k in np if k not in mlengths])
+            if len(s) == 0:
+                new_nparts.append(np)
+            else:
+                vprint("Winnowing %s (can't use %s)" % (
+                    np, ", ".join(map(str,sorted(s)))))
+        new_mparts = []
+        for mp in mparts:
+            s = set([k for k in mp if k not in nlengths])
+            if len(s) == 0:
+                new_mparts.append(mp)
+            else:
+                vprint("Winnowing %s (can't use %s)" % (
+                    mp, ", ".join(map(str,sorted(s)))))
+        nparts = new_nparts
+        mparts = new_mparts
+        if oldlens == (len(nparts), len(mparts)):
+            break # we have converged
+
+    if len(nparts) == 0 or len(mparts) == 0:
+        vprint("No partitions available.")
+        return None
+    # Now we need to look for an integer occurrence count of each
+    # nparts row, summing to m, and one for each mparts row, summing
+    # to n, with the right constraints. We do this by appealing to an
+    # ILP solver :-)
+    ilp_file = ""
+    nvarnames = {}
+    for np in nparts:
+        nvarnames[np] = "_".join(["n"] + map(str,np))
+    mvarnames = {}
+    for mp in mparts:
+        mvarnames[mp] = "_".join(["m"] + map(str,mp))
+    varlist = sorted(nvarnames.values()) + sorted(mvarnames.values())
+    # Have to try to minimise _something_!
+    ilp_file += "min: %s;\n" % " + ".join(["%d * %s" % (i+1, name) for i, name in enumerate(varlist)])
+    for var in varlist:
+        ilp_file += "%s >= 0;\n" % var
+    ilp_file += " + ".join(sorted(nvarnames.values())) + " = %d;\n" % m
+    ilp_file += " + ".join(sorted(mvarnames.values())) + " = %d;\n" % n
+    assert nlengths == mlengths
+    for k in nlengths:
+        ns = []
+        for np in nparts:
+            count = len([x for x in np if x == k])
+            if count > 0:
+                ns.append((count, nvarnames[np]))
+        ms = []
+        for mp in mparts:
+            count = len([x for x in mp if x == k])
+            if count > 0:
+                ms.append((count, mvarnames[mp]))
+        ilp_file += " + ".join(["%d * %s" % t for t in ns]) + " = "
+        ilp_file += " + ".join(["%d * %s" % t for t in ms]) + ";\n"
+    for var in sorted(nvarnames.values()) + sorted(mvarnames.values()):
+        ilp_file += "int %s;\n" % var
+
+    p = subprocess.Popen(["lp_solve", "-lp"], stdin=subprocess.PIPE,
+                         stdout=subprocess.PIPE)
+    stdout, stderr = p.communicate(ilp_file)
+    if p.wait() != 0:
+        vprint("ILP solver failed")
+        return None
+    else:
+        ncounts = {}
+        mcounts = {}
+        for line in stdout.splitlines():
+            words = line.split()
+            if len(words) == 0:
+                pass # rule out for future elifs
+            elif words[0][:2] == "n_" and words[1] != "0":
+                ncounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
+            elif words[0][:2] == "m_" and words[1] != "0":
+                mcounts[tuple(map(int, words[0][2:].split("_")))] = int(words[1])
+        return ncounts, mcounts
+
+def search(n, m):
+    if n % m == 0:
+        # Trivial special case.
+        return (m, 1, {(m,):n}, {(m,)*(n/m):m})
+    best = (0,)
+    for d in xrange(1, n+1):
+        for k in xrange(m*d/2, int(math.ceil(best[0]*d)), -1):
+            result = try_one_minfrag(n, m, k, d)
+            if result is not None:
+                best = (Fraction(k, d), d) + result
+                break
+    return best
+
+def search_and_report(n, m):
+    best = search(n, m)
+    d = best[1]
+    print "%d into %d best min fragment found: %s" % (n, m, best[0])
+    print "  Cut up %d sticks of length %d like this:" % (n, m)
+    for row, count in sorted(best[2].items(), reverse=True):
+        print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in row]))
+    print "  Reassemble as %d sticks of length %d like this:" % (m, n)
+    for col, count in sorted(best[3].items(), reverse=True):
+        print "    %d x (%s)" % (count, " + ".join([str(Fraction(k,d)) for k in col]))
+
+def main():
+    m, n = sorted(map(int,sys.argv[1:3]))
+    search_and_report(n, m)
+
+if __name__ == "__main__":
+    main()