#!/usr/bin/env python3

# Demo application of the spigot Python module.
#
# This program expects two integer arguments a,b, and will use a
# technique based on continued fractions to efficiently search for the
# smallest n such that a^n begins with the digits of b. (More exactly,
# such that frac(log2(a^n)) is between frac(log2(b)) and
# frac(log2(b+1)).)

import sys, os, string, argparse, traceback, signal, itertools, numbers
try:
    from math import gcd
except ImportError: # this moved modules between Python versions
    from fractions import gcd
import spigot

def debug(dbgclass, *args):
    if dbgclass in dbgclasses_enabled:
        sys.stdout.write("[{}] {}\n".format(dbgclass, " ".join(map(str,args))))

def format_in_base(n, base):
    assert base <= 36
    digits = string.digits + string.ascii_uppercase
    ret = ""
    while True:
        ret = digits[n % base] + ret
        n //= base
        if n == 0:
            return ret

def make_f(x):
    def f(n):
        if n == 0:
            return spigot.Spigot(0)
        return spigot.eval("frac(x*n)", {"x":x, "n":n})
    return f

def floor_within_limits(spig, lo, hi):
    """Return floor(spig), clamped to the range [lo,hi]."""

    # The purpose of this function is to avoid an exactness hazard if
    # spig happens to non-obviously evaluate to an exact integer
    # _outside_ the target range. If it evaluates to one inside the
    # range then there's nothing we can do, but if it's out of range
    # then we can at least detect that without wasting time on
    # figuring out exactly _what_ out-of-range thing it might be.
    if lo is not None and lo == hi:
        return lo
    if lo is not None and spig < lo+1:
        return lo
    if hi is not None and spig >= hi:
        return hi
    return int(spigot.floor(spig))
def ceil_within_limits(spig, lo, hi):
    """Return ceil(spig), clamped to the range [lo,hi]."""
    return -floor_within_limits(-spig, -hi, -lo)

class Firstupto(object):
    """Given some real k, return the smallest m such that frac(xm) < k.

We expect our caller to give us k values in decreasing sequence, and
we make use of that assumption to increase efficiency.

If our caller knows that frac(xM) is exactly equal to k for some M,
then passing that M as an extra parameter to our __call__ function
will cause us to notice that and not suffer an exactness hazard when
we reach it."""
    def __init__(self, x):
        self.xconv = x.to_convergents()
        self.b_low = next(self.xconv)[1]
        try:
            self.b_step = next(self.xconv)[1]
            self.b_high = next(self.xconv)[1]
        except StopIteration:
            debug("firstupto", "continued fraction terminated extremely early")
            self.b_step = 1
            self.b_high = self.b_low
        self.f = make_f(x)

    def __call__(self, k, M=None):
        debug("firstupto", "called with k = {} and M = {}".format(k,M))
        while True:
            debug("firstupto", "b = {", self.b_low,
                  self.b_step, self.b_high, "}")
            fl = self.f(self.b_low)
            fs = self.f(self.b_step)
            debug("firstupto", "fl = {}".format(fl))
            debug("firstupto", "fs = {}".format(fs))
            if self.b_low != M:
                debug("firstupto", "normal expression: (k-fl)/(fs-1)")
                i_sp = spigot.eval("(k-fl)/(fs-1)",
                                   {"fl":fl, "fs":fs, "k":k})
            else:
                # A roundabout alternative way to compute the same
                # thing, which avoids running into an exactness hazard
                # when the numerator (k-fl) of the above version would
                # have turned out to be exactly zero.
                fl1 = self.f(self.b_low+self.b_step)
                debug("firstupto", "fl1 = {}".format(fl1))
                debug("firstupto", "b_low = M: (k-fl1)/(fs-1) + 1")
                i_sp = spigot.eval("(k-fl1)/(fs-1) + 1",
                                   {"fl1":fl1, "fs":fs, "k":k})
            #debug("firstupto", "i_sp = {}".format(i_sp))
            upper_bound = (self.b_high - self.b_low) // self.b_step + 1
            if (M is not None and
                self.b_low <= M <= self.b_high and
                (M - self.b_low) % self.b_step == 0):
                # Special case: one of the values taken by the
                # arithmetic progression from b_low to b_high in steps
                # of b_step yields *exactly* k when given to f. That
                # means we can't return that value, so try each side
                # of it.
                i_avoid = (M - self.b_low) // self.b_step
                debug("firstupto", "have to avoid", i_avoid)
                debug("firstupto", "try [", 0, ",", i_avoid, "]")
                i = ceil_within_limits(i_sp, 0, i_avoid)
                if i == i_avoid:
                    debug("firstupto", "try (", i_avoid, ",", upper_bound, "]")
                    i = ceil_within_limits(i_sp, i_avoid+1, upper_bound)
            else:
                debug("firstupto", "try [", 0, ",", upper_bound, "]")
                i = ceil_within_limits(i_sp, 0, upper_bound)
            debug("firstupto", "i =", i)
            if self.b_low + i*self.b_step <= self.b_high:
                m = self.b_low + i*self.b_step
                debug("firstupto", "ok, m = {} {}".format(m, self.f(m)))
                return m
            else:
                debug("firstupto", "get more convergents")
                try:
                    self.b_low = self.b_high
                    self.b_step = next(self.xconv)[1]
                    self.b_high = next(self.xconv)[1]
                except StopIteration:
                    debug("firstupto", "continued fraction terminated")
                    return None

