From: Emilie Burgun Date: Sun, 2 Feb 2025 23:07:39 +0000 (+0100) Subject: Fix close/1 messing up stream_aliases when user_input or user_output aren't set to... X-Git-Tag: v0.10.0~74^2~5 X-Git-Url: https://git.sagredo.dev/?a=commitdiff_plain;h=2fe7b55343ccf6396f4ad6a8088462592302ad06;p=scryer-prolog.git Fix close/1 messing up stream_aliases when user_input or user_output aren't set to Stdin and Stdout --- diff --git a/src/machine/streams.rs b/src/machine/streams.rs index d434c569..212ec463 100644 --- a/src/machine/streams.rs +++ b/src/machine/streams.rs @@ -1288,11 +1288,10 @@ impl Stream { )) } + /// Drops the stream handle and marks the arena pointer as [`ArenaHeaderTag::Dropped`]. #[inline] pub(crate) fn close(&mut self) -> Result<(), std::io::Error> { - let mut stream = std::mem::replace(self, Stream::Null(StreamOptions::default())); - - match stream { + match self { Stream::NamedTcp(ref mut tcp_stream) => { tcp_stream.inner_mut().tcp_stream.shutdown(Shutdown::Both) } @@ -1322,7 +1321,20 @@ impl Stream { Ok(()) } - _ => Ok(()), + Stream::Byte(mut stream) => { + stream.drop_payload(); + Ok(()) + } + Stream::StaticString(mut stream) => { + stream.drop_payload(); + Ok(()) + } + + Stream::Null(_) => Ok(()), + + Stream::Readline(_) | Stream::StandardOutput(_) | Stream::StandardError(_) => { + unreachable!(); + } } } @@ -1893,3 +1905,45 @@ impl MachineState { } } } + +#[cfg(test)] +mod test { + use super::*; + use crate::machine::config::*; + + #[test] + #[cfg_attr(miri, ignore)] + fn close_memory_user_output_stream() { + let mut machine = MachineBuilder::new() + .with_streams(StreamConfig::in_memory()) + .build(); + + let results = machine + .run_query( + "\\+ \\+ (current_output(Stream), close(Stream)), write(user_output, hello).", + ) + .collect::>(); + + assert_eq!(results.len(), 1); + assert!(results[0].is_ok()); + + let mut actual = String::new(); + machine.user_output.read_to_string(&mut actual).unwrap(); + assert_eq!(actual, "hello"); + } + + #[test] + #[cfg_attr(miri, ignore)] + fn close_memory_user_output_stream_twice() { + let mut machine = MachineBuilder::new() + .with_streams(StreamConfig::in_memory()) + .build(); + + let results = machine + .run_query("\\+ \\+ (current_output(Stream), close(Stream), close(Stream)).") + .collect::>(); + + assert_eq!(results.len(), 1); + assert!(results[0].is_ok()); + } +} diff --git a/src/machine/system_calls.rs b/src/machine/system_calls.rs index 14f9094b..7a7eeb57 100644 --- a/src/machine/system_calls.rs +++ b/src/machine/system_calls.rs @@ -3868,47 +3868,26 @@ impl Machine { stream.flush().unwrap(); // 8.11.6.1b) } - self.indices.streams.remove(&stream); - - if stream == self.user_input { - self.user_input = self - .indices - .stream_aliases - .get(&atom!("user_input")) - .cloned() - .unwrap(); - - self.indices.streams.insert(self.user_input); - } else if stream == self.user_output { - self.user_output = self - .indices - .stream_aliases - .get(&atom!("user_output")) - .cloned() - .unwrap(); - - self.indices.streams.insert(self.user_output); + if stream == self.user_input || stream == self.user_output || stream.is_stderr() { + // stdin, stdout and stderr shouldn't be removed from the store, so return now + return Ok(()); } - if !stream.is_stdin() && !stream.is_stdout() && !stream.is_stderr() { - if let Some(alias) = stream.options().get_alias() { - self.indices.stream_aliases.swap_remove(&alias); - } - - let close_result = stream.close(); - - if close_result.is_err() { - let stub = functor_stub(atom!("close"), 1); - let addr = stream_as_cell!(stream); - let err = self - .machine_st - .existence_error(ExistenceError::Stream(addr)); + self.indices.streams.remove(&stream); - return Err(self.machine_st.error_form(err, stub)); - } + if let Some(alias) = stream.options().get_alias() { + self.indices.stream_aliases.swap_remove(&alias); } - Ok(()) + stream.close().map_err(|_| { + let stub = functor_stub(atom!("close"), 1); + let addr = stream_as_cell!(stream); + let err = self + .machine_st + .existence_error(ExistenceError::Stream(addr)); + + self.machine_st.error_form(err, stub) + }) } #[inline(always)]