// -*- C++ -*-
//===--------------------------- barrier ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX_BARRIER
#define _LIBCUDACXX_BARRIER

/*
    barrier synopsis

namespace std
{

  template<class CompletionFunction = see below>
  class barrier
  {
  public:
    using arrival_token = see below;

    constexpr explicit barrier(ptrdiff_t phase_count,
                               CompletionFunction f = CompletionFunction());
    ~barrier();

    barrier(const barrier&) = delete;
    barrier& operator=(const barrier&) = delete;

    [[nodiscard]] arrival_token arrive(ptrdiff_t update = 1);
    void wait(arrival_token&& arrival) const;

    void arrive_and_wait();
    void arrive_and_drop();

  private:
    CompletionFunction __completion; // exposition only
  };

}

*/

#ifndef __cuda_std__
#  include <__config>
#endif // __cuda_std__

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#ifndef __cuda_std__
#ifndef _LIBCUDACXX_HAS_NO_TREE_BARRIER
#  include <thread>
#  include <vector>
#endif
#else
#ifndef _CCCL_COMPILER_NVRTC
#include <new>
#endif // _CCCL_COMPILER_NVRTC
#endif // __cuda_std__

#include "__assert" // all public C++ headers provide the assertion handler
#include "__debug"
#include "atomic"
#include "chrono"
#include "cstddef"

#ifndef __cuda_std__
#include <__pragma_push>
#endif // __cuda_std__

#ifdef _LIBCUDACXX_HAS_NO_THREADS
# error <barrier> is not supported on this single threaded system
#endif

_LIBCUDACXX_BEGIN_NAMESPACE_STD

struct __empty_completion
{
    inline _LIBCUDACXX_INLINE_VISIBILITY
    void operator()() noexcept { }
};

#ifndef _LIBCUDACXX_HAS_NO_TREE_BARRIER

template<class _CompletionF = __empty_completion, int _Sco = 0>
class alignas(64) __barrier_base {

    ptrdiff_t                       __expected;
    __atomic_base<ptrdiff_t, _Sco>  __expected_adjustment;
    _CompletionF                    __completion;

    using __phase_t = uint8_t;
    __atomic_base<__phase_t, _Sco>  __phase;

    struct alignas(64) __state_t
    {
        struct {
            __atomic_base<__phase_t, _Sco> __phase = ATOMIC_VAR_INIT(0);
        } __tickets[64];
    };
    ::std::vector<__state_t>   __state;

    inline _LIBCUDACXX_INLINE_VISIBILITY
    bool __arrive(__phase_t const __old_phase)
    {
        __phase_t const __half_step = __old_phase + 1, __full_step = __old_phase + 2;
#ifndef _LIBCUDACXX_HAS_NO_THREAD_FAVORITE_BARRIER_INDEX
        ptrdiff_t __current = __libcpp_thread_favorite_barrier_index,
#else
        ptrdiff_t __current = 0,
#endif
                  __current_expected = __expected,
                  __last_node = (__current_expected >> 1);
        for(size_t __round = 0;; ++__round) {
            _LIBCUDACXX_ASSERT(__round <= 63, "");
            if(__current_expected == 1)
                return true;
            for(;;++__current) {
#ifndef _LIBCUDACXX_HAS_NO_THREAD_FAVORITE_BARRIER_INDEX
                if(0 == __round) {
                    if(__current >= __current_expected)
                        __current = 0;
                    __libcpp_thread_favorite_barrier_index = __current;
                }
#endif
                _LIBCUDACXX_ASSERT(__current <= __last_node, "");
                __phase_t expect = __old_phase;
                if(__current == __last_node && (__current_expected & 1))
                {
                    if(__state[__current].__tickets[__round].__phase.compare_exchange_strong(expect, __full_step, memory_order_acq_rel))
                        break;    // I'm 1 in 1, go to next __round
                    _LIBCUDACXX_ASSERT(expect == __full_step, "");
                }
                else if(__state[__current].__tickets[__round].__phase.compare_exchange_strong(expect, __half_step, memory_order_acq_rel))
                {
                    return false; // I'm 1 in 2, done with arrival
                }
                else if(expect == __half_step)
                {
                    if(__state[__current].__tickets[__round].__phase.compare_exchange_strong(expect, __full_step, memory_order_acq_rel))
                        break;    // I'm 2 in 2, go to next __round
                    _LIBCUDACXX_ASSERT(expect == __full_step, "");
                }
                _LIBCUDACXX_ASSERT(__round == 0 && expect == __full_step, "");
            }
            __current_expected = (__current_expected >> 1) + (__current_expected & 1);
            __current &= ~( 1 << __round );
            __last_node &= ~( 1 << __round );
        }
    }

public:
    using arrival_token = __phase_t;