def to_integer(s):
    """Return the integer exactly equal to a spigot's value, or None."""
    if isinstance(s, numbers.Integral):
        return s
    assert isinstance(s, spigot.Spigot)
    rv = s.known_rational_value()
    if rv is None or rv[1] != 1:
        return None
    return rv[0]

def from_string(s, base=10):
    try:
        return int(s, base)
    except ValueError:
        return spigot.eval(s)

def writeaspowers(n, u, v):
    """Find unique integers ui,vi such that n = u^ui * v^vi.

One of ui,vi might be negative. If no such ui,vi exist, return
(None,None). We also return that if ui,vi are not unique, which can
only happen if log(u,v) is rational."""

    n = to_integer(n)
    u = to_integer(u)
    v = to_integer(v)
    if n is None or u is None or v is None:
        return None, None

    # The main strategy of writeaspowers() is to try to find two
    # 'atomic' factors of u,v. By 'atomic' I don't necessarily mean
    # prime: a prime would do if we happen to find out, but finding
    # prime factors is hard in general, and we don't need anything
    # that specific for this purpose. What we want is some number
    # which you can divide off u a certain number of times and then it
    # becomes immediately _coprime_ to u - i.e. no smaller factor of
    # it will come off. And you can do the same to v. Then anything
    # expressible as a product of powers of u,v has that same
    # property, and we've acquired a linear relation saying what the
    # powers ui,vi have to be - namely, if our atom g has multiplicity
    # gu in u, gv in v and gn in n, then we know ui*gu + vi*gv = gn.
    #
    # The technique for finding such an 'atom' is to start with
    # gcd(u,v), and keep trying to prove it _is_ an atom, by dividing
    # it off u as many times as possible and then seeing if any
    # fraction of it still survives (i.e. if the gcd of our candidate
    # atom and whatever is left of u is non-trivial). If one does, we
    # replace our atom candidate with that smaller fraction, and try
    # again. So we reduce the size of our candidate in every
    # iteration, and hence the algorithm must terminate.
    #
    # Once we've divided off this atom, we'll then look for a second
    # atom between the remains of u,v, and since this search technique
    # will find as large an atom as possible, it follows that the
    # second atom will have _different_ ui and vi - i.e. the second
    # atom gives us a second independent linear relation, and between
    # those two relations, we'll have enough information to nail down
    # a single pair of values that ui and vi will have to have if they
    # exist at all.

    def find_common_atom(u,v):
        """Find a number 'atomically' dividing u and v.

The returned value g should have the property that u/g^i and v/g^j,
for some i,j, are both integers coprime to g."""
        g = gcd(u,v)
        if g == 1:
            return None
        prev = max(u,v)
        while g < prev:
            prev = g
            # Try to 'split' our atom using each of u and v. If we
            # succeed in either attempt, g will reduce and the while
            # loop will go round again. If we fail, we'll terminate
            # the loop and return g.
            for w in u,v:
                # Divide off g as many times as we can.
                while w % g == 0:
                    w //= g
                # Is there a nontrivial common factor between g and
                # what's left of w?
                gg = gcd(g, w)
                if gg > 1:
                    # Yes, so let that be our new g.
                    g = gg
        return g

    def find_multiplicity(u, g):
        """Return i such that u/g^i is an integer coprime to g.

Returns None if no such i exists."""
        ret = 0
        # Divide off g as many times as we can, and count them.
        while u % g == 0:
            u //= g
            ret += 1
        # Now we want the remains of u to be coprime to g. If not,
        # fail.
        if gcd(u, g) != 1:
            return None
        return ret

    def find_atom_and_multiplicities(u, v):
        """Find an atom between u and v, and its multiplicities.

Returns a tuple (g,gu,gv) such that u/g^gu and v/g^gv are integers
coprime to g."""
        if u == 1 and v == 1:
            # Trivial case: nothing at all to find here.
            return None, None, None
        g = find_common_atom(u, v)
        if g == None:
            # We can make do with using u or v itself as an atom at a
            # pinch; it will have multiplicity 1 in one of u,v and 0
            # in the other, which is good enough for a linear
            # relation.
            g = u if u != 1 else v
        return g, find_multiplicity(u, g), find_multiplicity(v, g)

    debug("writeaspowers", "try to write", n, "as powers of", u, "and", v)

    # Look for our first atom, g0.
    g0, g0u, g0v = find_atom_and_multiplicities(u, v)
    debug("writeaspowers", "found atom", g0)
    assert g0 is not None

    # Divide off g0 from both u and v, and look for a second atom in
    # the remainders.
    u0 = u // g0 ** g0u
    v0 = v // g0 ** g0v
    g1, g1u, g1v = find_atom_and_multiplicities(u0, v0)

    if g1 is None:
        # This can only happen if u and v are both exact powers of g0.
        # In that case, any power of g0**gcd(g0u,g0v) is representable
        # in infinitely many ways, and anything that is not such a
        # power is not representable at all. We don't distinguish
        # nonexistence from nonuniqueness in this function's API, so
        # we can just fail immediately in both cases.
        debug("writeaspowers", "one-atom failure case")
        return None, None

    debug("writeaspowers", "found another atom", g1)

    # Our atoms should be coprime.
    assert gcd(g0, g1) == 1

    # Find the multiplicity of each atom in our target number n. If it
    # doesn't divide 'atomically' into n, there's no point going any
    # further - we'll have already proved that n is not any ratio of
    # powers of u,v.
    g0n = find_multiplicity(n, g0)
    if g0n is None:
        debug("writeaspowers", "no answer:", g0, "does not go evenly into", n)
        return None, None
    g1n = find_multiplicity(n, g1)
    if g1n is None:
        debug("writeaspowers", "no answer:", g1, "does not go evenly into", n)
        return None, None
    debug("writeaspowers", "multiplicities of", g0, g1, "in", n, ":", g0n, g1n)
    debug("writeaspowers", "multiplicities of", g0, g1, "in", u, ":", g0u, g1u)
    debug("writeaspowers", "multiplicities of", g0, g1, "in", v, ":", g0v, g1v)

    # Now we have our two linear relations in ui,vi, namely
    #
    #   g0u * ui + g0v * vi = g0n
    #   g1u * ui + g1v * vi = g1n
    #
    # I.e. we need to solve the matrix equation
    #
    #   (g0u g0v) (ui) = (g0n)
    #   (g1u g1v) (vi)   (g1n)
    #
    # which has solution
    #
    #   (ui) = 1/(g0u g1v - g0v g1u) ( g1v -g0v) (g0n)
    #   (vi)                         (-g1u  g0u) (g1n)

    det = g0u * g1v - g0v * g1u
    # We expect our two linear relations to be independent, i.e. the
    # determinant of the matrix is nonzero so the linear system has a
    # unique solution. I admit my proof of that was a _bit_ handwavy;
    # if I turn out to be wrong about it, the fix will be to go back
    # to the code that finds g1, check _there_ whether the two linear
    # relations are scalings of each other, and if so, amalgamate g0
    # and g1 into a larger atom and try again to find a new g1.
    assert det != 0

    # Numerators of the fractions giving ui and vi, with common
    # denominator det.
    ui_num = +g1v * g0n -g0v * g1n
    vi_num = -g1u * g0n +g0u * g1n
    debug("writeaspowers", "raw matrix solution is",
          ui_num, "/", det, "and", vi_num, "/", det)

    # Those fractions should have integer values, or else the solution
    # isn't valid.
    if ui_num % det or vi_num % det:
        return None, None
    ui = ui_num // det
    vi = vi_num // det
    debug("writeaspowers", "integer solution is", ui, "and", vi)

    # Now we know that ui,vi cannot be anything _other_ than the
    # values we've found. But we don't know that they _are_ the values
    # we've found - the remains of u,v other than our two atoms g0,g1
    # might still come out wrong. So we now check by actually
    # computing the product u^ui v^vi, and seeing if it is equal to n.
    #
    # (Of course, the product could fail to be an integer, if one or
    # both of our putative ui,vi is negative, so we must check that
    # too.)
    num = u**max(ui,0) * v**max(vi,0)
    denom = u**max(-ui,0) * v**max(-vi,0)
    if num % denom != 0 or num // denom != n:
        debug("writeaspowers", "failure:", num, "/", denom, "!=", n)
        return None, None

    debug("writeaspowers", "success! Returning", ui, vi)
    return ui, vi

