#include <assert.h>
#include <coroutine>
#include <iostream>
#include <optional>

// To avoid having to have a tedious prefix on all these
// implementation class names, we'll put them all in a namespace of
// their own, and export just the user-facing type under a more
// sensible name.

namespace SG {

// Forward-declare all the template type names and say how many
// parameters each one has. We must do that before the class
// definitions start cross-referring to each other, or else the C++
// parser will give errors.

template <typename Yield> class PromiseBase;
template <typename Yield, typename Return> class PromiseMid;
template <typename Yield, typename Return> class Promise;
template <typename Yield, typename Return> class SubroutineAwaiterBase;
template <typename Yield, typename Return> class SubroutineAwaiter;
template <typename Yield, typename Return> class FinalSuspendAwaiter;
template <typename Yield, typename Return> class UserFacing;

template <typename Yield, typename Return>
using Handle = std::coroutine_handle<Promise<Yield, Return>>;

template <typename Yield> class PromiseBase {
  public:
    PromiseBase *outer, *parent, *current;
    std::optional<Yield> last_yield;

    std::exception_ptr exception;

    PromiseBase() {
        outer = current = this;
        parent = nullptr;
    }
    PromiseBase(const PromiseBase &) = delete;
    PromiseBase(PromiseBase &&) = delete;

    PromiseBase &operator=(const PromiseBase &) = delete;
    PromiseBase &operator=(PromiseBase &&) = delete;

    virtual ~PromiseBase() = default;
    virtual std::coroutine_handle<> handle() = 0;

    void link_to_callee(PromiseBase *callee) {
        callee->outer = outer;
        callee->parent = outer->current;
        outer->current = callee;
    }

    void unlink_from_caller() {
        outer->current = parent;
    }
};

template <typename Yield, typename Return>
class PromiseMid : public PromiseBase<Yield> {
    friend class SubroutineAwaiterBase<Yield, Return>;
    friend class SubroutineAwaiter<Yield, Return>;

    Handle<Yield, Return> handle_specific()
    {
        return Handle<Yield, Return>::from_promise(
            *static_cast<Promise<Yield, Return> *>(this));
    }

  public:
    UserFacing<Yield, Return> get_return_object()
    {
        return UserFacing<Yield, Return> { handle_specific() };
    }
    std::suspend_always initial_suspend() { return {}; }
    FinalSuspendAwaiter<Yield, Return> final_suspend() noexcept
    {
        this->unlink_from_caller();
        return FinalSuspendAwaiter<Yield, Return> { this };
    }
    void unhandled_exception() {
        this->exception = std::current_exception();
    }

    template <typename CalleeReturn>
    SubroutineAwaiter<Yield, CalleeReturn>
    await_transform(UserFacing<Yield, CalleeReturn> &&callee)
    {
        return SubroutineAwaiter<Yield, CalleeReturn>(std::move(callee));
    }

    std::suspend_always yield_value(Yield y)
    {
        this->outer->last_yield = y;
        return {};
    }

    std::coroutine_handle<> handle() override { return handle_specific(); };
};

template <typename Yield, typename Return>
class Promise : public PromiseMid<Yield, Return> {
  public:
    Return ret;
    void return_value(Return r) { ret = r; }
};

template <typename Yield>
class Promise<Yield, void> : public PromiseMid<Yield, void> {
  public:
    void return_void() {}
};

template <typename Yield, typename Return> class SubroutineAwaiterBase {
    UserFacing<Yield, Return> callee;

  protected:
    Promise<Yield, Return> *callee_promise;

  public:
    SubroutineAwaiterBase(UserFacing<Yield, Return> &&callee)
        : callee(std::move(callee)) {}

    bool await_ready() { return false; }

    template <typename CallerReturn>
    std::coroutine_handle<> await_suspend(Handle<Yield, CallerReturn> caller)
    {
        callee_promise = &callee.handle.promise();
        caller.promise().link_to_callee(callee_promise);
        return callee.handle;
    }

    void check_exception() {
        if (this->callee_promise->exception)
            std::rethrow_exception(this->callee_promise->exception);
    }
};

template <typename Yield, typename Return>
class SubroutineAwaiter : public SubroutineAwaiterBase<Yield, Return> {
  public:
    Return await_resume() {
        this->check_exception();
        return this->callee_promise->ret;
    }
};

template <typename Yield>
class SubroutineAwaiter<Yield, void>
    : public SubroutineAwaiterBase<Yield, void> {
  public:
    void await_resume() {
        this->check_exception();
    }
};

template <typename Yield, typename Return>
class FinalSuspendAwaiter {
    PromiseMid<Yield, Return> *promise;

