#!/usr/bin/env python3

# spigot demo program: implement two different multidimensional
# generalisations of Euclid's algorithm.

import sys
import argparse
import itertools
from functools import reduce
from collections import namedtuple
import spigot

# Python 2/3 agnosticism
try:
    xrange
except NameError:
    xrange = range

class Matrix(object):
    """Class implementing matrix arithmetic.

    Also implements vectors: a width-1 matrix is regarded as a vector,
    and allows several extra operations and has a different __repr__
    format.
    """

    def __init__(self, h, w, *values):
        """Create a matrix with given height, width and contents.

        After h and w, the remaining arguments 'values' are given in
        *column*-major order.
        """
        assert len(values) == w*h
        self.w = w
        self.h = h
        self.values = list(values)
        self.kind = "Matrix" if w > 1 else "Vector"

    def printable_dimensions(self):
        "Internal helper function for __repr__."
        return (self.h, self.w) if self.w > 1 else (self.h,)
    def __str__(self):
        return "<{}>".format(" | ".join(
            ",".join(str(self[r,c]) for r in range(self.h))
            for c in range(self.w)))
    def __repr__(self):
        return "{}({},{})".format(
            self.kind,
            ",".join("{:d}".format(d) for d in self.printable_dimensions()),
            ",".join(
                (" " if i % self.h == 0 else "") + repr(v)
                for i,v in enumerate(self.values)))

    def valueindex(self, key):
        try:
            r, c = key
        except TypeError:
            assert self.kind == "Vector"
            r, c = key, 0

        if not 0 <= r < self.h and 0 <= c < self.w:
            raise IndexError

        return self.h*c+r
    def __getitem__(self, key):
        return self.values[self.valueindex(key)]
    def __setitem__(self, key, value):
        self.values[self.valueindex(key)] = value

    def slice(self, rlo, rhi, clo=None, chi=None):
        """Extract an arbitrary rectangle from a matrix.

        The returned matrix consists of rows [rlo:rhi] and columns
        [clo:chi] of the original. (With Python semantics, i.e.
        treating the coordinates as spaces between elements, so that
        the dimensions of the matrix will be (rhi-rlo) by (chi-clo),
        not off by one.)

        For vectors, clo and chi can be omitted.
        """
        if clo is None or chi is None:
            assert clo is None and chi is None
            assert self.kind == "Vector"
            clo, chi = 0, 1
        return Matrix(rhi-rlo, chi-clo,
                      *[self[r,c] for c in range(clo,chi)
                        for r in range(rlo, rhi)])

    def compose(self, rhs):
        """Compose two matrices by matrix multiplication."""
        assert isinstance(rhs, Matrix)
        assert self.w == rhs.h
        return Matrix(self.h, rhs.w, *[
            sum(self[i,j] * rhs[j,k] for j in range(self.w))
            for k in range(rhs.w) for i in range(self.h)])

    def scale(self, scale):
        """Multiply a matrix by a scalar."""
        assert not isinstance(scale, Matrix)
        return Matrix(self.h, self.w, *[scale*v for v in self.values])

    def add(self, rhs):
        """Add two matrices elementwise."""
        assert self.w == rhs.w
        assert self.h == rhs.h
        return Matrix(self.h, self.w,
                      *[v+w for v,w in zip(self.values, rhs.values)])

    def subdeterminant(self, rs, cs, sign=+1):
        "Internal helper function for determinant and inverse."
        assert len(rs) == len(cs)

        # Base case: if we're being asked for the determinant of a 1x1
        # subset of a matrix, return just that element.
        c = cs[0]
        if len(cs) == 1:
            return self[rs[0],c] * sign

        cs_without = cs[1:]
        rs_without = lambda r: [i for i in rs if i != r]

        # Inductive case: iterate over each element of the first
        # included column, and for each one, add that element times
        # the sub-determinant of the remaining columns and rows, with
        # alternating sign.
        return sum(
            self[r,c] * self.subdeterminant(rs_without(r), cs_without, s)
            for r, s in zip(rs, itertools.cycle([+sign,-sign])))
    def determinant(self):
        """Return the determinant of a matrix."""
        return self.subdeterminant(range(self.h), range(self.w))
    def inverse(self):
        """Return the inverse of a matrix."""
        n = self.h
        assert self.w == n

        if n == 1: # trivial special case for which the inductive formula fails
            return Matrix(1, 1, 1/self[0,0])

        det = self.determinant()
        without = lambda k: [i for i in range(n) if i != k]

        # Use the subdeterminant helper function to construct the
        # elements of the adjugate matrix, which becomes the inverse
        # when divided by the determinant.
        return Matrix(n, n, *[
            self.subdeterminant(without(c), without(r), 1-(r+c)%2*2)/det
            for c in range(n) for r in range(n)])

    def transpose(self):
        """Transpose a matrix."""
        return Matrix(self.w, self.h, *[
            self[r,c] for r in range(self.h) for c in range(self.w)])

    def length(self):
        """Give the length of a Vector."""
        assert self.kind == "Vector"
        return spigot.hypot(*self.values)
    def cross(self, rhs):
        """Return the cross product of two 3-element vectors."""
        assert self.kind == "Vector"
        assert isinstance(rhs, Matrix)
        assert rhs.kind == "Vector"
        assert self.h == rhs.h == 3
        a, b, c = self.values
        u, v, w = rhs.values
        return Vector(3, b*w-c*v, c*u-a*w, a*v-b*u)

    def unit(self):
        """Scale the input vector to unit length."""
        assert self.kind == "Vector"
        return self.scale(1 / self.length())

    @classmethod
    def identity(cls, n):
        """Constructor function to make an identity matrix."""
        return cls(n, n, *[0 if r != c else 1
                           for c in range(n) for r in range(n)])

    @classmethod
    def fromvectors(cls, vecs):
        """Constructor function that takes a list of column vectors."""
        vs = list(vecs)
        w = len(vs)
        h = vs[0].h
        assert all(v.h == h for v in vs)
        assert all(v.w == 1 for v in vs)
        return cls(h, w, *reduce(lambda x,y:x+y, [v.values for v in vs]))

