#include <coroutine>
#include <iostream>
#include <string>

class AbstractBaseClass {
  public:
    virtual ~AbstractBaseClass() = default;
    virtual std::string get_value() = 0;
};

class Promise;
using Handle = std::coroutine_handle<Promise>;

class CoroutineDerivedClass : public AbstractBaseClass {
  private:
    friend class Promise;
    Handle handle;

    CoroutineDerivedClass(Handle handle) : handle(handle) {}

  public:
    std::string get_value() override;

    ~CoroutineDerivedClass() {
        handle.destroy();
    }
};

class Promise {
  public:
    std::string yielded_value;

    CoroutineDerivedClass *get_return_object() {
        return new CoroutineDerivedClass(Handle::from_promise(*this));
    }
    std::suspend_always initial_suspend() { return {}; }
    std::suspend_always final_suspend() noexcept { return {}; }
    std::suspend_always yield_value(std::string_view val) {
        yielded_value = val;
        return {};
    }
    void unhandled_exception() {}
    void return_void() {}
};

std::string CoroutineDerivedClass::get_value() {
    handle.promise().yielded_value = "";
    handle.resume();
    return handle.promise().yielded_value;
}

template<typename... ArgTypes>
struct std::coroutine_traits<AbstractBaseClass *, ArgTypes...> {
    using promise_type = Promise;
};

class ConventionalDerivedClass : public AbstractBaseClass {
  public:
    ConventionalDerivedClass() = default;
    std::string get_value() override {
        return "hello from ConventionalDerivedClass::get_value";
    }
};

AbstractBaseClass *demo_coroutine() {
    co_yield "hello from coroutine, part 1";
    co_yield "hello from coroutine, part 2";
}

AbstractBaseClass *demo_non_coroutine() {
    return new ConventionalDerivedClass();
}

int main() {
    auto foo = demo_coroutine();
    auto bar = demo_non_coroutine();
    std::cout << "foo says: " << foo->get_value() << "\n";
    std::cout << "foo says: " << foo->get_value() << "\n";
    std::cout << "bar says: " << bar->get_value() << "\n";
    std::cout << "bar says: " << bar->get_value() << "\n";
    delete foo;
    delete bar;
}