  public:
    FinalSuspendAwaiter(PromiseMid<Yield, Return> *promise)
        : promise(promise) {}

    bool await_ready() noexcept { return false; }

    std::coroutine_handle<> await_suspend(Handle<Yield, Return>) noexcept {
        return promise->parent ? promise->parent->handle()
            : std::noop_coroutine();
    }

    void await_resume() noexcept {}
};

template <typename Yield, typename Return> class UserFacing {
    friend class SubroutineAwaiterBase<Yield, Return>;
    friend class SubroutineAwaiter<Yield, Return>;
    friend class PromiseMid<Yield, Return>;
    friend class Promise<Yield, Return>;

  public:
    std::optional<Yield> next_value() {
        auto &outer = handle.promise();
        assert(outer.outer == &outer);
        if (!outer.current)
            return std::nullopt;

        auto curr = outer.current->handle();
        outer.last_yield = std::nullopt;
        outer.exception = nullptr;
        assert(!curr.done());
        curr.resume();

        if (outer.exception)
            std::rethrow_exception(outer.exception);

        return outer.last_yield;
    }

    UserFacing(const UserFacing &) = delete;
    UserFacing &operator=(const UserFacing &) = delete;

    UserFacing(UserFacing &&rhs) : handle(rhs.handle) {
        rhs.handle = nullptr;
    }
    UserFacing &operator=(UserFacing &&rhs) {
        if (handle)
            handle.destroy();
        handle = rhs.handle;
        rhs.handle = nullptr;
        return *this;
    }
    ~UserFacing() {
        if (handle)
            handle.destroy();
    }

  private:
    Handle<Yield, Return> handle;
    UserFacing(Handle<Yield, Return> handle) : handle(handle) {}
};

} // namespace SG

template <typename Yield, typename Return, typename... ArgTypes>
struct std::coroutine_traits<SG::UserFacing<Yield, Return>, ArgTypes...> {
    using promise_type = SG::Promise<Yield, Return>;
};

template <typename Yield, typename Return = void>
using StackableGenerator = SG::UserFacing<Yield, Return>;

#include <sstream>
#include <string>
#include <string_view>

using namespace std::literals;

class TestException: std::exception {
    std::string s;
  public:
    TestException(std::string_view sv) : s(sv) {}
    std::string_view str() const { return s; }
};

std::string to_string(int i)
{
    std::ostringstream os;
    os << i;
    return os.str();
}

StackableGenerator<std::string, int> subroutine(int n)
{
    for (int i = 0; i < n; i++)
        co_yield "subroutine(n="s + to_string(n) + ") yield #"s + to_string(i);
    if (n > 1)
        co_await subroutine(n - 1);
    co_return 5 * n;
}

StackableGenerator<std::string, int> subroutine_that_throws(int n)
{
    co_yield "hello from subroutine_that_throws("s + to_string(n) + ")"s;
    if (n == 0)
        throw TestException("oops!");
    co_return n+1;
}

StackableGenerator<std::string> toplevel(int n)
{
    for (int i = 0; i < n; i++)
        co_yield "top-level initial yield #"s + to_string(i);
    co_await subroutine_that_throws(n);
    int k = co_await subroutine(n);
    co_yield "subroutine returned "s + to_string(k);
    for (int i = 0; i < 3; i++) {
        co_yield "about to call subroutine_that_throws("s + to_string(i) + ")"s;
        std::optional<TestException> e;
        try {
            int r = co_await subroutine_that_throws(i);
            co_yield "it returned "s + to_string(r);
        } catch (TestException ee) {
            // You can't put a co_yield in a catch clause! IDK why
            // not. Language limitation.
            e = ee;
        }
        if (e)
            co_yield "it threw '"s + std::string(e->str()) + "'"s;
    }
    for (int i = 0; i < n; i++)
        co_yield "top-level final yield #"s + to_string(i);
}

int main()
{
    try {
        auto g = toplevel(5);
        while (auto val = g.next_value())
            std::cout << "yielded '" << *val << "'\n";

        g = toplevel(0);
        while (auto val = g.next_value())
            std::cout << "yielded '" << *val << "'\n";
    } catch (TestException ee) {
        std::cout << "main() caught '" << ee.str() << "'\n";
    }
}