def Vector(h, *values):
    """Constructor function that makes a vector-shaped matrix."""
    return Matrix(h, 1, *values)

def BasisVector(h, i):
    """Convenience function to return a canonical basis vector."""
    return Vector(h, *[1 if i==j else 0 for j in range(h)])

# Trivial tuple type to store a vector and its projection into some
# space of interest. Used by both main algorithms, though with 'proj'
# being a different data type in each case.
VecProj = namedtuple("VecProj", "vec proj")

def euclid_linear(values, iterations, verbose=None):
    """Euclid's algorithm for making small linear combinations.

    Given a list of n input values defining a vector in R^n, generates
    a sequence of vectors with integer components whose dot products
    with the input vector are as small as can be found.

    The working state is n previous linear combinations (starting with
    the canonical basis vectors, i.e. the numbers themselves). The
    strategy is to identify the vector among those n with the smallest
    projection, and reduce all the rest modulo that one.
    """

    n = len(values)

    # Function that turns an integer vector into a VecProj containing
    # that vector and its dot product with the input.
    value_matrix = Matrix(1, n, *values)
    project = lambda v: VecProj(v, value_matrix.compose(v)[0])

    # Initial working state.
    vps = [project(BasisVector(n, i)) for i in range(n)]

    for iteration in iterations:
        if verbose:
            verbose.write("iteration #{:d} candidates:\n".format(
                iteration))
            for i, vp in enumerate(vps):
                verbose.write("  #{:d} = {} (proj={:.5e})\n".format(
                    i, vp.vec, vp.proj))

        # Find the vector with the currently smallest projection, and
        # output it.
        s, mvp = min(enumerate(vps), key=lambda pair: pair[1].proj)
        yield mvp.vec.values

        # Now reduce all the other vectors mod that one.
        if verbose:
            verbose.write("reducing mod #{:d}\n".format(s, mvp.vec))
        for i, rvp in enumerate(vps):
            if i == s:
                continue # don't accidentally reduce a vector mod itself

            # To reduce rvp mod mvp, identify the _real_ multiple of
            # mvp that you'd have to subtract from rbp to reduce its
            # projection to zero, and then round down to the nearest
            # integer.
            mult = (rvp.proj / mvp.proj).to_int(rmode=spigot.ROUND_DOWN)
            vps[i] = project(rvp.vec.add(mvp.vec.scale(-mult)))

            if verbose:
                verbose.write("  subtracting {:d} * #{:d} from #{:d}\n".format(
                    mult, s, i))

