From 0cf46d3ec4f0fc1e9fd9c8222336a28ea3b4ad3b Mon Sep 17 00:00:00 2001 From: Emilie Burgun Date: Mon, 3 Feb 2025 14:00:37 +0100 Subject: [PATCH] Encapsulate accesses to IndexStore::streams and ::stream_aliases These two fields are able to hold `Stream` instances, which predicates like `close/1` expect to be managed properly for their correctness. To ensure that this is the case, I have removed direct accesses to those two fields, so that they can be properly managed in one place. --- src/machine/loader.rs | 2 +- src/machine/machine_indices.rs | 105 ++++++++++++++++++++- src/machine/mod.rs | 28 ++---- src/machine/streams.rs | 12 +-- src/machine/system_calls.rs | 164 +++++++++++++++------------------ 5 files changed, 187 insertions(+), 124 deletions(-) diff --git a/src/machine/loader.rs b/src/machine/loader.rs index 90080e9e..b7e7c4b5 100644 --- a/src/machine/loader.rs +++ b/src/machine/loader.rs @@ -1808,7 +1808,7 @@ impl Machine { pub(crate) fn push_load_context(&mut self) -> CallResult { let stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("$push_load_context"), 2, )?; diff --git a/src/machine/machine_indices.rs b/src/machine/machine_indices.rs index 2f259125..4c849068 100644 --- a/src/machine/machine_indices.rs +++ b/src/machine/machine_indices.rs @@ -7,8 +7,9 @@ use crate::atom_table::*; use crate::forms::*; use crate::machine::loader::*; use crate::machine::machine_state::*; -use crate::machine::streams::Stream; +use crate::machine::streams::{Stream, StreamOptions}; use crate::machine::ClauseType; +use crate::machine::MachineStubGen; use fxhash::FxBuildHasher; use indexmap::{IndexMap, IndexSet}; @@ -261,8 +262,8 @@ pub struct IndexStore { pub(super) meta_predicates: MetaPredicateDir, pub(super) modules: ModuleDir, pub(super) op_dir: OpDir, - pub(super) streams: StreamDir, - pub(super) stream_aliases: StreamAliasDir, + streams: StreamDir, + stream_aliases: StreamAliasDir, } impl IndexStore { @@ -459,6 +460,94 @@ impl IndexStore { } } + pub(crate) fn add_stream( + &mut self, + stream: Stream, + stub_name: Atom, + stub_arity: usize, + ) -> Result<(), MachineStubGen> { + if let Some(alias) = stream.options().get_alias() { + if self.stream_aliases.contains_key(&alias) { + return Err(Box::new(move |machine_st| { + machine_st.occupied_alias_permission_error(alias, stub_name, stub_arity) + })); + } + + self.stream_aliases.insert(alias, stream); + } + + self.streams.insert(stream); + + Ok(()) + } + + pub(crate) fn remove_stream(&mut self, stream: Stream) { + if let Some(alias) = stream.options().get_alias() { + debug_assert_eq!(self.stream_aliases.get(&alias), Some(&stream)); + self.stream_aliases.swap_remove(&alias); + } + self.streams.remove(&stream); + } + + pub(crate) fn update_stream_options( + &mut self, + mut stream: Stream, + callback: F, + ) { + if let Some(prev_alias) = stream.options().get_alias() { + debug_assert_eq!(self.stream_aliases.get(&prev_alias), Some(&stream)); + } + let options = stream.options_mut(); + let prev_alias = options.get_alias(); + + callback(options); + + if options.get_alias() != prev_alias { + if let Some(prev_alias) = prev_alias { + self.stream_aliases.swap_remove(&prev_alias); + } + if let Some(new_alias) = options.get_alias() { + self.stream_aliases.insert(new_alias, stream); + } + } + } + + pub(crate) fn has_stream(&self, alias: Atom) -> bool { + self.stream_aliases.contains_key(&alias) + } + + pub(crate) fn get_stream(&self, alias: Atom) -> Option { + self.stream_aliases.get(&alias).copied() + } + + pub(crate) fn iter_streams<'a, R: std::ops::RangeBounds>( + &'a self, + range: R, + ) -> impl Iterator + 'a { + self.streams.range(range).into_iter().copied() + } + + /// Forcibly sets `alias` to `stream`. + /// If there was a previous stream with that alias, it will lose that alias. + /// + /// Consider using [`add_stream`](Self::add_stream) if you wish to instead + /// return an error when stream aliases conflict. + pub(crate) fn set_stream(&mut self, alias: Atom, mut stream: Stream) { + if let Some(mut prev_stream) = self.get_stream(alias) { + if prev_stream == stream { + // Nothing to do, as the stream is already present + return; + } + + prev_stream.options_mut().set_alias_to_atom_opt(None); + } + + stream.options_mut().set_alias_to_atom_opt(Some(alias)); + + self.stream_aliases.insert(alias, stream); + self.streams.insert(stream); + } + #[inline] pub(super) fn new() -> Self { index_store!( @@ -468,3 +557,13 @@ impl IndexStore { ) } } + +/// A stream is said to have a "protected" alias if modifying its +/// alias would cause breakage in other parts of the code. +/// +/// A stream with a protected alias cannot be realiased through +/// [`IndexStore::update_stream_options`]. Instead, one has to use +/// [`IndexStore::set_stream`] to do so. +fn is_protected_alias(alias: Atom) -> bool { + alias == atom!("user_input") || alias == atom!("user_output") || alias == atom!("user_error") +} diff --git a/src/machine/mod.rs b/src/machine/mod.rs index a62c0caa..57a6213b 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -483,32 +483,16 @@ impl Machine { } } + /// Ensures that [`Machine::indices`] properly reflects + /// the streams stored in [`Machine::user_input`], [`Machine::user_output`] + /// and [`Machine::user_error`]. pub(crate) fn configure_streams(&mut self) { - self.user_input - .options_mut() - .set_alias_to_atom_opt(Some(atom!("user_input"))); - self.indices - .stream_aliases - .insert(atom!("user_input"), self.user_input); - - self.indices.streams.insert(self.user_input); - - self.user_output - .options_mut() - .set_alias_to_atom_opt(Some(atom!("user_output"))); - + .set_stream(atom!("user_input"), self.user_input); self.indices - .stream_aliases - .insert(atom!("user_output"), self.user_output); - - self.indices.streams.insert(self.user_output); - + .set_stream(atom!("user_output"), self.user_output); self.indices - .stream_aliases - .insert(atom!("user_error"), self.user_error); - - self.indices.streams.insert(self.user_error); + .set_stream(atom!("user_error"), self.user_error); } #[inline(always)] diff --git a/src/machine/streams.rs b/src/machine/streams.rs index 212ec463..3432fadc 100644 --- a/src/machine/streams.rs +++ b/src/machine/streams.rs @@ -1594,7 +1594,7 @@ impl MachineState { pub(crate) fn get_stream_or_alias( &mut self, addr: HeapCellValue, - stream_aliases: &StreamAliasDir, + indices: &IndexStore, caller: Atom, arity: usize, ) -> Result { @@ -1604,8 +1604,8 @@ impl MachineState { (HeapCellValueTag::Atom, (name, arity)) => { debug_assert_eq!(arity, 0); - return match stream_aliases.get(&name) { - Some(stream) if !stream.is_null_stream() => Ok(*stream), + return match indices.get_stream(name) { + Some(stream) if !stream.is_null_stream() => Ok(stream), _ => { let stub = functor_stub(caller, arity); let addr = atom_as_cell!(name); @@ -1622,8 +1622,8 @@ impl MachineState { debug_assert_eq!(arity, 0); - return match stream_aliases.get(&name) { - Some(stream) if !stream.is_null_stream() => Ok(*stream), + return match indices.get_stream(name) { + Some(stream) if !stream.is_null_stream() => Ok(stream), _ => { let stub = functor_stub(caller, arity); let addr = atom_as_cell!(name); @@ -1813,7 +1813,7 @@ impl MachineState { // 8.11.5.3l) if let Some(alias) = options.get_alias() { - if indices.stream_aliases.contains_key(&alias) { + if indices.has_stream(alias) { return Err(self.occupied_alias_permission_error(alias, atom!("open"), 4)); } } diff --git a/src/machine/system_calls.rs b/src/machine/system_calls.rs index 7a7eeb57..2b05b8f0 100644 --- a/src/machine/system_calls.rs +++ b/src/machine/system_calls.rs @@ -41,7 +41,6 @@ use indexmap::IndexSet; use std::cell::Cell; use std::cmp::Ordering; -use std::collections::BTreeSet; use std::convert::TryFrom; use std::env; #[cfg(feature = "ffi")] @@ -55,7 +54,6 @@ use std::mem; use std::net::{SocketAddr, ToSocketAddrs}; use std::net::{TcpListener, TcpStream}; use std::num::NonZeroU32; -use std::ops::Sub; use std::process; #[cfg(feature = "http")] use std::str::FromStr; @@ -2543,7 +2541,7 @@ impl Machine { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("peek_byte"), 2, )?; @@ -2634,7 +2632,7 @@ impl Machine { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("peek_char"), 2, )?; @@ -2726,7 +2724,7 @@ impl Machine { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("peek_code"), 2, )?; @@ -3147,7 +3145,7 @@ impl Machine { pub(crate) fn put_code(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("put_code"), 2, )?; @@ -3199,7 +3197,7 @@ impl Machine { pub(crate) fn put_char(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("put_char"), 2, )?; @@ -3243,7 +3241,7 @@ impl Machine { pub(crate) fn put_chars(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("$put_chars"), 2, )?; @@ -3292,7 +3290,7 @@ impl Machine { pub(crate) fn put_byte(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("put_byte"), 2, )?; @@ -3359,7 +3357,7 @@ impl Machine { pub(crate) fn get_byte(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("get_byte"), 2, )?; @@ -3444,7 +3442,7 @@ impl Machine { pub(crate) fn get_char(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("get_char"), 2, )?; @@ -3539,7 +3537,7 @@ impl Machine { pub(crate) fn get_n_chars(&mut self) -> CallResult { let stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("get_n_chars"), 3, )?; @@ -3608,7 +3606,7 @@ impl Machine { pub(crate) fn get_code(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("get_code"), 2, )?; @@ -3715,19 +3713,11 @@ impl Machine { #[inline(always)] pub(crate) fn first_stream(&mut self) { - let mut first_stream = None; - let mut null_streams = BTreeSet::new(); - - for stream in self.indices.streams.iter().cloned() { - if !stream.is_null_stream() { - first_stream = Some(stream); - break; - } else { - null_streams.insert(stream); - } - } - - self.indices.streams = self.indices.streams.sub(&null_streams); + let first_stream = self + .indices + .iter_streams(..) + .filter(|s| !s.is_null_stream()) + .next(); if let Some(first_stream) = first_stream { let stream = stream_as_cell!(first_stream); @@ -3743,20 +3733,12 @@ impl Machine { #[inline(always)] pub(crate) fn next_stream(&mut self) { let prev_stream = cell_as_stream!(self.deref_register(1)); - - let mut next_stream = None; - let mut null_streams = BTreeSet::new(); - - for stream in self.indices.streams.range(prev_stream..).skip(1).cloned() { - if !stream.is_null_stream() { - next_stream = Some(stream); - break; - } else { - null_streams.insert(stream); - } - } - - self.indices.streams = self.indices.streams.sub(&null_streams); + let next_stream = self + .indices + .iter_streams(prev_stream..) + .filter(|s| !s.is_null_stream()) + .skip(1) + .next(); if let Some(next_stream) = next_stream { let var = self.deref_register(2).as_var().unwrap(); @@ -3772,7 +3754,7 @@ impl Machine { pub(crate) fn flush_output(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("flush_output"), 1, )?; @@ -3859,7 +3841,7 @@ impl Machine { pub(crate) fn close(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("close"), 2, )?; @@ -3873,11 +3855,7 @@ impl Machine { return Ok(()); } - self.indices.streams.remove(&stream); - - if let Some(alias) = stream.options().get_alias() { - self.indices.stream_aliases.swap_remove(&alias); - } + self.indices.remove_stream(stream); stream.close().map_err(|_| { let stub = functor_stub(atom!("close"), 1); @@ -4445,11 +4423,10 @@ impl Machine { &mut self.machine_st.arena, ); *stream.options_mut() = StreamOptions::default(); - if let Some(alias) = stream.options().get_alias() { - self.indices.stream_aliases.insert(alias, stream); - } - self.indices.streams.insert(stream); + self.indices + .add_stream(stream, atom!("http_open"), 3) + .map_err(|stub_gen| stub_gen(&mut self.machine_st))?; let stream = stream_as_cell!(stream); @@ -4667,7 +4644,10 @@ impl Machine { ); *stream.options_mut() = StreamOptions::default(); stream.options_mut().set_stream_type(StreamType::Binary); - self.indices.streams.insert(stream); + + self.indices.add_stream(stream, atom!("http_accept"), 7) + .map_err(|stub_gen| stub_gen(&mut self.machine_st))?; + let stream = stream_as_cell!(stream); let handle: TypedArenaPtr = arena_alloc!(request.response, &mut self.machine_st.arena); @@ -4781,7 +4761,11 @@ impl Machine { ); *stream.options_mut() = StreamOptions::default(); stream.options_mut().set_stream_type(StreamType::Binary); - self.indices.streams.insert(stream); + + + self.indices.add_stream(stream, atom!("http_answer"), 4) + .map_err(|stub_gen| stub_gen(&mut self.machine_st))?; + let stream = stream_as_cell!(stream); self.machine_st.bind(stream_addr.as_var().unwrap(), stream); } @@ -5096,11 +5080,10 @@ impl Machine { .stream_from_file_spec(file_spec, &mut self.indices, &options)?; *stream.options_mut() = options; - self.indices.streams.insert(stream); - if let Some(alias) = stream.options().get_alias() { - self.indices.stream_aliases.insert(alias, stream); - } + self.indices + .add_stream(stream, atom!("open"), 4) + .map_err(|stub_gen| stub_gen(&mut self.machine_st))?; let stream_var = self.deref_register(3); self.machine_st @@ -5181,7 +5164,7 @@ impl Machine { pub(crate) fn set_stream_options(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("open"), 4, )?; @@ -5937,12 +5920,9 @@ impl Machine { pub(crate) fn set_input(&mut self) -> CallResult { let addr = self.deref_register(1); - let stream = self.machine_st.get_stream_or_alias( - addr, - &self.indices.stream_aliases, - atom!("set_input"), - 1, - )?; + let stream = + self.machine_st + .get_stream_or_alias(addr, &self.indices, atom!("set_input"), 1)?; if !stream.is_input_stream() { let stub = functor_stub(atom!("set_input"), 1); @@ -5964,12 +5944,9 @@ impl Machine { #[inline(always)] pub(crate) fn set_output(&mut self) -> CallResult { let addr = self.deref_register(1); - let stream = self.machine_st.get_stream_or_alias( - addr, - &self.indices.stream_aliases, - atom!("set_output"), - 1, - )?; + let stream = + self.machine_st + .get_stream_or_alias(addr, &self.indices, atom!("set_output"), 1)?; if !stream.is_output_stream() { let stub = functor_stub(atom!("set_output"), 1); @@ -6275,7 +6252,7 @@ impl Machine { let stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("read_term"), 3, )?; @@ -6516,7 +6493,7 @@ impl Machine { } if let Some(alias) = options.get_alias() { - if self.indices.stream_aliases.contains_key(&alias) { + if self.indices.has_stream(alias) { return Err(self.machine_st.occupied_alias_permission_error( alias, atom!("socket_client_open"), @@ -6532,11 +6509,9 @@ impl Machine { *stream.options_mut() = options; - if let Some(alias) = stream.options().get_alias() { - self.indices.stream_aliases.insert(alias, stream); - } - - self.indices.streams.insert(stream); + self.indices + .add_stream(stream, atom!("socket_client_open"), 7) + .map_err(|stub_gen| stub_gen(&mut self.machine_st))?; stream_as_cell!(stream) } @@ -6544,7 +6519,7 @@ impl Machine { return Err(self.machine_st.open_permission_error( addr, atom!("socket_client_open"), - 3, + 7, )); } Err(ErrorKind::NotFound) => { @@ -6661,7 +6636,7 @@ impl Machine { } if let Some(alias) = options.get_alias() { - if self.indices.stream_aliases.contains_key(&alias) { + if self.indices.has_stream(alias) { return Err(self.machine_st.occupied_alias_permission_error( alias, atom!("socket_server_accept"), @@ -6688,11 +6663,10 @@ impl Machine { *tcp_stream.options_mut() = options; - if let Some(alias) = &tcp_stream.options().get_alias() { - self.indices.stream_aliases.insert(*alias, tcp_stream); - } - - self.indices.streams.insert(tcp_stream); + self.indices.add_stream(tcp_stream, atom!("socket_server_accept"), 4) + .map_err(|stub_gen| { + stub_gen(&mut self.machine_st) + })?; let tcp_stream = stream_as_cell!(tcp_stream); let client = atom_as_cell!(client); @@ -6728,7 +6702,7 @@ impl Machine { { let stream0 = self.machine_st.get_stream_or_alias( self.machine_st.registers[2], - &self.indices.stream_aliases, + &self.indices, atom!("tls_client_negotiate"), 3, )?; @@ -6747,7 +6721,10 @@ impl Machine { let addr = atom!("TLS"); let stream = Stream::from_tls_stream(addr, stream, &mut self.machine_st.arena); - self.indices.streams.insert(stream); + + self.indices + .add_stream(stream, atom!("tls_client_negotiate"), 3) + .map_err(|stub_gen| stub_gen(&mut self.machine_st))?; self.machine_st.heap.push(stream_as_cell!(stream)); let stream_addr = self.deref_register(3); @@ -6782,7 +6759,7 @@ impl Machine { let stream0 = self.machine_st.get_stream_or_alias( self.machine_st.registers[3], - &self.indices.stream_aliases, + &self.indices, atom!("tls_server_negotiate"), 3, )?; @@ -6801,7 +6778,10 @@ impl Machine { }; let stream = Stream::from_tls_stream(atom!("TLS"), stream, &mut self.machine_st.arena); - self.indices.streams.insert(stream); + + self.indices + .add_stream(stream, atom!("tls_server_negotiate"), 3) + .map_err(|stub_gen| stub_gen(&mut self.machine_st))?; let stream_addr = self.deref_register(4); self.machine_st @@ -6843,7 +6823,7 @@ impl Machine { pub(crate) fn set_stream_position(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("set_stream_position"), 2, )?; @@ -6886,7 +6866,7 @@ impl Machine { pub(crate) fn stream_property(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("stream_property"), 2, )?; @@ -7237,7 +7217,7 @@ impl Machine { pub(crate) fn write_term(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("write_term"), 3, )?; @@ -8113,7 +8093,7 @@ impl Machine { pub(crate) fn devour_whitespace(&mut self) -> CallResult { let mut stream = self.machine_st.get_stream_or_alias( self.machine_st.registers[1], - &self.indices.stream_aliases, + &self.indices, atom!("$devour_whitespace"), 1, )?; -- 2.54.0