본문 바로가기

진리는어디에/C++

[C++20] 코루틴(Coroutine) - 예제

Coroutine.h

#ifndef _COROUTINE_H_
#define _COROUTINE_H_

#include <coroutine>
#include <memory>

template <class T, class INITIAL_SUSPEND = std::suspend_always>
class Coroutine
{
private:
    class Impl;

    struct promise_base
    {
        INITIAL_SUSPEND initial_suspend()
        {
            return INITIAL_SUSPEND{};
        }

        std::suspend_always final_suspend() noexcept
        {
            return {};
        }

        void unhandled_exception()
        {
            throw std::exception("unhandled exception");
        }
    };

    template <class R>
    struct promise_type_impl : public promise_base
    {
        R value;

        Coroutine get_return_object()
        {
            return Coroutine{ std::make_shared<Impl>(std::coroutine_handle<promise_type_impl>::from_promise(*this)) };
        }

        std::suspend_always yield_value(R&& value)
        {
            this->value = value;
            return {};
        }

        std::suspend_always yield_value(const R& value)
        {
            this->value = value;
            return {};
        }

        void return_value(R&& value)
        {
            this->value = value;
        }

        void return_value(const R& value)
        {
            this->value = value;
        }

    };

    template <>
    struct promise_type_impl<void> : public promise_base
    {
        Coroutine get_return_object()
        {
            return Coroutine{ std::make_shared<Impl>(std::coroutine_handle<promise_type_impl>::from_promise(*this)) };
        }

        void return_void()
        {
        }
    };

public:
    typedef promise_type_impl<typename T> promise_type;

public:
    Coroutine()
        : impl(nullptr)
    {
    }

    Coroutine(std::shared_ptr<Impl> impl)
        : impl(impl)
    {
    }

    Coroutine(const Coroutine& other)
        : impl(other.impl)
    {
    }

    bool operator()() const
    {
        return resume();
    }

    bool resume() const
    {
        if (true == done())
        {
            return false;
        }

        impl->handle.resume();

        return true;
    }

    promise_type& promise()
    {
        return impl->handle.promise();
    }

    bool done() const
    {
        if (nullptr == impl)
        {
            return true;
        }

        return !impl->handle || impl->handle.done();
    }

    std::coroutine_handle<promise_type> corotine_handle() const
    {
        return impl->handle;
    }

    Coroutine& operator = (const Coroutine& other)
    {
        impl = other.impl;
        return *this;
    }

    struct iterator
    {
        explicit iterator(Coroutine* coroutine)
            : coroutine(coroutine)
            , done(true)
        {
        }

        const T& operator* () const
        {
            return coroutine->promise().value;
        }

        iterator& operator++()
        {
            done = !coroutine->resume();
            return *this;
        }

        bool operator == (std::default_sentinel_t)
        {
            if(true == done && true == coroutine->done())
            {
                return true;
            }
            return false;
        }
    private:
        Coroutine* coroutine;
        bool done;
    };

    iterator begin()
    {
        if (nullptr == impl)
        {
            return iterator{ nullptr };
        }

        if (impl->handle)
        {
            impl->handle.resume();
        }

        return iterator{ this };
    }

    std::default_sentinel_t end()
    {
        return {};
    }
private:
    class Impl
    {
    public:
        Impl(std::coroutine_handle<promise_type> handle)
            : handle(handle)
            , done(false)
        {
        }

        ~Impl()
        {
            if (true == (bool)handle)
            {
                handle.destroy();
            }
        }

        std::coroutine_handle<promise_type> handle;
        bool done;
    };

private:
    std::shared_ptr<Impl> impl;
};

#endif

Coroutine Example

// main.cpp
#include <iostream>
#include "Coroutine.h"

Coroutine<void> lazily_start()
{
    std::cout << "\tlazily_start 1" << std::endl;
    co_await std::suspend_always{};
    std::cout << "\tlazily_start 2" << std::endl;
}

Coroutine<void, std::suspend_never> eagerly_start()
{
    std::cout << "\teagerly_start 1" << std::endl;
    co_await std::suspend_always{};
    std::cout << "\teagerly_start 2" << std::endl;
}

Coroutine<int> yield(int start, int end)
{
    for (int i = start; i < end; i++)
    {
        co_yield i;
    }
}

int main()
{
    {
        std::cout << "==== lazily_start example ====" << std::endl;
        Coroutine<void> coroutine = lazily_start();
        std::cout << "main 1" << std::endl;
        coroutine();
        std::cout << "main 2" << std::endl;
        coroutine();
        /* OUTPUT
        ==== lazily_start example ====
        main 1
                lazily_start 1
        main 2
                lazily_start 2
        */
    }

    {
        std::cout << "==== eagerly_start example ====" << std::endl;
        Coroutine<void, std::suspend_never> coroutine = eagerly_start();
        std::cout << "main 1" << std::endl;
        coroutine();
        std::cout << "main 2" << std::endl;
        coroutine();
        /* OUTPUT
        ==== eagerly_start example ====
                eagerly_start 1
        main 1
                eagerly_start 2
        main 2
        */
    }

    {
        std::cout << "==== yield while loop example ====" << std::endl;
        Coroutine<int> coroutine = yield(0, 5);

        int i = 0;

        while (true == coroutine.resume())
        {
            std::cout << "main " << i++ << std::endl;
            std::cout << "\tyield value " << coroutine.promise().value << std::endl;
        }
        /* OUTPUT
        ==== yield while loop example ====
        main 0
                yield value 0
        main 1
                yield value 1
        main 2
                yield value 2
        main 3
                yield value 3
        main 4
                yield value 4
        */
    }

    {
        std::cout << "==== yield ranged-for example ====" << std::endl;
        Coroutine<int> coroutine = yield(5, 10);

        int i = 0;

        for (int value : coroutine)
        {
            std::cout << "main " << i++ << std::endl;
            std::cout << "\tyield value " << value << std::endl;
        }
        /* OUTPUT
        ==== yield ranged-for example ====
        main 0
                yield value 5
        main 1
                yield value 6
        main 2
                yield value 7
        main 3
                yield value 8
        main 4
                yield value 9
        */
    }
}

결과

==== lazily_start example ====
main 1
        lazily_start 1
main 2
        lazily_start 2
==== eagerly_start example ====
        eagerly_start 1
main 1
        eagerly_start 2
main 2
==== yield while loop example ====
main 0
        yield value 0
main 1
        yield value 1
main 2
        yield value 2
main 3
        yield value 3
main 4
        yield value 4
==== yield ranged-for example ====
main 0
        yield value 5
main 1
        yield value 6
main 2
        yield value 7
main 3
        yield value 8
main 4
        yield value 9

부록 1. 같이 읽으면 좋은 글

유익한 글이었다면 공감(❤) 버튼 꾹!! 추가 문의 사항은 댓글로!!