    inline _LIBCUDACXX_INLINE_VISIBILITY
    __barrier_base(ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
            : __expected(__expected), __expected_adjustment(0), __completion(__completion),
              __phase(0), __state((__expected+1) >> 1)
    {
        _LIBCUDACXX_ASSERT(__expected >= 0, "");
    }

    inline _LIBCUDACXX_INLINE_VISIBILITY
    ~__barrier_base() = default;

    __barrier_base(__barrier_base const&) = delete;
    __barrier_base& operator=(__barrier_base const&) = delete;

     _LIBCUDACXX_NODISCARD_ATTRIBUTE inline _LIBCUDACXX_INLINE_VISIBILITY
    arrival_token arrive(ptrdiff_t update = 1)
    {
        _LIBCUDACXX_ASSERT(update > 0, "");
        auto __old_phase = __phase.load(memory_order_relaxed);
        for(; update; --update)
            if(__arrive(__old_phase)) {
                __completion();
                __expected += __expected_adjustment.load(memory_order_relaxed);
                __expected_adjustment.store(0, memory_order_relaxed);
                __phase.store(__old_phase + 2, memory_order_release);
            }
        return __old_phase;
    }
    inline _LIBCUDACXX_INLINE_VISIBILITY
    void wait(arrival_token&& __old_phase) const
    {
        __libcpp_thread_poll_with_backoff([=]() -> bool {
            return __phase.load(memory_order_acquire) != __old_phase;
        });
    }
    inline _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_wait()
    {
        wait(arrive());
    }
    inline _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_drop()
    {
        __expected_adjustment.fetch_sub(1, memory_order_relaxed);
        (void)arrive();
    }
};

#else

# if _LIBCUDACXX_CUDA_ABI_VERSION < 3
#  define _LIBCUDACXX_BARRIER_ALIGNMENTS alignas(64)
# else
#  define _LIBCUDACXX_BARRIER_ALIGNMENTS
# endif

template<class _Barrier>
class __barrier_poll_tester_phase {
    _Barrier const* __this;
    typename _Barrier::arrival_token __phase;

public:
    _LIBCUDACXX_INLINE_VISIBILITY
    __barrier_poll_tester_phase(_Barrier const* __this_,
            typename _Barrier::arrival_token&& __phase_)
        : __this(__this_)
        , __phase(_CUDA_VSTD::move(__phase_))
    {}

    _LIBCUDACXX_INLINE_VISIBILITY
    bool operator()() const
    {
        return __this->__try_wait(__phase);
    }
};

template<class _Barrier>
class __barrier_poll_tester_parity {
    _Barrier const* __this;
    bool __parity;

public:
    _LIBCUDACXX_INLINE_VISIBILITY
    __barrier_poll_tester_parity(_Barrier const* __this_, bool __parity_)
        : __this(__this_)
        , __parity(__parity_)
    {}

