From 0f179eed88cc00393a7abc3c824aa28df7b99d58 Mon Sep 17 00:00:00 2001 From: Andrew Noyes Date: Thu, 11 Sep 2025 10:41:42 -0400 Subject: [PATCH] Switch to two separate atomic counters It's faster and still correct. I was confused remembering something about atomic shared pointer ideas before. --- src/reference.hpp | 133 +++++++++++++++------------------------------- 1 file changed, 44 insertions(+), 89 deletions(-) diff --git a/src/reference.hpp b/src/reference.hpp index 15ff539..e90a21a 100644 --- a/src/reference.hpp +++ b/src/reference.hpp @@ -20,84 +20,42 @@ namespace detail { struct ControlBlock { - // Least significant 32 bits are strong reference count - // Most significant 32 bits are weak reference count - std::atomic ref_counts; + std::atomic strong_count; + std::atomic weak_count; - ControlBlock() : ref_counts(1) {} // Start with 1 strong reference + ControlBlock() + : strong_count(1), weak_count(0) {} // Start with 1 strong reference /** * @brief Increment strong reference count - * @return Previous ref_counts value (both strong and weak counts) + * @return Previous strong count */ - uint64_t increment_strong() noexcept { - uint64_t old_value; - uint64_t new_value; - do { - old_value = ref_counts.load(std::memory_order_relaxed); - uint32_t strong_count = static_cast(old_value); - uint32_t weak_count = static_cast(old_value >> 32); - new_value = - (static_cast(weak_count) << 32) | (strong_count + 1); - } while (!ref_counts.compare_exchange_weak(old_value, new_value, - std::memory_order_relaxed)); - - return old_value; + uint32_t increment_strong() noexcept { + return strong_count.fetch_add(1, std::memory_order_relaxed); } /** * @brief Decrement strong reference count - * @return Previous ref_counts value (both strong and weak counts) + * @return Previous strong count */ - uint64_t decrement_strong() noexcept { - uint64_t old_value; - uint64_t new_value; - do { - old_value = ref_counts.load(std::memory_order_relaxed); - uint32_t strong_count = static_cast(old_value); - uint32_t weak_count = static_cast(old_value >> 32); - new_value = - (static_cast(weak_count) << 32) | (strong_count - 1); - } while (!ref_counts.compare_exchange_weak(old_value, new_value, - std::memory_order_acq_rel)); - - return old_value; + uint32_t decrement_strong() noexcept { + return strong_count.fetch_sub(1, std::memory_order_acq_rel); } /** * @brief Increment weak reference count - * @return Previous ref_counts value (both strong and weak counts) + * @return Previous weak count */ - uint64_t increment_weak() noexcept { - uint64_t old_value; - uint64_t new_value; - do { - old_value = ref_counts.load(std::memory_order_relaxed); - uint32_t strong_count = static_cast(old_value); - uint32_t weak_count = static_cast(old_value >> 32); - new_value = (static_cast(weak_count + 1) << 32) | strong_count; - } while (!ref_counts.compare_exchange_weak(old_value, new_value, - std::memory_order_relaxed)); - - return old_value; + uint32_t increment_weak() noexcept { + return weak_count.fetch_add(1, std::memory_order_relaxed); } /** * @brief Decrement weak reference count - * @return Previous ref_counts value (both strong and weak counts) + * @return Previous weak count */ - uint64_t decrement_weak() noexcept { - uint64_t old_value; - uint64_t new_value; - do { - old_value = ref_counts.load(std::memory_order_relaxed); - uint32_t strong_count = static_cast(old_value); - uint32_t weak_count = static_cast(old_value >> 32); - new_value = (static_cast(weak_count - 1) << 32) | strong_count; - } while (!ref_counts.compare_exchange_weak(old_value, new_value, - std::memory_order_acq_rel)); - - return old_value; + uint32_t decrement_weak() noexcept { + return weak_count.fetch_sub(1, std::memory_order_acq_rel); } }; } // namespace detail @@ -214,17 +172,17 @@ private: */ void release() noexcept { if (control_block) { - uint64_t prev = control_block->decrement_strong(); - uint32_t prev_strong = static_cast(prev); + uint32_t prev_strong = control_block->decrement_strong(); // If this was the last strong reference, destroy the object if (prev_strong == 1) { T *obj = get(); obj->~T(); - // If no weak references either, free the entire allocation - uint32_t prev_weak = static_cast(prev >> 32); - if (prev_weak == 0) { + // Check if there are any weak references + uint32_t current_weak = + control_block->weak_count.load(std::memory_order_acquire); + if (current_weak == 0) { std::free(control_block); } } @@ -246,28 +204,23 @@ template struct WeakRef { if (!control_block) { return Ref(); } - uint64_t old_value; - uint64_t new_value; - do { - // Use acquire ordering to ensure that any subsequent use of the returned - // Ref (like dereferencing the object pointer) cannot be reordered before - // this safety check. This would ideally use memory_order_consume for - // dependency ordering, but the folk wisdom is "don't use that". - old_value = control_block->ref_counts.load(std::memory_order_acquire); - uint32_t strong_count = static_cast(old_value); - // If strong count is 0, object is being destroyed - if (strong_count == 0) { - return Ref(); + // Try to increment strong count if it's not zero + uint32_t expected_strong = + control_block->strong_count.load(std::memory_order_relaxed); + while (expected_strong > 0) { + // Try to increment the strong count + if (control_block->strong_count.compare_exchange_weak( + expected_strong, expected_strong + 1, std::memory_order_acquire, + std::memory_order_relaxed)) { + // Success - we incremented the strong count + return Ref(control_block); } + // CAS failed, expected_strong now contains the current value, retry + } - uint32_t weak_count = static_cast(old_value >> 32); - new_value = - (static_cast(weak_count) << 32) | (strong_count + 1); - } while (!control_block->ref_counts.compare_exchange_weak( - old_value, new_value, std::memory_order_relaxed)); - - return Ref(control_block); + // Strong count was 0, object is being destroyed + return Ref(); } /** @@ -361,14 +314,16 @@ private: */ void release() noexcept { if (control_block) { - uint64_t prev = control_block->decrement_weak(); - uint32_t prev_strong = static_cast(prev); - uint32_t prev_weak = static_cast(prev >> 32); + uint32_t prev_weak = control_block->decrement_weak(); - // If this was the last weak reference and no strong references, free - // control block - if (prev_weak == 1 && prev_strong == 0) { - std::free(control_block); + // If this was the last weak reference, check if we need to free control + // block + if (prev_weak == 1) { + uint32_t current_strong = + control_block->strong_count.load(std::memory_order_acquire); + if (current_strong == 0) { + std::free(control_block); + } } } }