]> Repositorios git - scryer-prolog.git/commitdiff
switch Rcu to the arcu crate
authorBennet Bleßmann <[email protected]>
Sat, 6 Jul 2024 12:49:19 +0000 (14:49 +0200)
committerBennet Bleßmann <[email protected]>
Sat, 6 Jul 2024 12:49:19 +0000 (14:49 +0200)
The arcu crate is a more general implementation of the Rcu I implemented in here in scryer. It contains some bug-fixes regarding race-conditions in the Rcu update function, which could cause leaks and uses after free.

Source of the problem was the Relaxed load/strore/update of the reference count in side the Arc not being properly ordered with other load/stores.

Cargo.lock
Cargo.toml
src/arena.rs
src/atom_table.rs
src/lib.rs
src/rcu.rs [deleted file]

index 43806baf7e19c3210a9c58a431e54f56e63f6a11..eb2bd1b2313bfbbb05da7d195b3b79f40cff781d 100644 (file)
@@ -108,6 +108,12 @@ dependencies = [
  "windows-sys 0.52.0",
 ]
 
+[[package]]
+name = "arcu"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f8727c0fb4c436605c8f11c579ec86edcb729134aec4ee66e454efd99a91859f"
+
 [[package]]
 name = "arrayvec"
 version = "0.5.2"
