]> Repositorios git - scryer-prolog.git/commitdiff
Add input stream channel
authorbakaq <[email protected]>
Wed, 29 Jan 2025 14:35:04 +0000 (11:35 -0300)
committerbakaq <[email protected]>
Sun, 16 Feb 2025 07:04:48 +0000 (04:04 -0300)
src/arena.rs
src/machine/config.rs
src/machine/streams.rs
src/macros.rs

index e3e6cbbdf01b75d80535a5aa5658c3396b2ff861..113ff6f0e60eb1e4bd38c9df581b31c2191a38db 100644 (file)
@@ -182,6 +182,7 @@ pub enum ArenaHeaderTag {
     StaticStringStream = 0b110100,
     ByteStream = 0b111000,
     CallbackStream = 0b111001,
+    InputChannelStream = 0b111010,
     StandardOutputStream = 0b1100,
     StandardErrorStream = 0b11000,
     NullStream = 0b111100,
@@ -845,6 +846,9 @@ unsafe fn drop_slab_in_place(value: NonNull<AllocSlab>, tag: ArenaHeaderTag) {
         ArenaHeaderTag::CallbackStream => {
             drop_typed_slab_in_place!(CallbackStream, value);
         }
+        ArenaHeaderTag::InputChannelStream => {
+            drop_typed_slab_in_place!(InputChannelStream, value);
+        }
         ArenaHeaderTag::LiveLoadState | ArenaHeaderTag::InactiveLoadState => {
             drop_typed_slab_in_place!(LiveLoadState, value);
         }
index 34529434de023d2eee5cf982d002c7de2f41166a..06dbb77ac96b998f1c4213801f4e04bb26b20bf1 100644 (file)
@@ -1,4 +1,7 @@
-use std::borrow::Cow;
+use std::cell::RefCell;
+use std::io::{Seek, SeekFrom, Write};
+use std::rc::Rc;
+use std::{borrow::Cow, io::Cursor};
 
 use rand::{rngs::StdRng, SeedableRng};
 
@@ -34,10 +37,45 @@ impl StreamConfig {
     }
 
     /// Calls the given callbacks when the respective streams are written to.
-    pub fn with_callbacks(stdout: Option<Callback>, stderr: Option<Callback>) -> Self {
-        StreamConfig {
-            inner: StreamConfigInner::Callbacks { stdout, stderr },
-        }
+    ///
+    /// This also returns a handler to the stdin do the [`Machine`](crate::Machine).
+    pub fn with_callbacks(stdout: Option<Callback>, stderr: Option<Callback>) -> (UserInput, Self) {
+        let stdin = Rc::new(RefCell::new(Cursor::new(Vec::new())));
+        (
+            UserInput {
+                inner: stdin.clone(),
+            },
+            StreamConfig {
+                inner: StreamConfigInner::Callbacks {
+                    stdin,
+                    stdout,
+                    stderr,
+                },
+            },
+        )
+    }
+}
+
+/// A handler for the stdin of the [`Machine`](crate::Machine).
+#[derive(Debug)]
+pub struct UserInput {
+    inner: Rc<RefCell<Cursor<Vec<u8>>>>,
+}
+
+impl Write for UserInput {
+    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
+        let mut inner = self.inner.borrow_mut();
+        let pos = inner.position();
+
+        inner.seek(SeekFrom::End(0))?;
+        let result = inner.write(buf);
+        inner.seek(SeekFrom::Start(pos))?;
+
+        result
+    }
+
+    fn flush(&mut self) -> std::io::Result<()> {
+        self.inner.borrow_mut().flush()
     }
 }
 
@@ -47,6 +85,7 @@ enum StreamConfigInner {
     #[default]
     Memory,
     Callbacks {
+        stdin: Rc<RefCell<Cursor<Vec<u8>>>>,
         stdout: Option<Callback>,
         stderr: Option<Callback>,
     },
@@ -102,8 +141,12 @@ impl MachineBuilder {
                 Stream::from_owned_string("".to_owned(), &mut machine_st.arena),
                 Stream::stderr(&mut machine_st.arena),
             ),
-            StreamConfigInner::Callbacks { stdout, stderr } => (
-                Stream::Null(StreamOptions::default()),
+            StreamConfigInner::Callbacks {
+                stdin,
+                stdout,
+                stderr,
+            } => (
+                Stream::input_channel(stdin, &mut machine_st.arena),
                 stdout.map_or_else(
                     || Stream::Null(StreamOptions::default()),
                     |x| Stream::from_callback(x, &mut machine_st.arena),
index 0c180d4a795fc007eeb37e2093ba6f7364a96b14..ce0c8d0540ef52727ae9d3cbfc09e51a77f13d25 100644 (file)
@@ -16,6 +16,7 @@ pub use scryer_modular_bitfield::prelude::*;
 
 #[cfg(feature = "http")]
 use bytes::{buf::Reader as BufReader, Buf, Bytes};
+use std::cell::RefCell;
 use std::cmp::Ordering;
 use std::error::Error;
 use std::fmt;
@@ -29,6 +30,7 @@ use std::net::{Shutdown, TcpStream};
 use std::ops::{Deref, DerefMut};
 use std::path::PathBuf;
 use std::ptr;
+use std::rc::Rc;
 
 #[cfg(feature = "tls")]
 use native_tls::TlsStream;
@@ -410,6 +412,18 @@ impl Write for CallbackStream {
     }
 }
 
+#[derive(Debug)]
+pub struct InputChannelStream {
+    pub(crate) inner: Rc<RefCell<Cursor<Vec<u8>>>>,
+}
+
+impl Read for InputChannelStream {
+    #[inline]
+    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
+        self.inner.borrow_mut().read(buf)
+    }
+}
+
 #[bitfield]
 #[repr(u64)]
 #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@@ -536,6 +550,7 @@ arena_allocated_impl_for_stream!(StaticStringStream, StaticStringStream);
 arena_allocated_impl_for_stream!(StandardOutputStream, StandardOutputStream);
 arena_allocated_impl_for_stream!(StandardErrorStream, StandardErrorStream);
 arena_allocated_impl_for_stream!(CharReader<CallbackStream>, CallbackStream);
+arena_allocated_impl_for_stream!(CharReader<InputChannelStream>, InputChannelStream);
 
 #[derive(Debug, Copy, Clone)]
 pub enum Stream {
@@ -555,6 +570,7 @@ pub enum Stream {
     StandardOutput(TypedArenaPtr<StandardOutputStream>),
     StandardError(TypedArenaPtr<StandardErrorStream>),
     Callback(TypedArenaPtr<CallbackStream>),
+    InputChannel(TypedArenaPtr<InputChannelStream>),
 }
 
 impl From<TypedArenaPtr<ReadlineStream>> for Stream {
@@ -585,6 +601,14 @@ impl Stream {
         ))
     }
 
+    #[inline]
+    pub fn input_channel(cursor: Rc<RefCell<Cursor<Vec<u8>>>>, arena: &mut Arena) -> Stream {
+        Stream::InputChannel(arena_alloc!(
+            StreamLayout::new(CharReader::new(InputChannelStream { inner: cursor })),
+            arena
+        ))
+    }
+
     #[inline]
     pub fn stdin(arena: &mut Arena, add_history: bool) -> Stream {
         Stream::Readline(arena_alloc!(
@@ -619,6 +643,9 @@ impl Stream {
                 Stream::Null(StreamOptions::default())
             }
             ArenaHeaderTag::CallbackStream => Stream::Callback(unsafe { ptr.as_typed_ptr() }),
+            ArenaHeaderTag::InputChannelStream => {
+                Stream::InputChannel(unsafe { ptr.as_typed_ptr() })
+            }
             _ => unreachable!(),
         }
     }
@@ -656,6 +683,7 @@ impl Stream {
             Stream::StandardOutput(ptr) => ptr.header_ptr(),
             Stream::StandardError(ptr) => ptr.header_ptr(),
             Stream::Callback(ptr) => ptr.header_ptr(),
+            Stream::InputChannel(ptr) => ptr.header_ptr(),
         }
     }
 
@@ -677,6 +705,7 @@ impl Stream {
             Stream::StandardOutput(ref ptr) => &ptr.options,
             Stream::StandardError(ref ptr) => &ptr.options,
             Stream::Callback(ref ptr) => &ptr.options,
+            Stream::InputChannel(ref ptr) => &ptr.options,
         }
     }
 
@@ -698,6 +727,7 @@ impl Stream {
             Stream::StandardOutput(ref mut ptr) => &mut ptr.options,
             Stream::StandardError(ref mut ptr) => &mut ptr.options,
             Stream::Callback(ref mut ptr) => &mut ptr.options,
+            Stream::InputChannel(ref mut ptr) => &mut ptr.options,
         }
     }
 
@@ -720,6 +750,7 @@ impl Stream {
             Stream::StandardOutput(ptr) => ptr.lines_read += incr_num_lines_read,
             Stream::StandardError(ptr) => ptr.lines_read += incr_num_lines_read,
             Stream::Callback(ptr) => ptr.lines_read += incr_num_lines_read,
+            Stream::InputChannel(ptr) => ptr.lines_read += incr_num_lines_read,
         }
     }
 