    inline _LIBCUDACXX_INLINE_VISIBILITY
    bool operator()() const
    {
        return __this->__try_wait_parity(__parity);
    }
};

template<class _Barrier>
_LIBCUDACXX_INLINE_VISIBILITY
bool __call_try_wait(const _Barrier& __b, typename _Barrier::arrival_token&& __phase)
{
    return __b.__try_wait(_CUDA_VSTD::move(__phase));
}

template<class _Barrier>
_LIBCUDACXX_INLINE_VISIBILITY
bool __call_try_wait_parity(const _Barrier& __b, bool __parity)
{
    return __b.__try_wait_parity(__parity);
}


template<class _CompletionF, int _Sco = 0>
class __barrier_base {

    _LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<ptrdiff_t, _Sco> __expected, __arrived;
    _LIBCUDACXX_BARRIER_ALIGNMENTS _CompletionF                   __completion;
    _LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<bool, _Sco>      __phase;

public:
    using arrival_token = bool;

private:
    template<typename _Barrier>
    friend class __barrier_poll_tester_phase;
    template<typename _Barrier>
    friend class __barrier_poll_tester_parity;
    template<typename _Barrier>
    _LIBCUDACXX_INLINE_VISIBILITY
    friend bool __call_try_wait(const _Barrier& __b,
    typename _Barrier::arrival_token&& __phase);
    template<typename _Barrier>
    _LIBCUDACXX_INLINE_VISIBILITY
    friend bool __call_try_wait_parity(const _Barrier& __b, bool __parity);

    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_wait(arrival_token __old) const
    {
        return __phase.load(memory_order_acquire) != __old;
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_wait_parity(bool __parity) const
    {
        return __try_wait(__parity);
    }

public:
    __barrier_base() = default;

    _LIBCUDACXX_INLINE_VISIBILITY
    __barrier_base(ptrdiff_t __expected, _CompletionF __completion = _CompletionF())
        : __expected(__expected), __arrived(__expected), __completion(__completion), __phase(false)
    {
    }

    ~__barrier_base() = default;

    __barrier_base(__barrier_base const&) = delete;
    __barrier_base& operator=(__barrier_base const&) = delete;

    _LIBCUDACXX_NODISCARD_ATTRIBUTE _LIBCUDACXX_INLINE_VISIBILITY
    arrival_token arrive(ptrdiff_t __update = 1)
    {
        auto const __old_phase = __phase.load(memory_order_relaxed);
        auto const __result = __arrived.fetch_sub(__update, memory_order_acq_rel) - __update;
        auto const __new_expected = __expected.load(memory_order_relaxed);

        _LIBCUDACXX_DEBUG_ASSERT(__result >= 0, "");

        if(0 == __result) {
            __completion();
            __arrived.store(__new_expected, memory_order_relaxed);
            __phase.store(!__old_phase, memory_order_release);
            __cxx_atomic_notify_all(&__phase.__a_);
        }
        return __old_phase;
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    void wait(arrival_token&& __old_phase) const
    {
        __phase.wait(__old_phase, memory_order_acquire);
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_wait()
    {
        wait(arrive());
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_drop()
    {
        __expected.fetch_sub(1, memory_order_relaxed);
        (void)arrive();
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    static constexpr ptrdiff_t max() noexcept
    {
        return numeric_limits<ptrdiff_t>::max();
    }
};

template<int _Sco>
class __barrier_base<__empty_completion, _Sco> {

    static constexpr uint64_t __expected_unit = 1ull;
    static constexpr uint64_t __arrived_unit = 1ull << 32;
    static constexpr uint64_t __expected_mask = __arrived_unit - 1;
    static constexpr uint64_t __phase_bit = 1ull << 63;
    static constexpr uint64_t __arrived_mask = (__phase_bit - 1) & ~__expected_mask;

    _LIBCUDACXX_BARRIER_ALIGNMENTS __atomic_base<uint64_t, _Sco> __phase_arrived_expected;

public:
    using arrival_token = uint64_t;

private:
    template<typename _Barrier>
    friend class __barrier_poll_tester_phase;
    template<typename _Barrier>
    friend class __barrier_poll_tester_parity;
    template<typename _Barrier>
    _LIBCUDACXX_INLINE_VISIBILITY
    friend bool __call_try_wait(const _Barrier& __b,
    typename _Barrier::arrival_token&& __phase);
    template<typename _Barrier>
    _LIBCUDACXX_INLINE_VISIBILITY
    friend bool __call_try_wait_parity(const _Barrier& __b, bool __parity);

    static _LIBCUDACXX_INLINE_VISIBILITY constexpr
    uint64_t __init(ptrdiff_t __count) noexcept
    {
#if _CCCL_STD_VER > 2011
        // This debug assert is not supported in C++11 due to resulting in a
        // multi-statement constexpr function.
        _LIBCUDACXX_DEBUG_ASSERT(__count >= 0, "Count must be non-negative.");
#endif // _CCCL_STD_VER > 2011
        return (((1u << 31) - __count) << 32)
              | ((1u << 31) - __count);
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_wait_phase(uint64_t __phase) const
    {
        uint64_t const __current = __phase_arrived_expected.load(memory_order_acquire);
        return ((__current & __phase_bit) != __phase);
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_wait(arrival_token __old) const
    {
		return __try_wait_phase(__old & __phase_bit);
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    bool __try_wait_parity(bool __parity) const
    {
        return __try_wait_phase(__parity ? __phase_bit : 0);
    }

public:
    __barrier_base() = default;

    _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11
    __barrier_base(ptrdiff_t __count, __empty_completion = __empty_completion())
        : __phase_arrived_expected(__init(__count)) {
        _LIBCUDACXX_DEBUG_ASSERT(__count >= 0, "");
    }

    ~__barrier_base() = default;

    __barrier_base(__barrier_base const&) = delete;
    __barrier_base& operator=(__barrier_base const&) = delete;

    _LIBCUDACXX_NODISCARD_ATTRIBUTE inline _LIBCUDACXX_INLINE_VISIBILITY
    arrival_token arrive(ptrdiff_t __update = 1)
    {
        auto const __inc = __arrived_unit * __update;
        auto const __old = __phase_arrived_expected.fetch_add(__inc, memory_order_acq_rel);
        if((__old ^ (__old + __inc)) & __phase_bit) {
            __phase_arrived_expected.fetch_add((__old & __expected_mask) << 32, memory_order_relaxed);
            __phase_arrived_expected.notify_all();
        }
        return __old & __phase_bit;
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    void wait(arrival_token&& __phase) const
    {
		__libcpp_thread_poll_with_backoff(__barrier_poll_tester_phase<__barrier_base>(this, _CUDA_VSTD::move(__phase)));
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    void wait_parity(bool __parity) const
    {
        __libcpp_thread_poll_with_backoff(__barrier_poll_tester_parity<__barrier_base>(this, __parity));
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_wait()
    {
        wait(arrive());
    }
    _LIBCUDACXX_INLINE_VISIBILITY
    void arrive_and_drop()
    {
        __phase_arrived_expected.fetch_add(__expected_unit, memory_order_relaxed);
        (void)arrive();
    }

    _LIBCUDACXX_INLINE_VISIBILITY
    static constexpr ptrdiff_t max() noexcept
    {
        return numeric_limits<int32_t>::max();
    }
};

#endif //_LIBCUDACXX_HAS_NO_TREE_BARRIER

template<class _CompletionF = __empty_completion>
class barrier : public __barrier_base<_CompletionF> {
public:
    _LIBCUDACXX_INLINE_VISIBILITY constexpr
    barrier(ptrdiff_t __count, _CompletionF __completion = _CompletionF())
        : __barrier_base<_CompletionF>(__count, __completion) {
    }
};

_LIBCUDACXX_END_NAMESPACE_STD

#ifndef __cuda_std__
#include <__pragma_pop>
#else
#include "__cuda/barrier.h"
#endif // __cuda_std__

#endif //_LIBCUDACXX_BARRIER
