From: bakaq Date: Thu, 30 Jan 2025 09:14:38 +0000 (-0300) Subject: Refactor UserInput to use channels X-Git-Tag: v0.10.0~71^2~5 X-Git-Url: https://git.sagredo.dev/?a=commitdiff_plain;h=baae1dca15080735f7b0a77093b86390f14f4e38;p=scryer-prolog.git Refactor UserInput to use channels --- diff --git a/src/machine/config.rs b/src/machine/config.rs index 06dbb77a..12e5a538 100644 --- a/src/machine/config.rs +++ b/src/machine/config.rs @@ -1,7 +1,6 @@ -use std::cell::RefCell; -use std::io::{Seek, SeekFrom, Write}; -use std::rc::Rc; -use std::{borrow::Cow, io::Cursor}; +use std::borrow::Cow; +use std::io::Write; +use std::sync::mpsc::{channel, Receiver, Sender}; use rand::{rngs::StdRng, SeedableRng}; @@ -40,14 +39,12 @@ impl StreamConfig { /// /// This also returns a handler to the stdin do the [`Machine`](crate::Machine). pub fn with_callbacks(stdout: Option, stderr: Option) -> (UserInput, Self) { - let stdin = Rc::new(RefCell::new(Cursor::new(Vec::new()))); + let (sender, receiver) = channel(); ( - UserInput { - inner: stdin.clone(), - }, + UserInput { inner: sender }, StreamConfig { inner: StreamConfigInner::Callbacks { - stdin, + stdin: receiver, stdout, stderr, }, @@ -59,23 +56,19 @@ impl StreamConfig { /// A handler for the stdin of the [`Machine`](crate::Machine). #[derive(Debug)] pub struct UserInput { - inner: Rc>>>, + inner: Sender>, } impl Write for UserInput { fn write(&mut self, buf: &[u8]) -> std::io::Result { - let mut inner = self.inner.borrow_mut(); - let pos = inner.position(); - - inner.seek(SeekFrom::End(0))?; - let result = inner.write(buf); - inner.seek(SeekFrom::Start(pos))?; - - result + self.inner + .send(buf.into()) + .map(|_| buf.len()) + .map_err(|_| std::io::ErrorKind::BrokenPipe.into()) } fn flush(&mut self) -> std::io::Result<()> { - self.inner.borrow_mut().flush() + Ok(()) } } @@ -85,7 +78,7 @@ enum StreamConfigInner { #[default] Memory, Callbacks { - stdin: Rc>>>, + stdin: Receiver>, stdout: Option, stderr: Option, }, diff --git a/src/machine/lib_machine/tests.rs b/src/machine/lib_machine/tests.rs index 40c33225..502f4d89 100644 --- a/src/machine/lib_machine/tests.rs +++ b/src/machine/lib_machine/tests.rs @@ -620,7 +620,7 @@ fn callback_streams() { let (mut user_input, streams) = StreamConfig::with_callbacks( Some(Box::new(move |x| { - x.read_to_string(&mut *test_string2.borrow_mut()).unwrap(); + x.read_to_string(&mut test_string2.borrow_mut()).unwrap(); })), None, ); diff --git a/src/machine/streams.rs b/src/machine/streams.rs index ce0c8d05..edb6756b 100644 --- a/src/machine/streams.rs +++ b/src/machine/streams.rs @@ -16,7 +16,6 @@ pub use scryer_modular_bitfield::prelude::*; #[cfg(feature = "http")] use bytes::{buf::Reader as BufReader, Buf, Bytes}; -use std::cell::RefCell; use std::cmp::Ordering; use std::error::Error; use std::fmt; @@ -30,7 +29,8 @@ use std::net::{Shutdown, TcpStream}; use std::ops::{Deref, DerefMut}; use std::path::PathBuf; use std::ptr; -use std::rc::Rc; +use std::sync::mpsc::Receiver; +use std::sync::mpsc::TryRecvError; #[cfg(feature = "tls")] use native_tls::TlsStream; @@ -414,13 +414,50 @@ impl Write for CallbackStream { #[derive(Debug)] pub struct InputChannelStream { - pub(crate) inner: Rc>>>, + pub(crate) inner: Cursor>, + pub eof: bool, + channel: Receiver>, } impl Read for InputChannelStream { #[inline] fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - self.inner.borrow_mut().read(buf) + if self.eof { + return Ok(0); + } + + let to_read = buf.len(); + let mut total_read = 0; + + loop { + total_read += self.inner.read(&mut buf[total_read..])?; + + if total_read < to_read { + // We need to get more data to read + match self.channel.try_recv() { + Ok(data) => { + // Append into self.inner + let pos = self.inner.position(); + assert_eq!(pos as usize, self.inner.get_ref().len()); + self.inner.write_all(&data)?; + self.inner.seek(SeekFrom::Start(pos))?; + } + Err(TryRecvError::Empty) => { + // Data is pending + break; + } + Err(TryRecvError::Disconnected) => { + // The other end of the channel was closed so we are EOF + self.eof = true; + break; + } + } + } else { + assert_eq!(total_read, to_read); + break; + } + } + Ok(total_read) } } @@ -602,9 +639,14 @@ impl Stream { } #[inline] - pub fn input_channel(cursor: Rc>>>, arena: &mut Arena) -> Stream { + pub fn input_channel(channel: Receiver>, arena: &mut Arena) -> Stream { + let inner = Cursor::new(Vec::new()); Stream::InputChannel(arena_alloc!( - StreamLayout::new(CharReader::new(InputChannelStream { inner: cursor })), + StreamLayout::new(CharReader::new(InputChannelStream { + inner, + eof: false, + channel + })), arena )) } @@ -1239,6 +1281,13 @@ impl Stream { AtEndOfStream::Past } } + Stream::InputChannel(stream_layout) => { + if stream_layout.stream.get_ref().eof { + AtEndOfStream::At + } else { + AtEndOfStream::Not + } + } _ => AtEndOfStream::Not, } } @@ -1519,6 +1568,10 @@ impl Stream { readline_stream.reset(); true } + Stream::InputChannel(ref mut input_channel_stream) => { + input_channel_stream.stream.get_mut().inner.set_position(0); + true + } _ => false, } }