@@ -2558,6 +2564,7 @@ dependencies = [
 name = "scryer-prolog"
 version = "0.9.4"
 dependencies = [
+ "arcu",
  "assert_cmd",
  "base64 0.12.3",
  "bit-set",
index 3fe3577280b9714afd09e1854de78e14c07cdc6e..45ba6d6efdafaa12761a638230729eba4a4a4ebc 100644 (file)
@@ -74,6 +74,7 @@ static_assertions = "1.1.0"
 
 serde_json = "1.0.95"
 serde = "1.0.159"
+arcu = { version = "0.1.1", features = ["thread_local_counter"] }
 
 [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
 crossterm = { version = "0.20.0", optional = true }
index e0cdeef9c3026322b5194de18165489c2a96cf8b..571595a6687c6f3c8063117cb1edb67b27b426fe 100644 (file)
@@ -7,12 +7,14 @@ use crate::machine::machine_indices::*;
 use crate::machine::streams::*;
 use crate::parser::char_reader::CharReader;
 use crate::raw_block::*;
-use crate::rcu::Rcu;
-use crate::rcu::RcuRef;
 use crate::read::*;
 use crate::types::UntypedArenaPtr;
 
 use crate::parser::dashu::{Integer, Rational};
+use arcu::atomic::Arcu;
+use arcu::epoch_counters::GlobalEpochCounterPool;
+use arcu::rcu_ref::RcuRef;
+use arcu::Rcu;
 use ordered_float::OrderedFloat;
 
 use std::cell::UnsafeCell;
@@ -79,7 +81,7 @@ impl RawBlockTraits for F64Table {
 
 #[derive(Debug)]
 pub struct F64Table {
-    block: Rcu<RawBlock<F64Table>>,
+    block: Arcu<RawBlock<F64Table>, GlobalEpochCounterPool>,
     update: Mutex<()>,
 }
 
@@ -93,7 +95,7 @@ pub fn lookup_float(
         .upgrade()
         .expect("We should only be looking up floats while there is a float table");
 
-    RcuRef::try_map(f64table.block.active_epoch(), |raw_block| unsafe {
+    RcuRef::try_map(f64table.block.read(), |raw_block| unsafe {
         raw_block
             .base
             .add(offset.0)
@@ -118,7 +120,7 @@ impl F64Table {
                 atom_table
             } else {
                 let atom_table = Arc::new(Self {
-                    block: Rcu::new(RawBlock::new()),
+                    block: Arcu::new(RawBlock::new(), GlobalEpochCounterPool),
                     update: Mutex::new(()),
                 });
                 *guard = Arc::downgrade(&atom_table);
@@ -133,7 +135,7 @@ impl F64Table {
 
         // we don't have an index table for lookups as AtomTable does so
         // just get the epoch after we take the upgrade lock
-        let mut block_epoch = self.block.active_epoch();
+        let mut block_epoch = self.block.read();
 
         let mut ptr;
 
@@ -143,7 +145,7 @@ impl F64Table {
             if ptr.is_null() {
                 let new_block = block_epoch.grow_new().unwrap();
                 self.block.replace(new_block);
-                block_epoch = self.block.active_epoch();
+                block_epoch = self.block.read();
             } else {
                 break;
             }
index 5216f6a31736840fb246eee4144d5e06b9dcc456..4a426cd0cb049e1c6683ea568e5aa48ed146ae2b 100644 (file)
@@ -2,7 +2,6 @@
 
 use crate::parser::ast::MAX_ARITY;
 use crate::raw_block::*;
-use crate::rcu::{Rcu, RcuRef};
 use crate::types::*;
 
 use std::cmp::Ordering;
@@ -16,6 +15,10 @@ use std::sync::Mutex;
 use std::sync::RwLock;
 use std::sync::Weak;
 
+use arcu::atomic::Arcu;
+use arcu::epoch_counters::GlobalEpochCounterPool;
+use arcu::rcu_ref::RcuRef;
+use arcu::Rcu;
 use indexmap::IndexSet;
 
 use scryer_modular_bitfield::prelude::*;
@@ -180,7 +183,7 @@ impl Atom {
             let atom_table =
                 arc_atom_table().expect("We should only have an Atom while there is an AtomTable");
 
-            AtomTableRef::try_map(atom_table.inner.active_epoch(), |buf| unsafe {
+            AtomTableRef::try_map(atom_table.inner.read(), |buf| unsafe {
                 let ptr = buf
                     .block
                     .base
@@ -278,17 +281,17 @@ impl Ord for Atom {
 #[derive(Debug)]
 pub struct InnerAtomTable {
     block: RawBlock<AtomTable>,
-    pub table: Rcu<IndexSet<Atom>>,
+    pub table: Arcu<IndexSet<Atom>, GlobalEpochCounterPool>,
 }
 
 #[derive(Debug)]
 pub struct AtomTable {
-    inner: Rcu<InnerAtomTable>,
+    inner: Arcu<InnerAtomTable, GlobalEpochCounterPool>,
     // this lock is taking during resizing
     update: Mutex<()>,
 }
 
-pub type AtomTableRef<M> = RcuRef<InnerAtomTable, M>;
+pub type AtomTableRef<M> = arcu::rcu_ref::RcuRef<InnerAtomTable, M>;
 
 impl InnerAtomTable {
     #[inline(always)]
@@ -296,7 +299,7 @@ impl InnerAtomTable {
         STATIC_ATOMS_MAP
             .get(string)
             .cloned()
-            .or_else(|| self.table.active_epoch().get(string).cloned())
+            .or_else(|| self.table.read().get(string).cloned())
     }
 }
 
@@ -314,10 +317,13 @@ impl AtomTable {
                 atom_table
             } else {
                 let atom_table = Arc::new(Self {
-                    inner: Rcu::new(InnerAtomTable {
-                        block: RawBlock::new(),
-                        table: Rcu::new(IndexSet::new()),
-                    }),
+                    inner: Arcu::new(
+                        InnerAtomTable {
+                            block: RawBlock::new(),
+                            table: Arcu::new(IndexSet::new(), GlobalEpochCounterPool),
+                        },
+                        GlobalEpochCounterPool,
+                    ),
                     update: Mutex::new(()),
                 });
                 *guard = Arc::downgrade(&atom_table);
@@ -327,13 +333,13 @@ impl AtomTable {
     }
 
     pub fn active_table(&self) -> RcuRef<IndexSet<Atom>, IndexSet<Atom>> {
-        self.inner.active_epoch().table.active_epoch()
+        self.inner.read().table.read()
     }
 
     pub fn build_with(atom_table: &AtomTable, string: &str) -> Atom {
         loop {
-            let mut block_epoch = atom_table.inner.active_epoch();
-            let mut table_epoch = block_epoch.table.active_epoch();
+            let mut block_epoch = atom_table.inner.read();
+            let mut table_epoch = block_epoch.table.read();
 
             if let Some(atom) = block_epoch.lookup_str(string) {
                 return atom;
@@ -342,10 +348,8 @@ impl AtomTable {
             // take a lock to prevent concurrent updates
             let update_guard = atom_table.update.lock().unwrap();
 
-            let is_same_allocation =
-                RcuRef::same_epoch(&block_epoch, &atom_table.inner.active_epoch());
-            let is_same_atom_list =
-                RcuRef::same_epoch(&table_epoch, &block_epoch.table.active_epoch());
+            let is_same_allocation = RcuRef::same_epoch(&block_epoch, &atom_table.inner.read());
+            let is_same_atom_list = RcuRef::same_epoch(&table_epoch, &block_epoch.table.read());
 
             if !(is_same_allocation && is_same_atom_list) {
                 // some other thread raced us between our lookup and
@@ -364,14 +368,14 @@ impl AtomTable {
                     if ptr.is_null() {
                         // garbage collection would go here
                         let new_block = block_epoch.block.grow_new().unwrap();
-                        let new_table = Rcu::new(table_epoch.clone());
+                        let new_table = Arcu::new(table_epoch.clone(), GlobalEpochCounterPool);
                         let new_alloc = InnerAtomTable {
                             block: new_block,
                             table: new_table,
                         };
                         atom_table.inner.replace(new_alloc);
-                        block_epoch = atom_table.inner.active_epoch();
-                        table_epoch = block_epoch.table.active_epoch();
+                        block_epoch = atom_table.inner.read();
+                        table_epoch = block_epoch.table.read();
                     } else {
                         break ptr;
                     }
index 56d967eb256d77831883d71465c21e8a1ef2b4d8..90f498a30131703839e7a3fcb7975d64c3497dae 100644 (file)
@@ -42,8 +42,6 @@ pub mod types;
 
 use instructions::instr;
 
-mod rcu;
-
 #[cfg(target_arch = "wasm32")]
 use wasm_bindgen::prelude::*;
 
diff --git a/src/rcu.rs b/src/rcu.rs
deleted file mode 100644 (file)
index 75ecef3..0000000
+++ /dev/null
@@ -1,220 +0,0 @@
-use std::{
-    cell::OnceCell,
-    fmt::Debug,
-    mem::ManuallyDrop,
-    ops::Deref,
-    ptr::NonNull,
-    sync::{
-        atomic::{AtomicPtr, AtomicU8},
-        Arc, RwLock, Weak,
-    },
-};
-
-// the epoch counters of all threads that have ever accessed an Rcu
-// threads that have finished will have a dangling Weak reference and can be cleand up
-// having this be shared between all Rcu's is a tradeof,
-// writes will be slower as more epoch counters need to be waited for
-// reads should be faster as a thread only needs to register itself once on the first read
-//
-static EPOCH_COUNTERS: RwLock<Vec<Weak<AtomicU8>>> = RwLock::new(Vec::new());
-
-thread_local! {
-    // odd value means the current thread is about to access the active_epoch of an Rcu
-    // a thread has a single epoch counter for all Rcu it accesses,
-    // as a thread can only access one Rcu at a time
-    static THREAD_EPOCH_COUNTER: OnceCell<Arc<AtomicU8>> = const { OnceCell::new() };
-}
-
-pub struct Rcu<T> {
-    active_value: AtomicPtr<T>,
-}
-
-impl<T: std::fmt::Debug> std::fmt::Debug for Rcu<T> {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        let active_epoch = self.active_epoch();
-        f.debug_struct("Rcu")
-            .field("active_value", &active_epoch)
-            .finish()
-    }
-}
-
-impl<T> Rcu<T> {
-    pub fn new(initial_value: T) -> Self {
-        Rcu {
-            active_value: AtomicPtr::new(Arc::into_raw(Arc::new(initial_value)).cast_mut()),
-        }
-    }
-
-    pub fn active_epoch(&self) -> RcuRef<T, T> {
-        THREAD_EPOCH_COUNTER.with(|epoch_counter| {
-            let epoch_counter = epoch_counter.get_or_init(|| {
-                let epoch_counter = Arc::new(AtomicU8::new(0));
-                // register the current threads epoch counter on init
-                EPOCH_COUNTERS
-                    .write()
-                    .unwrap()
-                    .push(Arc::downgrade(&epoch_counter));
-                epoch_counter
-            });
-
-            let old = epoch_counter.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
-            assert!(old % 2 == 0, "Old Epoch counter value should be even!");
-        });
-
-        let arc_ptr = self.active_value.load(std::sync::atomic::Ordering::Acquire);
-
-        let arc = unsafe {
-            // Safety:
-            // - the ptr was created in Rcu::new or Rcu::replace with Arc::into_raw
-            // - the Rcu is responsible for of the arc's strong refrences
-            // - the Rcu is alive as this function takes a reference to the Rcu
-            // - replace will wait with decrementing the old values strong count until our epoich counter is even again
-            Arc::increment_strong_count(arc_ptr);
-            // Safety:
-            // - the ptr was created in Rcu::new or Rcu::replace with Arc::into_raw
-            // - we have just ensured an additional strong count by incrementing the count
-            Arc::from_raw(arc_ptr)
-        };
-
-        THREAD_EPOCH_COUNTER.with(|epoch_counter| {
-            let old = epoch_counter
-                .get().expect("we initialized the OnceCell when we incremented the epoch counter the fist time")
-                .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
-            assert!(old % 2 != 0, "Old Epoch counter value should be odd!");
-        });
-
-        RcuRef {
-            data: arc.deref().into(),
-            arc,
-        }
-    }
-
-    /*
-     *  replace the Rcu'S content with a new value
-     *
-     *  This does not syncronize write and last to update the active_value pointer wins,
-     *  all writes that do not win will be lost, though not leaked.
-     *  This will block untill the old value can be reclaimed,
-     *  i.e. all threads whitnest to be in the read critical sections
-     *       have been witnest to have left the critical section at least once
-     */
-    pub fn replace(&self, new_value: T) {
-        let arc_ptr = self.active_value.swap(
-            Arc::into_raw(Arc::new(new_value)).cast_mut(),
-            std::sync::atomic::Ordering::AcqRel,
-        );
-
-        // maually drop as we need to ensure not to drop the arc while
-        // we have not witnest all threads to be or have been outside the read critical section
-        // i.e. even epoch counter or different odd epoch counter
-        // Safety:
-        // - the ptr was created in Rcu::new or Rcu::replace with Arc::into_raw
-        // - the Rcu itself holds one strong count
-        let arc = unsafe { ManuallyDrop::new(Arc::from_raw(arc_ptr)) };
-
-        let epochs = EPOCH_COUNTERS.read().unwrap().clone();
-        let mut epochs = epochs
-            .into_iter()
-            .flat_map(|elem| {
-                let arc = elem.upgrade()?;
-                let init_val = arc.load(std::sync::atomic::Ordering::Acquire);
-                if init_val % 2 == 0 {
-                    // already even can be ignored
-                    return None;
-                }
-                // odd initial value thread is in read critical section
-                // need to wait for the value to change before we can drop the arc
-                Some((init_val, elem))
-            })
-            .collect::<Vec<_>>();
-
-        while !epochs.is_empty() {
-            epochs.retain(|elem| {
-                let Some(arc) = elem.1.upgrade() else {
-                    // as the thread is dead it can't have a ref to old arc
-                    return false;
-                };
-                // the epoch counter has not changed so the thread is still in the same instance of the critical section
-                // any different value is ok as
-                // - even values indicate the thread is outside the critical section
-                // - a diffrent odd value indicates the thread has left the critical section and can subsequently only read the new active_value
-                arc.load(std::sync::atomic::Ordering::Acquire) == elem.0
-            })
-        }
-
-        // Safety:
-        // - we have not dropped the arc another way
-        // - we witnessed all threads either with an even epoch count or with a new odd count
-        //   as such they must have left the critical section at some point
-        ManuallyDrop::into_inner(arc);
-    }
-}
-
-pub struct RcuRef<T, M>
-where
-    T: ?Sized,
-    M: ?Sized,
-{
-    arc: Arc<T>,
-    data: NonNull<M>,
-}
-
-impl<T: ?Sized, M: ?Sized + Debug> Debug for RcuRef<T, M> {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        f.debug_struct("RcuRef")
-            .field("data", &self.deref())
-            .finish()
-    }
-}
-
-// use assoiated functions rather than methods so that we don't overlap
-// with functions of the Deref Target type
-impl<T: ?Sized, M: ?Sized> RcuRef<T, M> {
-    pub fn map<N: ?Sized, F: for<'a> FnOnce(&'a M) -> &'a N>(referece: Self, f: F) -> RcuRef<T, N> {
-        RcuRef {
-            arc: referece.arc,
-            data: f(unsafe { referece.data.as_ref() }).into(),
-        }
-    }
-
-    pub fn try_map<N: ?Sized, F: for<'a> FnOnce(&'a M) -> Option<&'a N>>(
-        referece: Self,
-        f: F,
-    ) -> Option<RcuRef<T, N>> {
-        let val = f(unsafe { referece.data.as_ref() })?;
-        Some(RcuRef {
-            arc: Arc::clone(&referece.arc),
-            data: val.into(),
-        })
-    }
-
-    pub fn same_epoch<M2>(this: &Self, other: &RcuRef<T, M2>) -> bool {
-        Arc::ptr_eq(&this.arc, &other.arc)
-    }
-
-    pub fn ptr_eq(this: &Self, other: &Self) -> bool {
-        std::ptr::addr_eq(this.data.as_ptr(), other.data.as_ptr())
-    }
-
-    pub fn clone(this: &Self) -> Self {
-        Self {
-            arc: Arc::clone(&this.arc),
-            data: this.data,
-        }
-    }
-
-    pub fn get_root(this: &Self) -> &T {
-        &this.arc
-    }
-}
-
-impl<T: ?Sized, M: ?Sized> Deref for RcuRef<T, M> {
-    type Target = M;
-
-    fn deref(&self) -> &Self::Target {
-        // Safety: The pointer points into the arc we are holding
-        // while we are alive so is the target
-        // as the content is in an Rcu no mutable acess is given out
-        unsafe { self.data.as_ref() }
-    }
-}