Switch to two separate atomic counters

It's faster and still correct. I was confused remembering something
about atomic shared pointer ideas before.
This commit is contained in:
2025-09-11 10:41:42 -04:00
parent b9106a0d3c
commit 0f179eed88

View File

@@ -20,84 +20,42 @@
namespace detail {
struct ControlBlock {
// Least significant 32 bits are strong reference count
// Most significant 32 bits are weak reference count
std::atomic<uint64_t> ref_counts;
std::atomic<uint32_t> strong_count;
std::atomic<uint32_t> weak_count;
ControlBlock() : ref_counts(1) {} // Start with 1 strong reference
ControlBlock()
: strong_count(1), weak_count(0) {} // Start with 1 strong reference
/**
* @brief Increment strong reference count
* @return Previous ref_counts value (both strong and weak counts)
* @return Previous strong count
*/
uint64_t increment_strong() noexcept {
uint64_t old_value;
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value =
(static_cast<uint64_t>(weak_count) << 32) | (strong_count + 1);
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_relaxed));
return old_value;
uint32_t increment_strong() noexcept {
return strong_count.fetch_add(1, std::memory_order_relaxed);
}
/**
* @brief Decrement strong reference count
* @return Previous ref_counts value (both strong and weak counts)
* @return Previous strong count
*/
uint64_t decrement_strong() noexcept {
uint64_t old_value;
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value =
(static_cast<uint64_t>(weak_count) << 32) | (strong_count - 1);
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_acq_rel));
return old_value;
uint32_t decrement_strong() noexcept {
return strong_count.fetch_sub(1, std::memory_order_acq_rel);
}
/**
* @brief Increment weak reference count
* @return Previous ref_counts value (both strong and weak counts)
* @return Previous weak count
*/
uint64_t increment_weak() noexcept {
uint64_t old_value;
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value = (static_cast<uint64_t>(weak_count + 1) << 32) | strong_count;
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_relaxed));
return old_value;
uint32_t increment_weak() noexcept {
return weak_count.fetch_add(1, std::memory_order_relaxed);
}
/**
* @brief Decrement weak reference count
* @return Previous ref_counts value (both strong and weak counts)
* @return Previous weak count
*/
uint64_t decrement_weak() noexcept {
uint64_t old_value;
uint64_t new_value;
do {
old_value = ref_counts.load(std::memory_order_relaxed);
uint32_t strong_count = static_cast<uint32_t>(old_value);
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value = (static_cast<uint64_t>(weak_count - 1) << 32) | strong_count;
} while (!ref_counts.compare_exchange_weak(old_value, new_value,
std::memory_order_acq_rel));
return old_value;
uint32_t decrement_weak() noexcept {
return weak_count.fetch_sub(1, std::memory_order_acq_rel);
}
};
} // namespace detail
@@ -214,17 +172,17 @@ private:
*/
void release() noexcept {
if (control_block) {
uint64_t prev = control_block->decrement_strong();
uint32_t prev_strong = static_cast<uint32_t>(prev);
uint32_t prev_strong = control_block->decrement_strong();
// If this was the last strong reference, destroy the object
if (prev_strong == 1) {
T *obj = get();
obj->~T();
// If no weak references either, free the entire allocation
uint32_t prev_weak = static_cast<uint32_t>(prev >> 32);
if (prev_weak == 0) {
// Check if there are any weak references
uint32_t current_weak =
control_block->weak_count.load(std::memory_order_acquire);
if (current_weak == 0) {
std::free(control_block);
}
}
@@ -246,28 +204,23 @@ template <typename T> struct WeakRef {
if (!control_block) {
return Ref<T>();
}
uint64_t old_value;
uint64_t new_value;
do {
// Use acquire ordering to ensure that any subsequent use of the returned
// Ref (like dereferencing the object pointer) cannot be reordered before
// this safety check. This would ideally use memory_order_consume for
// dependency ordering, but the folk wisdom is "don't use that".
old_value = control_block->ref_counts.load(std::memory_order_acquire);
uint32_t strong_count = static_cast<uint32_t>(old_value);
// If strong count is 0, object is being destroyed
if (strong_count == 0) {
return Ref<T>();
// Try to increment strong count if it's not zero
uint32_t expected_strong =
control_block->strong_count.load(std::memory_order_relaxed);
while (expected_strong > 0) {
// Try to increment the strong count
if (control_block->strong_count.compare_exchange_weak(
expected_strong, expected_strong + 1, std::memory_order_acquire,
std::memory_order_relaxed)) {
// Success - we incremented the strong count
return Ref<T>(control_block);
}
// CAS failed, expected_strong now contains the current value, retry
}
uint32_t weak_count = static_cast<uint32_t>(old_value >> 32);
new_value =
(static_cast<uint64_t>(weak_count) << 32) | (strong_count + 1);
} while (!control_block->ref_counts.compare_exchange_weak(
old_value, new_value, std::memory_order_relaxed));
return Ref<T>(control_block);
// Strong count was 0, object is being destroyed
return Ref<T>();
}
/**
@@ -361,14 +314,16 @@ private:
*/
void release() noexcept {
if (control_block) {
uint64_t prev = control_block->decrement_weak();
uint32_t prev_strong = static_cast<uint32_t>(prev);
uint32_t prev_weak = static_cast<uint32_t>(prev >> 32);
uint32_t prev_weak = control_block->decrement_weak();
// If this was the last weak reference and no strong references, free
// control block
if (prev_weak == 1 && prev_strong == 0) {
std::free(control_block);
// If this was the last weak reference, check if we need to free control
// block
if (prev_weak == 1) {
uint32_t current_strong =
control_block->strong_count.load(std::memory_order_acquire);
if (current_strong == 0) {
std::free(control_block);
}
}
}
}