def euclid_ratio(components, iterations, verbose=None):
    """Euclid's algorithm for finding lattice points near a target line.

    Given a list of n input values defining a vector in R^n, generates
    a sequence of vectors with integer components whose perpendicular
    distances from the line of multiples of the input vector are as
    small as can be found.

    The working state is n previous vectors (starting with the
    canonical basis vectors). The strategy is to reduce each vector in
    turn modulo all the others at once, by projecting all n of the
    vectors into the (n-1)-dimensional space normal to the target
    vector, and representing the projection of the vector being
    reduced in terms of the basis formed by the remaining n-1, to
    decide what integer multiple of each one to subtract.
    """

    n = len(components)

    # Make a matrix of n-1 linearly independent vectors all normal to
    # the input, which we'll use to project each candidate vector into
    # the normal space in which we're trying to minimise it.
    #
    # The basis for this space doesn't have to have any particular
    # properties, like being orthogonal: we're going to change basis
    # again in the next computation, so anything like that would be
    # normalised out anyway. So we just make vectors with only two
    # nonzero elements, with component 1 being v_i and component i
    # being -v_1 (so that the dot product with v is obviously zero).
    values = []
    for i in range(n-1):
        vec = [0] * n
        vec[-1] = components[i]
        vec[i] = -components[-1]
        values.extend(vec)
    M = Matrix(n, n-1, *values).transpose()

    # Projection function that takes a candidate vector and wraps it
    # into a VecProj that also contains its (n-1)-dimensional image in
    # the target space.
    project = lambda v: VecProj(v, M.compose(v))

    # Initial working state.
    vps = [project(BasisVector(n, i)) for i in range(n)]

    # Count the number of reductions in which no progress could be
    # made, so that we can abort if we somehow find we can't reduce
    # _any_ vector. (I don't actually expect that to happen, though:
    # this is just a sanity check.)
    n_consecutive_failures = 0

    for iteration in iterations:
        if verbose:
            verbose.write("iteration #{:d} candidates:\n".format(
                iteration))
            for i, vp in enumerate(vps):
                verbose.write("  #{:d} = {} (proj={})\n".format(
                    i, vp.vec, vp.proj))
            verbose.write("trying to reduce ({})\n".format(vps[0].vec))

        # At each stage, we're reducing vps[0] mod all the others at
        # once (and then we'll cycle the vectors round). Build a
        # matrix out of the projections of all the rest, and then
        # apply its inverse to the projection of vps[0].
        basis = Matrix.fromvectors([vps[j].proj.scale(-1) for j in range(1,n)])
        counts = basis.inverse().compose(vps[0].proj)
        if verbose:
            verbose.write("  unrounded counts = {}\n".format(counts))

        # That's identified a set of real-valued coefficients such
        # that, if you added that multiple of vps[i] to vps[0] for
        # each i>0, you'd reduce vps[0].proj to exactly zero. But
        # we're constrained to only add _integer_ multiples, so now we
        # round those values down.
        rcounts = [max(0, c.floor_int()) for c in counts]
        if verbose:
            verbose.write("  rounded counts = {}\n".format(rcounts))

        if any(rc > 0 for rc in rcounts):
            # We've made progress! We can reset the consecutive-
            # failures counter.
            n_consecutive_failures = 0

            newvec = vps[0].vec
            for j, rc in enumerate(rcounts, 1):
                newvec = newvec.add(vps[j].vec.scale(rc))
            if verbose:
                verbose.write(
                    "  successful reduction to {}\n".format(newvec))
            newvp = project(newvec)
        else:
            # No luck with this vector; move it unchanged to the tail
            # of the list, and try again with the one that will now
            # end up at vps[0].
            if verbose:
                verbose.write("  can't reduce\n")

            n_consecutive_failures += 1
            assert n_consecutive_failures < n, (
                "unable to make progress reducing any vector!")

            newvp = vps[0]

        # Either way, put the newly constructed vector on the end of
        # the list.
        vps = vps[1:] + [newvp]

        # Return both the new vector, and the rounded multiples of the
        # others we used to construct it.
        yield newvp.vec.values, rcounts

def main():
    class ListOfAtLeastTwo(argparse.Action):
        def __call__(self, parser, args, values, option_string=None):
            if len(values) < 2:
                raise argparse.ArgumentError(
                    self, "expected at least two components")
            setattr(args, self.dest, values)

    def range_iterator(arg):
        return xrange(int(arg))

    parser = argparse.ArgumentParser(
        description="Run generalisations of Euclid's algorithm on more than "
        "two input values.")
    parser.add_argument("component", nargs="+", action=ListOfAtLeastTwo,
                        help="Components of the target ratio.")
    parser.add_argument("-v", "--verbose", action="store_const",
                        const=sys.stdout, help="Print detailed diagnostics.")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("-r", "--ratio", action="store_const",
                       dest="action", const=(euclid_ratio, lambda t:t[0]),
                       help="Find integer tuples approximating the n-ary "
                       "ratio of the input values.")
    group.add_argument("-C", "--ratio-coefficients", action="store_const",
                       dest="action", const=(euclid_ratio, lambda t:t[1]),
                       help="Print the multiples of each vector added to "
                       "the reduced one in ratio mode.")
    group.add_argument("-l", "--linear-combination", "--combination",
                       action="store_const",
                       dest="action", const=(euclid_linear, lambda x:x),
                       help="Find integer tuples representing small linear "
                       "combinations of the input values.")
    parser.add_argument("-i", "--iterations", type=int,
                        help="Number of iterations to perform.")
    args = parser.parse_args()

    inputs = [spigot.eval(c) for c in args.component]
    iterations = (itertools.count() if args.iterations is None
                  else range(args.iterations))
    algorithm, result_extractor = args.action
    generator = (result_extractor(result)
                 for result in algorithm(inputs, iterations, args.verbose))

    for values in generator:
        sys.stdout.write(" ".join("{:d}".format(x) for x in values) + "\n")
        sys.stdout.flush()

if __name__ == '__main__':
    main()