@@ -742,6 +773,7 @@ impl Stream {
             Stream::StandardOutput(ptr) => ptr.lines_read = value,
             Stream::StandardError(ptr) => ptr.lines_read = value,
             Stream::Callback(ptr) => ptr.lines_read = value,
+            Stream::InputChannel(ptr) => ptr.lines_read = value,
         }
     }
 
@@ -764,6 +796,7 @@ impl Stream {
             Stream::StandardOutput(ptr) => ptr.lines_read,
             Stream::StandardError(ptr) => ptr.lines_read,
             Stream::Callback(ptr) => ptr.lines_read,
+            Stream::InputChannel(ptr) => ptr.lines_read,
         }
     }
 }
@@ -780,6 +813,7 @@ impl CharRead for Stream {
             Stream::Readline(rl_stream) => (*rl_stream).peek_char(),
             Stream::StaticString(src) => (*src).peek_char(),
             Stream::Byte(cursor) => (*cursor).peek_char(),
+            Stream::InputChannel(cursor) => (*cursor).peek_char(),
             #[cfg(feature = "http")]
             Stream::HttpWrite(_) => Some(Err(std::io::Error::new(
                 ErrorKind::PermissionDenied,
@@ -807,6 +841,7 @@ impl CharRead for Stream {
             Stream::Readline(rl_stream) => (*rl_stream).read_char(),
             Stream::StaticString(src) => (*src).read_char(),
             Stream::Byte(cursor) => (*cursor).read_char(),
+            Stream::InputChannel(cursor) => (*cursor).read_char(),
             #[cfg(feature = "http")]
             Stream::HttpWrite(_) => Some(Err(std::io::Error::new(
                 ErrorKind::PermissionDenied,
@@ -841,6 +876,7 @@ impl CharRead for Stream {
             | Stream::StandardOutput(_)
             | Stream::Null(_)
             | Stream::Callback(_) => {}
+            Stream::InputChannel(_) => {}
         }
     }
 
@@ -855,6 +891,7 @@ impl CharRead for Stream {
             Stream::Readline(ref mut rl_stream) => rl_stream.consume(nread),
             Stream::StaticString(ref mut src) => src.consume(nread),
             Stream::Byte(ref mut cursor) => cursor.consume(nread),
+            Stream::InputChannel(ref mut cursor) => cursor.consume(nread),
             #[cfg(feature = "http")]
             Stream::HttpWrite(_) => {}
             Stream::OutputFile(_)
@@ -879,6 +916,7 @@ impl Read for Stream {
             Stream::Readline(rl_stream) => (*rl_stream).read(buf),
             Stream::StaticString(src) => (*src).read(buf),
             Stream::Byte(cursor) => (*cursor).read(buf),
+            Stream::InputChannel(cursor) => (*cursor).read(buf),
             #[cfg(feature = "http")]
             Stream::HttpWrite(_) => Err(std::io::Error::new(
                 ErrorKind::PermissionDenied,
@@ -915,6 +953,7 @@ impl Write for Stream {
                 StreamError::WriteToInputStream,
             )),
             Stream::StaticString(_)
+            | Stream::InputChannel(_)
             | Stream::Readline(_)
             | Stream::InputFile(..)
             | Stream::Null(_) => Err(std::io::Error::new(
@@ -942,6 +981,7 @@ impl Write for Stream {
                 StreamError::FlushToInputStream,
             )),
             Stream::StaticString(_)
+            | Stream::InputChannel(_)
             | Stream::Readline(_)
             | Stream::InputFile(_)
             | Stream::Null(_) => Err(std::io::Error::new(
@@ -1095,6 +1135,7 @@ impl Stream {
             Stream::StandardOutput(stream) => stream.past_end_of_stream,
             Stream::StandardError(stream) => stream.past_end_of_stream,
             Stream::Callback(stream) => stream.past_end_of_stream,
+            Stream::InputChannel(stream) => stream.past_end_of_stream,
         }
     }
 
@@ -1122,6 +1163,7 @@ impl Stream {
             Stream::StandardOutput(stream) => stream.past_end_of_stream = value,
             Stream::StandardError(stream) => stream.past_end_of_stream = value,
             Stream::Callback(stream) => stream.past_end_of_stream = value,
+            Stream::InputChannel(stream) => stream.past_end_of_stream = value,
         }
     }
 
@@ -1221,6 +1263,7 @@ impl Stream {
             #[cfg(feature = "tls")]
             Stream::NamedTls(..) => atom!("read_append"),
             Stream::Byte(_)
+            | Stream::InputChannel(_)
             | Stream::Readline(_)
             | Stream::StaticString(_)
             | Stream::InputFile(..) => atom!("read"),
@@ -1396,6 +1439,10 @@ impl Stream {
                 stream.drop_payload();
                 Ok(())
             }
+            Stream::InputChannel(mut stream) => {
+                stream.drop_payload();
+                Ok(())
+            }
             Stream::StaticString(mut stream) => {
                 stream.drop_payload();
                 Ok(())
@@ -1423,6 +1470,7 @@ impl Stream {
             Stream::HttpRead(..) => true,
             Stream::NamedTcp(..)
             | Stream::Byte(_)
+            | Stream::InputChannel(_)
             | Stream::Readline(_)
             | Stream::StaticString(_)
             | Stream::InputFile(..) => true,
index 9b2cbabe4bb528eef1f57cc93db09b165b42ab25..547ccc007dd2d35019c45b000981fbca1ae747ef 100644 (file)
@@ -306,6 +306,7 @@ macro_rules! match_untyped_arena_ptr_pat {
             | ArenaHeaderTag::StaticStringStream
             | ArenaHeaderTag::ByteStream
             | ArenaHeaderTag::CallbackStream
+            | ArenaHeaderTag::InputChannelStream
             | ArenaHeaderTag::StandardOutputStream
             | ArenaHeaderTag::StandardErrorStream
     };