From: Bennet Bleßmann Date: Sat, 6 Jul 2024 12:49:19 +0000 (+0200) Subject: switch Rcu to the arcu crate X-Git-Tag: v0.10.0~127^2~20 X-Git-Url: https://git.sagredo.dev/?a=commitdiff_plain;h=d87400afa0fef3edffd6295c747e27b229ee7c9d;p=scryer-prolog.git switch Rcu to the arcu crate The arcu crate is a more general implementation of the Rcu I implemented in here in scryer. It contains some bug-fixes regarding race-conditions in the Rcu update function, which could cause leaks and uses after free. Source of the problem was the Relaxed load/strore/update of the reference count in side the Arc not being properly ordered with other load/stores. --- diff --git a/Cargo.lock b/Cargo.lock index 43806baf..eb2bd1b2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "arcu" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8727c0fb4c436605c8f11c579ec86edcb729134aec4ee66e454efd99a91859f" + [[package]] name = "arrayvec" version = "0.5.2" @@ -2558,6 +2564,7 @@ dependencies = [ name = "scryer-prolog" version = "0.9.4" dependencies = [ + "arcu", "assert_cmd", "base64 0.12.3", "bit-set", diff --git a/Cargo.toml b/Cargo.toml index 3fe35772..45ba6d6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ static_assertions = "1.1.0" serde_json = "1.0.95" serde = "1.0.159" +arcu = { version = "0.1.1", features = ["thread_local_counter"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] crossterm = { version = "0.20.0", optional = true } diff --git a/src/arena.rs b/src/arena.rs index e0cdeef9..571595a6 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -7,12 +7,14 @@ use crate::machine::machine_indices::*; use crate::machine::streams::*; use crate::parser::char_reader::CharReader; use crate::raw_block::*; -use crate::rcu::Rcu; -use crate::rcu::RcuRef; use crate::read::*; use crate::types::UntypedArenaPtr; use crate::parser::dashu::{Integer, Rational}; +use arcu::atomic::Arcu; +use arcu::epoch_counters::GlobalEpochCounterPool; +use arcu::rcu_ref::RcuRef; +use arcu::Rcu; use ordered_float::OrderedFloat; use std::cell::UnsafeCell; @@ -79,7 +81,7 @@ impl RawBlockTraits for F64Table { #[derive(Debug)] pub struct F64Table { - block: Rcu>, + block: Arcu, GlobalEpochCounterPool>, update: Mutex<()>, } @@ -93,7 +95,7 @@ pub fn lookup_float( .upgrade() .expect("We should only be looking up floats while there is a float table"); - RcuRef::try_map(f64table.block.active_epoch(), |raw_block| unsafe { + RcuRef::try_map(f64table.block.read(), |raw_block| unsafe { raw_block .base .add(offset.0) @@ -118,7 +120,7 @@ impl F64Table { atom_table } else { let atom_table = Arc::new(Self { - block: Rcu::new(RawBlock::new()), + block: Arcu::new(RawBlock::new(), GlobalEpochCounterPool), update: Mutex::new(()), }); *guard = Arc::downgrade(&atom_table); @@ -133,7 +135,7 @@ impl F64Table { // we don't have an index table for lookups as AtomTable does so // just get the epoch after we take the upgrade lock - let mut block_epoch = self.block.active_epoch(); + let mut block_epoch = self.block.read(); let mut ptr; @@ -143,7 +145,7 @@ impl F64Table { if ptr.is_null() { let new_block = block_epoch.grow_new().unwrap(); self.block.replace(new_block); - block_epoch = self.block.active_epoch(); + block_epoch = self.block.read(); } else { break; } diff --git a/src/atom_table.rs b/src/atom_table.rs index 5216f6a3..4a426cd0 100644 --- a/src/atom_table.rs +++ b/src/atom_table.rs @@ -2,7 +2,6 @@ use crate::parser::ast::MAX_ARITY; use crate::raw_block::*; -use crate::rcu::{Rcu, RcuRef}; use crate::types::*; use std::cmp::Ordering; @@ -16,6 +15,10 @@ use std::sync::Mutex; use std::sync::RwLock; use std::sync::Weak; +use arcu::atomic::Arcu; +use arcu::epoch_counters::GlobalEpochCounterPool; +use arcu::rcu_ref::RcuRef; +use arcu::Rcu; use indexmap::IndexSet; use scryer_modular_bitfield::prelude::*; @@ -180,7 +183,7 @@ impl Atom { let atom_table = arc_atom_table().expect("We should only have an Atom while there is an AtomTable"); - AtomTableRef::try_map(atom_table.inner.active_epoch(), |buf| unsafe { + AtomTableRef::try_map(atom_table.inner.read(), |buf| unsafe { let ptr = buf .block .base @@ -278,17 +281,17 @@ impl Ord for Atom { #[derive(Debug)] pub struct InnerAtomTable { block: RawBlock, - pub table: Rcu>, + pub table: Arcu, GlobalEpochCounterPool>, } #[derive(Debug)] pub struct AtomTable { - inner: Rcu, + inner: Arcu, // this lock is taking during resizing update: Mutex<()>, } -pub type AtomTableRef = RcuRef; +pub type AtomTableRef = arcu::rcu_ref::RcuRef; impl InnerAtomTable { #[inline(always)] @@ -296,7 +299,7 @@ impl InnerAtomTable { STATIC_ATOMS_MAP .get(string) .cloned() - .or_else(|| self.table.active_epoch().get(string).cloned()) + .or_else(|| self.table.read().get(string).cloned()) } } @@ -314,10 +317,13 @@ impl AtomTable { atom_table } else { let atom_table = Arc::new(Self { - inner: Rcu::new(InnerAtomTable { - block: RawBlock::new(), - table: Rcu::new(IndexSet::new()), - }), + inner: Arcu::new( + InnerAtomTable { + block: RawBlock::new(), + table: Arcu::new(IndexSet::new(), GlobalEpochCounterPool), + }, + GlobalEpochCounterPool, + ), update: Mutex::new(()), }); *guard = Arc::downgrade(&atom_table); @@ -327,13 +333,13 @@ impl AtomTable { } pub fn active_table(&self) -> RcuRef, IndexSet> { - self.inner.active_epoch().table.active_epoch() + self.inner.read().table.read() } pub fn build_with(atom_table: &AtomTable, string: &str) -> Atom { loop { - let mut block_epoch = atom_table.inner.active_epoch(); - let mut table_epoch = block_epoch.table.active_epoch(); + let mut block_epoch = atom_table.inner.read(); + let mut table_epoch = block_epoch.table.read(); if let Some(atom) = block_epoch.lookup_str(string) { return atom; @@ -342,10 +348,8 @@ impl AtomTable { // take a lock to prevent concurrent updates let update_guard = atom_table.update.lock().unwrap(); - let is_same_allocation = - RcuRef::same_epoch(&block_epoch, &atom_table.inner.active_epoch()); - let is_same_atom_list = - RcuRef::same_epoch(&table_epoch, &block_epoch.table.active_epoch()); + let is_same_allocation = RcuRef::same_epoch(&block_epoch, &atom_table.inner.read()); + let is_same_atom_list = RcuRef::same_epoch(&table_epoch, &block_epoch.table.read()); if !(is_same_allocation && is_same_atom_list) { // some other thread raced us between our lookup and @@ -364,14 +368,14 @@ impl AtomTable { if ptr.is_null() { // garbage collection would go here let new_block = block_epoch.block.grow_new().unwrap(); - let new_table = Rcu::new(table_epoch.clone()); + let new_table = Arcu::new(table_epoch.clone(), GlobalEpochCounterPool); let new_alloc = InnerAtomTable { block: new_block, table: new_table, }; atom_table.inner.replace(new_alloc); - block_epoch = atom_table.inner.active_epoch(); - table_epoch = block_epoch.table.active_epoch(); + block_epoch = atom_table.inner.read(); + table_epoch = block_epoch.table.read(); } else { break ptr; } diff --git a/src/lib.rs b/src/lib.rs index 56d967eb..90f498a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,8 +42,6 @@ pub mod types; use instructions::instr; -mod rcu; - #[cfg(target_arch = "wasm32")] use wasm_bindgen::prelude::*; diff --git a/src/rcu.rs b/src/rcu.rs deleted file mode 100644 index 75ecef31..00000000 --- a/src/rcu.rs +++ /dev/null @@ -1,220 +0,0 @@ -use std::{ - cell::OnceCell, - fmt::Debug, - mem::ManuallyDrop, - ops::Deref, - ptr::NonNull, - sync::{ - atomic::{AtomicPtr, AtomicU8}, - Arc, RwLock, Weak, - }, -}; - -// the epoch counters of all threads that have ever accessed an Rcu -// threads that have finished will have a dangling Weak reference and can be cleand up -// having this be shared between all Rcu's is a tradeof, -// writes will be slower as more epoch counters need to be waited for -// reads should be faster as a thread only needs to register itself once on the first read -// -static EPOCH_COUNTERS: RwLock>> = RwLock::new(Vec::new()); - -thread_local! { - // odd value means the current thread is about to access the active_epoch of an Rcu - // a thread has a single epoch counter for all Rcu it accesses, - // as a thread can only access one Rcu at a time - static THREAD_EPOCH_COUNTER: OnceCell> = const { OnceCell::new() }; -} - -pub struct Rcu { - active_value: AtomicPtr, -} - -impl std::fmt::Debug for Rcu { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let active_epoch = self.active_epoch(); - f.debug_struct("Rcu") - .field("active_value", &active_epoch) - .finish() - } -} - -impl Rcu { - pub fn new(initial_value: T) -> Self { - Rcu { - active_value: AtomicPtr::new(Arc::into_raw(Arc::new(initial_value)).cast_mut()), - } - } - - pub fn active_epoch(&self) -> RcuRef { - THREAD_EPOCH_COUNTER.with(|epoch_counter| { - let epoch_counter = epoch_counter.get_or_init(|| { - let epoch_counter = Arc::new(AtomicU8::new(0)); - // register the current threads epoch counter on init - EPOCH_COUNTERS - .write() - .unwrap() - .push(Arc::downgrade(&epoch_counter)); - epoch_counter - }); - - let old = epoch_counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel); - assert!(old % 2 == 0, "Old Epoch counter value should be even!"); - }); - - let arc_ptr = self.active_value.load(std::sync::atomic::Ordering::Acquire); - - let arc = unsafe { - // Safety: - // - the ptr was created in Rcu::new or Rcu::replace with Arc::into_raw - // - the Rcu is responsible for of the arc's strong refrences - // - the Rcu is alive as this function takes a reference to the Rcu - // - replace will wait with decrementing the old values strong count until our epoich counter is even again - Arc::increment_strong_count(arc_ptr); - // Safety: - // - the ptr was created in Rcu::new or Rcu::replace with Arc::into_raw - // - we have just ensured an additional strong count by incrementing the count - Arc::from_raw(arc_ptr) - }; - - THREAD_EPOCH_COUNTER.with(|epoch_counter| { - let old = epoch_counter - .get().expect("we initialized the OnceCell when we incremented the epoch counter the fist time") - .fetch_add(1, std::sync::atomic::Ordering::AcqRel); - assert!(old % 2 != 0, "Old Epoch counter value should be odd!"); - }); - - RcuRef { - data: arc.deref().into(), - arc, - } - } - - /* - * replace the Rcu'S content with a new value - * - * This does not syncronize write and last to update the active_value pointer wins, - * all writes that do not win will be lost, though not leaked. - * This will block untill the old value can be reclaimed, - * i.e. all threads whitnest to be in the read critical sections - * have been witnest to have left the critical section at least once - */ - pub fn replace(&self, new_value: T) { - let arc_ptr = self.active_value.swap( - Arc::into_raw(Arc::new(new_value)).cast_mut(), - std::sync::atomic::Ordering::AcqRel, - ); - - // maually drop as we need to ensure not to drop the arc while - // we have not witnest all threads to be or have been outside the read critical section - // i.e. even epoch counter or different odd epoch counter - // Safety: - // - the ptr was created in Rcu::new or Rcu::replace with Arc::into_raw - // - the Rcu itself holds one strong count - let arc = unsafe { ManuallyDrop::new(Arc::from_raw(arc_ptr)) }; - - let epochs = EPOCH_COUNTERS.read().unwrap().clone(); - let mut epochs = epochs - .into_iter() - .flat_map(|elem| { - let arc = elem.upgrade()?; - let init_val = arc.load(std::sync::atomic::Ordering::Acquire); - if init_val % 2 == 0 { - // already even can be ignored - return None; - } - // odd initial value thread is in read critical section - // need to wait for the value to change before we can drop the arc - Some((init_val, elem)) - }) - .collect::>(); - - while !epochs.is_empty() { - epochs.retain(|elem| { - let Some(arc) = elem.1.upgrade() else { - // as the thread is dead it can't have a ref to old arc - return false; - }; - // the epoch counter has not changed so the thread is still in the same instance of the critical section - // any different value is ok as - // - even values indicate the thread is outside the critical section - // - a diffrent odd value indicates the thread has left the critical section and can subsequently only read the new active_value - arc.load(std::sync::atomic::Ordering::Acquire) == elem.0 - }) - } - - // Safety: - // - we have not dropped the arc another way - // - we witnessed all threads either with an even epoch count or with a new odd count - // as such they must have left the critical section at some point - ManuallyDrop::into_inner(arc); - } -} - -pub struct RcuRef -where - T: ?Sized, - M: ?Sized, -{ - arc: Arc, - data: NonNull, -} - -impl Debug for RcuRef { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RcuRef") - .field("data", &self.deref()) - .finish() - } -} - -// use assoiated functions rather than methods so that we don't overlap -// with functions of the Deref Target type -impl RcuRef { - pub fn map FnOnce(&'a M) -> &'a N>(referece: Self, f: F) -> RcuRef { - RcuRef { - arc: referece.arc, - data: f(unsafe { referece.data.as_ref() }).into(), - } - } - - pub fn try_map FnOnce(&'a M) -> Option<&'a N>>( - referece: Self, - f: F, - ) -> Option> { - let val = f(unsafe { referece.data.as_ref() })?; - Some(RcuRef { - arc: Arc::clone(&referece.arc), - data: val.into(), - }) - } - - pub fn same_epoch(this: &Self, other: &RcuRef) -> bool { - Arc::ptr_eq(&this.arc, &other.arc) - } - - pub fn ptr_eq(this: &Self, other: &Self) -> bool { - std::ptr::addr_eq(this.data.as_ptr(), other.data.as_ptr()) - } - - pub fn clone(this: &Self) -> Self { - Self { - arc: Arc::clone(&this.arc), - data: this.data, - } - } - - pub fn get_root(this: &Self) -> &T { - &this.arc - } -} - -impl Deref for RcuRef { - type Target = M; - - fn deref(&self) -> &Self::Target { - // Safety: The pointer points into the arc we are holding - // while we are alive so is the target - // as the content is in an Rcu no mutable acess is given out - unsafe { self.data.as_ref() } - } -}