def powbegin(a, b, base, minpower, mult_factor=1):
    debug("modmax", "a =", a, "b =", b, "base =", base, "minpower =", minpower)
    if minpower is None:
        # Default to choosing a minimum power which guarantees that
        # a^result >= b. It's pedantic and annoying to ask 'what power
        # of 2 begins with 100' and be told that 2^0 begins 1.00! I'd
        # much rather be told that 2^196 = 100433...06336. So the
        # default choice of minpower arranges that all the digits of b
        # will appear _before_ the decimal point.
        #
        # If a user really does want digits below the decimal point,
        # they can explicitly reset to the more straightforward
        # semantics by saying -m0.
        #
        # To calculate the right minimum power, spigot itself is the
        # easiest way - and using the dyadic log function guarantees
        # that this expression cannot hang.
        minpower = int(spigot.eval("ceil(log(b/f,a))",
                                   {"a":a, "b":b, "f":mult_factor}))
        debug("modmax", "auto minpower =", minpower)

    # If some power of a is equal to b times a power of the base, then
    # that's the absolute optimum answer, and we'll have to be aware
    # when we're coming up on it to avoid running into an exactness
    # hazard.
    target = spigot.eval("b/f", {"b":b, "f":mult_factor})
    exact_answer, _ = writeaspowers(target, a, base)
    debug("modmax", "exact_answer =", exact_answer)
    if exact_answer is not None:
        exact_answer -= minpower
        debug("modmax", "  adjusted to", exact_answer)
        if exact_answer < 0: exact_answer = None

    # If a power of a is similarly equal to b+1 times a power of the
    # base, then that's a value that is definitely _not_ the answer
    # (it's on the precise upper bound of our legal interval, but the
    # interval is open at the top end), and we'll want to tell
    # firstmodbelow to make sure to avoid trying to return it.
    target = spigot.eval("(b+1)/f", {"b":b, "f":mult_factor})
    exact_non_answer, _ = writeaspowers(target, a, base)
    debug("modmax", "exact_non_answer =", exact_non_answer)
    if exact_non_answer is not None:
        exact_non_answer -= minpower
        debug("modmax", "  adjusted to", exact_non_answer)
        if exact_non_answer < 0: exact_non_answer = None

    scope = {"a":a, "b":b, "base":base, "m":minpower, "f":mult_factor}
    x = spigot.log(a, base)
    if exact_answer == 0:
        lowerbound = spigot.Spigot(0)
    else:
        lowerbound = spigot.eval("frac(log(b/f,base)-m log(a,base))", scope)
    if exact_non_answer == 0:
        y = spigot.Spigot(1)
    else:
        y = spigot.eval("frac(log((b+1)/f,base)-m log(a,base))", scope)

    debug("modmax", "x = {}".format(x))
    debug("modmax", "target interval [{},{}]".format(lowerbound, y))

    if lowerbound > y:
        # Special case: minpower was already in the interval!
        return minpower

    f = make_f(x)
    firstupto = Firstupto(x)

    n = 0
    while True:
        # Find m = firstupto(y - f(n)).
        k = y - f(n)
        debug("modmax", "n =", n)
        debug("modmax", "need k = {}".format(k))
        if exact_non_answer is not None:
            debug("modmax", "exact_non_answer =", exact_non_answer)
            M = exact_non_answer - n
        else:
            M = None

        m = firstupto(k, M)
        debug("modmax", "got m =", m)
        if m is None or not f(m):
            # If firstupto returned None, or f(m) is actually zero,
            # that means the convergents of x ran out, so the value n
            # we already have is still the best we can do. So return
            # it, if it's good enough.
            if f(n) < lowerbound:
                debug("modmax", "convergents finished, no solution")
                return None
            debug("modmax", "convergents finished, returning", n)
            return n + minpower

        debug("modmax", "fm = {}".format(f(m)))
        i_sp = spigot.eval("(y-fn) / fm", {"fn":f(n), "fm":f(m), "y":y})
        if (exact_non_answer is not None and
            exact_non_answer >= n and
            (exact_non_answer - n) % m == 0):
            i_avoid = (exact_non_answer - n) // m
            debug("modmax", "avoiding", i_avoid)
            i = floor_within_limits(i_sp, i_avoid, None)
            if i == i_avoid:
                i = floor_within_limits(i_sp, None, i_avoid - 1)
        else:
            i = floor_within_limits(i_sp, None, None)
        debug("modmax", "i =", i)

        # Our possible answers are n, n+m, ..., n+im. Pick the first of
        # those which is at least lower_bound, if any is.
        j_sp = spigot.eval("(lowerbound-fn) / fm",
                           {"fn":f(n), "fm":f(m), "lowerbound":lowerbound})
        j_limit = i+1
        if (exact_answer is not None and
            exact_answer <= n + i*m and
            (exact_answer - n) % m == 0):
            j_limit = min(i, (exact_answer - n) // m)
        j = ceil_within_limits(j_sp, 0, j_limit)
        debug("modmax", "j =", j)

        if j <= i:
            debug("modmax", "done!")
            return n + j*m + minpower

        n = n + i*m

directed_test_cases = [
    ((2, 13, 10, 0), 17),  # nice easy basic case
    ((7, 10, 10, 0), 0),   # b is a power of base (trivial answer!)
    ((7, 9, 10, 0), 13),   # b+1 is a power of base
    ((7, 49, 10, 0), 2),   # b is a power of a
    ((7, 48, 10, 0), 228), # b+1 is a power of a
    ((2, 3, 5, 0), 4),     # same here, but exercises a different code path
    ((9, 3, 27, 0), 2),    # log_{base}(a) is rational, answer exists
    ((9, 4, 27, 0), None), # log_{base}(a) is rational, answer does not exist
    ((14645, 42487246, 36, 0), 34473395), # a larger case to prove it goes fast
    ((49, 2022, 10, 0), 183), # regression test for a past bug
    ((7, 49, 10, 3), 73),  # test minpower > 0
    ((7, 49, 10, None), 2), # test auto-discovery of minpower
    ((2, 7, 10, 46), 46),  # test minpower == exactly where we were aiming
    ((7, 49, 10, 2), 2),   # same again but a different code path
]

def test(testclasses):
    passes = 0
    fails = 0
    class TimeoutException(Exception): pass
    def sigalrm(signum, frame):
        raise TimeoutException
    def runtests(test_cases, passes, fails):
        for (a, b, base, minpower), expected in test_cases:
            desc = (
                ("power of {:d}{} starting '{}' in base {:d} "+
                 "should be {}").format(
                    a,
                    "" if minpower is None else " at least %d" % minpower,
                    format_in_base(b, base), base,
                    "None" if expected is None else "%d" % expected))
            desc += " [-b{:d} {} {}{}]".format(
                base, format_in_base(a, base), format_in_base(b, base),
                "" if minpower is None else " -m{:d}".format(minpower))
            debug("test", "test case:", desc)
            try:
                signal.alarm(10)
                got = powbegin(a, b, base, minpower)
                if got != expected:
                    sys.stdout.write("test FAILED: {} but got {}\n".format(
                        desc, got))
                    fails += 1
                else:
                    debug("test", "passed")
                    passes += 1
            except Exception as e:
                sys.stdout.write("test FAILED: {} but threw {!r}\n".format(
                    desc, e))
                debug("test",
                      "exception traceback:\n" + traceback.format_exc())
                fails += 1
            finally:
                signal.alarm(0)
        return passes, fails

    signal.signal(signal.SIGALRM, sigalrm)

    if "standard" in testclasses:
        debug("test", "running", len(directed_test_cases), "directed tests")
        passes, fails = runtests(directed_test_cases, passes, fails)
    if "soak" in testclasses:
        debug("testgen", "generating soak tests")
        soak_test_dict = {}
        max_digits = 6
        for base in 4,5,6,8,12: # numbers of the form: p, p^2, p^3, qp, qp^2
            max_value = base**max_digits
            for a in range(2, 2*base+1):
                power_limit = next(power for power in itertools.count()
                                   if a**power >= max_value)
                debug("testgen", "doing powers of", a, "in base", base,
                      "up to", power_limit)
                for power in range(power_limit-1, -1, -1):
                    # Compute the actual value of a^power.
                    b = a**power
                    debug("testgen", a, "^", power, "in base", base,
                          "=", format_in_base(b, base))

                    # Add zeroes on the end until it has the maximum
                    # number of digits we handle. (This is because,
                    # for example, if a larger power of a starts with
                    # '10' and this one is simply '1', we'll expect
                    # this smaller power to be returned even for the
                    # prefix '10', because it will be treated as 1.0.)
                    while b * base < max_value:
                        b *= base

                    # Insert each prefix of it into our list of
                    # outputs. Since we process powers in descending
                    # order, we can safely overwrite any value we find
                    # in our dictionary already.
                    while b > 0:
                        debug("testgen", "inserting (",
                              a, format_in_base(b, base), base, ") ->", power)
                        soak_test_dict[(a, b, base, 0)] = power
                        b //= base
        soak_test_cases = sorted(soak_test_dict.items())
        debug("test", "running", len(soak_test_cases), "soak tests")
        passes, fails = runtests(soak_test_cases, passes, fails)
    sys.stdout.write("passed {:d} failed {:d}\n".format(passes, fails))
    sys.exit(1 if fails > 0 else 0)

def main():
    parser = argparse.ArgumentParser(
        description='Find the smallest power of one number beginning with '
        'the digits of another.')
    parser.add_argument("a", nargs="?", help="Number to look for a power of.")
    parser.add_argument("b", nargs="?", help="Number whose digits the target"
                        " value should start with.")
    parser.add_argument("--minimum", "-m", type=int, help="Smallest power of"
                        " a to consider returning.")
    parser.add_argument("--multiply", "-M", default="1",
                        help="Value to multiply with each power of a before "
                        " checking its initial digits.")
    parser.add_argument("--base", "-b", type=int, default=10,
                        help="Base (radix) in which to interpret a, b and the"
                        " argument to --multiply (if any).")
    parser.add_argument("--debug", "-d", metavar="component", action="append",
                        help="Print diagnostics from a specified component"
                        " of the program.")
    parser.add_argument("--test", "-t", dest="tests", action="append_const",
                        const="standard", help="Run built-in self-tests.")
    parser.add_argument("--soaktest", "-T", dest="tests", action="append_const",
                        const="soak", help="Run more thorough soak tests.")
    args = parser.parse_args()

    global dbgclasses_enabled
    dbgclasses_enabled = set(args.debug or [])

    if args.tests is not None:
        test(args.tests)
    else:
        if args.a is None or args.b is None:
            # We listed these arguments as optional above so that
            # argparse wouldn't complain if we leave them off in test
            # mode. But in non-test mode they are required.
            sys.exit("{}: error: too few arguments".format(
                os.path.basename(sys.argv[0])))
        sys.stdout.write("{}\n".format(
            powbegin(from_string(args.a, args.base),
                     from_string(args.b, args.base),
                     args.base, args.minimum,
                     from_string(args.multiply, args.base))))

if __name__ == '__main__':
    main()
