]> Repositorios git - scryer-prolog.git/commitdiff
Refactor UserInput to use channels
authorbakaq <[email protected]>
Thu, 30 Jan 2025 09:14:38 +0000 (06:14 -0300)
committerbakaq <[email protected]>
Sun, 16 Feb 2025 07:04:48 +0000 (04:04 -0300)
src/machine/config.rs
src/machine/lib_machine/tests.rs
src/machine/streams.rs

index 06dbb77ac96b998f1c4213801f4e04bb26b20bf1..12e5a5381d82809d1636709665c8f4bfaaa1724a 100644 (file)
@@ -1,7 +1,6 @@
-use std::cell::RefCell;
-use std::io::{Seek, SeekFrom, Write};
-use std::rc::Rc;
-use std::{borrow::Cow, io::Cursor};
+use std::borrow::Cow;
+use std::io::Write;
+use std::sync::mpsc::{channel, Receiver, Sender};
 
 use rand::{rngs::StdRng, SeedableRng};
 
@@ -40,14 +39,12 @@ impl StreamConfig {
     ///
     /// This also returns a handler to the stdin do the [`Machine`](crate::Machine).
     pub fn with_callbacks(stdout: Option<Callback>, stderr: Option<Callback>) -> (UserInput, Self) {
-        let stdin = Rc::new(RefCell::new(Cursor::new(Vec::new())));
+        let (sender, receiver) = channel();
         (
-            UserInput {
-                inner: stdin.clone(),
-            },
+            UserInput { inner: sender },
             StreamConfig {
                 inner: StreamConfigInner::Callbacks {
-                    stdin,
+                    stdin: receiver,
                     stdout,
                     stderr,
                 },
@@ -59,23 +56,19 @@ impl StreamConfig {
 /// A handler for the stdin of the [`Machine`](crate::Machine).
 #[derive(Debug)]
 pub struct UserInput {
-    inner: Rc<RefCell<Cursor<Vec<u8>>>>,
+    inner: Sender<Vec<u8>>,
 }
 
 impl Write for UserInput {
     fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
-        let mut inner = self.inner.borrow_mut();
-        let pos = inner.position();
-
-        inner.seek(SeekFrom::End(0))?;
-        let result = inner.write(buf);
-        inner.seek(SeekFrom::Start(pos))?;
-
-        result
+        self.inner
+            .send(buf.into())
+            .map(|_| buf.len())
+            .map_err(|_| std::io::ErrorKind::BrokenPipe.into())
     }
 
     fn flush(&mut self) -> std::io::Result<()> {
-        self.inner.borrow_mut().flush()
+        Ok(())
     }
 }
 
@@ -85,7 +78,7 @@ enum StreamConfigInner {
     #[default]
     Memory,
     Callbacks {
-        stdin: Rc<RefCell<Cursor<Vec<u8>>>>,
+        stdin: Receiver<Vec<u8>>,
         stdout: Option<Callback>,
         stderr: Option<Callback>,
     },
index 40c33225e755a9d1e1d74ac3f8b32fbce301c376..502f4d89b6a8684fb14591d2e359cfeb109d8b60 100644 (file)
@@ -620,7 +620,7 @@ fn callback_streams() {
 
     let (mut user_input, streams) = StreamConfig::with_callbacks(
         Some(Box::new(move |x| {
-            x.read_to_string(&mut *test_string2.borrow_mut()).unwrap();
+            x.read_to_string(&mut test_string2.borrow_mut()).unwrap();
         })),
         None,
     );
index ce0c8d0540ef52727ae9d3cbfc09e51a77f13d25..edb6756b83040b6687339699557267305efe8568 100644 (file)
@@ -16,7 +16,6 @@ pub use scryer_modular_bitfield::prelude::*;
 
 #[cfg(feature = "http")]
 use bytes::{buf::Reader as BufReader, Buf, Bytes};
-use std::cell::RefCell;
 use std::cmp::Ordering;
 use std::error::Error;
 use std::fmt;
@@ -30,7 +29,8 @@ use std::net::{Shutdown, TcpStream};
 use std::ops::{Deref, DerefMut};
 use std::path::PathBuf;
 use std::ptr;
-use std::rc::Rc;
+use std::sync::mpsc::Receiver;
+use std::sync::mpsc::TryRecvError;
 
 #[cfg(feature = "tls")]
 use native_tls::TlsStream;
@@ -414,13 +414,50 @@ impl Write for CallbackStream {
 
 #[derive(Debug)]
 pub struct InputChannelStream {
-    pub(crate) inner: Rc<RefCell<Cursor<Vec<u8>>>>,
+    pub(crate) inner: Cursor<Vec<u8>>,
+    pub eof: bool,
+    channel: Receiver<Vec<u8>>,
 }
 
 impl Read for InputChannelStream {
     #[inline]
     fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
-        self.inner.borrow_mut().read(buf)
+        if self.eof {
+            return Ok(0);
+        }
+
+        let to_read = buf.len();
+        let mut total_read = 0;
+
+        loop {
+            total_read += self.inner.read(&mut buf[total_read..])?;
+
+            if total_read < to_read {
+                // We need to get more data to read
+                match self.channel.try_recv() {
+                    Ok(data) => {
+                        // Append into self.inner
+                        let pos = self.inner.position();
+                        assert_eq!(pos as usize, self.inner.get_ref().len());
+                        self.inner.write_all(&data)?;
+                        self.inner.seek(SeekFrom::Start(pos))?;
+                    }
+                    Err(TryRecvError::Empty) => {
+                        // Data is pending
+                        break;
+                    }
+                    Err(TryRecvError::Disconnected) => {
+                        // The other end of the channel was closed so we are EOF
+                        self.eof = true;
+                        break;
+                    }
+                }
+            } else {
+                assert_eq!(total_read, to_read);
+                break;
+            }
+        }
+        Ok(total_read)
     }
 }
 
@@ -602,9 +639,14 @@ impl Stream {
     }
 
     #[inline]
-    pub fn input_channel(cursor: Rc<RefCell<Cursor<Vec<u8>>>>, arena: &mut Arena) -> Stream {
+    pub fn input_channel(channel: Receiver<Vec<u8>>, arena: &mut Arena) -> Stream {
+        let inner = Cursor::new(Vec::new());
         Stream::InputChannel(arena_alloc!(
-            StreamLayout::new(CharReader::new(InputChannelStream { inner: cursor })),
+            StreamLayout::new(CharReader::new(InputChannelStream {
+                inner,
+                eof: false,
+                channel
+            })),
             arena
         ))
     }
@@ -1239,6 +1281,13 @@ impl Stream {
                     AtEndOfStream::Past
                 }
             }
+            Stream::InputChannel(stream_layout) => {
+                if stream_layout.stream.get_ref().eof {
+                    AtEndOfStream::At
+                } else {
+                    AtEndOfStream::Not
+                }
+            }
             _ => AtEndOfStream::Not,
         }
     }
@@ -1519,6 +1568,10 @@ impl Stream {
                 readline_stream.reset();
                 true
             }
+            Stream::InputChannel(ref mut input_channel_stream) => {
+                input_channel_stream.stream.get_mut().inner.set_position(0);
+                true
+            }
             _ => false,
         }
     }