Switch to two separate atomic counters

It's faster and still correct. I was confused remembering something
about atomic shared pointer ideas before.
This commit is contained in:
2025-09-11 10:41:42 -04:00
parent b9106a0d3c
commit 0f179eed88

View File

@@ -20,84 +20,42 @@
namespace detail { namespace detail {
struct ControlBlock { struct ControlBlock {
// Least significant 32 bits are strong reference count std::atomic<uint32_t> strong_count;
// Most significant 32 bits are weak reference count std::atomic<uint32_t> weak_count;
std::atomic<uint64_t> ref_counts;
ControlBlock() : ref_counts(1) {} // Start with 1 strong reference ControlBlock()
: strong_count(1), weak_count(0) {} // Start with 1 strong reference
/** /**
* @brief Increment strong reference count * @brief Increment strong reference count
* @return Previous ref_counts value (both strong and weak counts) * @return Previous strong count
*/ */
uint64_t increment_strong() noexcept { uint32_t increment_strong() noexcept {
uint64_t old_value; return strong_count.fetch_add(1, std::memory_order_relaxed);
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value =
(static_cast<uint64_t>(weak_count) << 32) | (strong_count + 1);
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_relaxed));
return old_value;
} }
/** /**
* @brief Decrement strong reference count * @brief Decrement strong reference count
* @return Previous ref_counts value (both strong and weak counts) * @return Previous strong count
*/ */
uint64_t decrement_strong() noexcept { uint32_t decrement_strong() noexcept {
uint64_t old_value; return strong_count.fetch_sub(1, std::memory_order_acq_rel);
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value =
(static_cast<uint64_t>(weak_count) << 32) | (strong_count - 1);
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_acq_rel));
return old_value;
} }
/** /**
* @brief Increment weak reference count * @brief Increment weak reference count
* @return Previous ref_counts value (both strong and weak counts) * @return Previous weak count
*/ */
uint64_t increment_weak() noexcept { uint32_t increment_weak() noexcept {
uint64_t old_value; return weak_count.fetch_add(1, std::memory_order_relaxed);
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value = (static_cast<uint64_t>(weak_count + 1) << 32) | strong_count;
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_relaxed));
return old_value;
} }
/** /**
* @brief Decrement weak reference count * @brief Decrement weak reference count
* @return Previous ref_counts value (both strong and weak counts) * @return Previous weak count
*/ */
uint64_t decrement_weak() noexcept { uint32_t decrement_weak() noexcept {
uint64_t old_value; return weak_count.fetch_sub(1, std::memory_order_acq_rel);
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value = (static_cast<uint64_t>(weak_count - 1) << 32) | strong_count;
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_acq_rel));
return old_value;
} }
}; };
} // namespace detail } // namespace detail
@@ -214,17 +172,17 @@ private:
*/ */
void release() noexcept { void release() noexcept {
if (control_block) { if (control_block) {
uint64_t prev = control_block->decrement_strong(); uint32_t prev_strong = control_block->decrement_strong();
uint32_t prev_strong = static_cast<uint32_t>(prev);
// If this was the last strong reference, destroy the object // If this was the last strong reference, destroy the object
if (prev_strong == 1) { if (prev_strong == 1) {
T *obj = get(); T *obj = get();
obj->~T(); obj->~T();
// If no weak references either, free the entire allocation // Check if there are any weak references
uint32_t prev_weak = static_cast<uint32_t>(prev >> 32); uint32_t current_weak =
if (prev_weak == 0) { control_block->weak_count.load(std::memory_order_acquire);
if (current_weak == 0) {
std::free(control_block); std::free(control_block);
} }
} }
@@ -246,28 +204,23 @@ template <typename T> struct WeakRef {
if (!control_block) { if (!control_block) {
return Ref<T>(); return Ref<T>();
} }
uint64_t old_value;
uint64_t new_value;
do {
// Use acquire ordering to ensure that any subsequent use of the returned
// Ref (like dereferencing the object pointer) cannot be reordered before
// this safety check. This would ideally use memory_order_consume for
// dependency ordering, but the folk wisdom is "don't use that".
old_value = control_block->ref_counts.load(std::memory_order_acquire);
uint32_t strong_count = static_cast<uint32_t>(old_value);
// If strong count is 0, object is being destroyed // Try to increment strong count if it's not zero
if (strong_count == 0) { uint32_t expected_strong =
return Ref<T>(); control_block->strong_count.load(std::memory_order_relaxed);
while (expected_strong > 0) {
// Try to increment the strong count
if (control_block->strong_count.compare_exchange_weak(
expected_strong, expected_strong + 1, std::memory_order_acquire,
std::memory_order_relaxed)) {
// Success - we incremented the strong count
return Ref<T>(control_block);
}
// CAS failed, expected_strong now contains the current value, retry
} }
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32); // Strong count was 0, object is being destroyed
new_value = return Ref<T>();
(static_cast<uint64_t>(weak_count) << 32) | (strong_count + 1);
} while (!control_block->ref_counts.compare_exchange_weak(
old_value, new_value, std::memory_order_relaxed));
return Ref<T>(control_block);
} }
/** /**
@@ -361,17 +314,19 @@ private:
*/ */
void release() noexcept { void release() noexcept {
if (control_block) { if (control_block) {
uint64_t prev = control_block->decrement_weak(); uint32_t prev_weak = control_block->decrement_weak();
uint32_t prev_strong = static_cast<uint32_t>(prev);
uint32_t prev_weak = static_cast<uint32_t>(prev >> 32);
// If this was the last weak reference and no strong references, free // If this was the last weak reference, check if we need to free control
// control block // block
if (prev_weak == 1 && prev_strong == 0) { if (prev_weak == 1) {
uint32_t current_strong =
control_block->strong_count.load(std::memory_order_acquire);
if (current_strong == 0) {
std::free(control_block); std::free(control_block);
} }
} }
} }
}
template <typename U> friend struct Ref; template <typename U> friend struct Ref;
}; };