From: bakaq Date: Wed, 29 Jan 2025 14:35:04 +0000 (-0300) Subject: Add input stream channel X-Git-Tag: v0.10.0~71^2~7 X-Git-Url: https://git.sagredo.dev/?a=commitdiff_plain;h=7a6620b52df02f33690ecf3033b7920ae855c819;p=scryer-prolog.git Add input stream channel --- diff --git a/src/arena.rs b/src/arena.rs index e3e6cbbd..113ff6f0 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -182,6 +182,7 @@ pub enum ArenaHeaderTag { StaticStringStream = 0b110100, ByteStream = 0b111000, CallbackStream = 0b111001, + InputChannelStream = 0b111010, StandardOutputStream = 0b1100, StandardErrorStream = 0b11000, NullStream = 0b111100, @@ -845,6 +846,9 @@ unsafe fn drop_slab_in_place(value: NonNull, tag: ArenaHeaderTag) { ArenaHeaderTag::CallbackStream => { drop_typed_slab_in_place!(CallbackStream, value); } + ArenaHeaderTag::InputChannelStream => { + drop_typed_slab_in_place!(InputChannelStream, value); + } ArenaHeaderTag::LiveLoadState | ArenaHeaderTag::InactiveLoadState => { drop_typed_slab_in_place!(LiveLoadState, value); } diff --git a/src/machine/config.rs b/src/machine/config.rs index 34529434..06dbb77a 100644 --- a/src/machine/config.rs +++ b/src/machine/config.rs @@ -1,4 +1,7 @@ -use std::borrow::Cow; +use std::cell::RefCell; +use std::io::{Seek, SeekFrom, Write}; +use std::rc::Rc; +use std::{borrow::Cow, io::Cursor}; use rand::{rngs::StdRng, SeedableRng}; @@ -34,10 +37,45 @@ impl StreamConfig { } /// Calls the given callbacks when the respective streams are written to. - pub fn with_callbacks(stdout: Option, stderr: Option) -> Self { - StreamConfig { - inner: StreamConfigInner::Callbacks { stdout, stderr }, - } + /// + /// This also returns a handler to the stdin do the [`Machine`](crate::Machine). + pub fn with_callbacks(stdout: Option, stderr: Option) -> (UserInput, Self) { + let stdin = Rc::new(RefCell::new(Cursor::new(Vec::new()))); + ( + UserInput { + inner: stdin.clone(), + }, + StreamConfig { + inner: StreamConfigInner::Callbacks { + stdin, + stdout, + stderr, + }, + }, + ) + } +} + +/// A handler for the stdin of the [`Machine`](crate::Machine). +#[derive(Debug)] +pub struct UserInput { + inner: Rc>>>, +} + +impl Write for UserInput { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut inner = self.inner.borrow_mut(); + let pos = inner.position(); + + inner.seek(SeekFrom::End(0))?; + let result = inner.write(buf); + inner.seek(SeekFrom::Start(pos))?; + + result + } + + fn flush(&mut self) -> std::io::Result<()> { + self.inner.borrow_mut().flush() } } @@ -47,6 +85,7 @@ enum StreamConfigInner { #[default] Memory, Callbacks { + stdin: Rc>>>, stdout: Option, stderr: Option, }, @@ -102,8 +141,12 @@ impl MachineBuilder { Stream::from_owned_string("".to_owned(), &mut machine_st.arena), Stream::stderr(&mut machine_st.arena), ), - StreamConfigInner::Callbacks { stdout, stderr } => ( - Stream::Null(StreamOptions::default()), + StreamConfigInner::Callbacks { + stdin, + stdout, + stderr, + } => ( + Stream::input_channel(stdin, &mut machine_st.arena), stdout.map_or_else( || Stream::Null(StreamOptions::default()), |x| Stream::from_callback(x, &mut machine_st.arena), diff --git a/src/machine/streams.rs b/src/machine/streams.rs index 0c180d4a..ce0c8d05 100644 --- a/src/machine/streams.rs +++ b/src/machine/streams.rs @@ -16,6 +16,7 @@ pub use scryer_modular_bitfield::prelude::*; #[cfg(feature = "http")] use bytes::{buf::Reader as BufReader, Buf, Bytes}; +use std::cell::RefCell; use std::cmp::Ordering; use std::error::Error; use std::fmt; @@ -29,6 +30,7 @@ use std::net::{Shutdown, TcpStream}; use std::ops::{Deref, DerefMut}; use std::path::PathBuf; use std::ptr; +use std::rc::Rc; #[cfg(feature = "tls")] use native_tls::TlsStream; @@ -410,6 +412,18 @@ impl Write for CallbackStream { } } +#[derive(Debug)] +pub struct InputChannelStream { + pub(crate) inner: Rc>>>, +} + +impl Read for InputChannelStream { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.inner.borrow_mut().read(buf) + } +} + #[bitfield] #[repr(u64)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] @@ -536,6 +550,7 @@ arena_allocated_impl_for_stream!(StaticStringStream, StaticStringStream); arena_allocated_impl_for_stream!(StandardOutputStream, StandardOutputStream); arena_allocated_impl_for_stream!(StandardErrorStream, StandardErrorStream); arena_allocated_impl_for_stream!(CharReader, CallbackStream); +arena_allocated_impl_for_stream!(CharReader, InputChannelStream); #[derive(Debug, Copy, Clone)] pub enum Stream { @@ -555,6 +570,7 @@ pub enum Stream { StandardOutput(TypedArenaPtr), StandardError(TypedArenaPtr), Callback(TypedArenaPtr), + InputChannel(TypedArenaPtr), } impl From> for Stream { @@ -585,6 +601,14 @@ impl Stream { )) } + #[inline] + pub fn input_channel(cursor: Rc>>>, arena: &mut Arena) -> Stream { + Stream::InputChannel(arena_alloc!( + StreamLayout::new(CharReader::new(InputChannelStream { inner: cursor })), + arena + )) + } + #[inline] pub fn stdin(arena: &mut Arena, add_history: bool) -> Stream { Stream::Readline(arena_alloc!( @@ -619,6 +643,9 @@ impl Stream { Stream::Null(StreamOptions::default()) } ArenaHeaderTag::CallbackStream => Stream::Callback(unsafe { ptr.as_typed_ptr() }), + ArenaHeaderTag::InputChannelStream => { + Stream::InputChannel(unsafe { ptr.as_typed_ptr() }) + } _ => unreachable!(), } } @@ -656,6 +683,7 @@ impl Stream { Stream::StandardOutput(ptr) => ptr.header_ptr(), Stream::StandardError(ptr) => ptr.header_ptr(), Stream::Callback(ptr) => ptr.header_ptr(), + Stream::InputChannel(ptr) => ptr.header_ptr(), } } @@ -677,6 +705,7 @@ impl Stream { Stream::StandardOutput(ref ptr) => &ptr.options, Stream::StandardError(ref ptr) => &ptr.options, Stream::Callback(ref ptr) => &ptr.options, + Stream::InputChannel(ref ptr) => &ptr.options, } } @@ -698,6 +727,7 @@ impl Stream { Stream::StandardOutput(ref mut ptr) => &mut ptr.options, Stream::StandardError(ref mut ptr) => &mut ptr.options, Stream::Callback(ref mut ptr) => &mut ptr.options, + Stream::InputChannel(ref mut ptr) => &mut ptr.options, } } @@ -720,6 +750,7 @@ impl Stream { Stream::StandardOutput(ptr) => ptr.lines_read += incr_num_lines_read, Stream::StandardError(ptr) => ptr.lines_read += incr_num_lines_read, Stream::Callback(ptr) => ptr.lines_read += incr_num_lines_read, + Stream::InputChannel(ptr) => ptr.lines_read += incr_num_lines_read, } } @@ -742,6 +773,7 @@ impl Stream { Stream::StandardOutput(ptr) => ptr.lines_read = value, Stream::StandardError(ptr) => ptr.lines_read = value, Stream::Callback(ptr) => ptr.lines_read = value, + Stream::InputChannel(ptr) => ptr.lines_read = value, } } @@ -764,6 +796,7 @@ impl Stream { Stream::StandardOutput(ptr) => ptr.lines_read, Stream::StandardError(ptr) => ptr.lines_read, Stream::Callback(ptr) => ptr.lines_read, + Stream::InputChannel(ptr) => ptr.lines_read, } } } @@ -780,6 +813,7 @@ impl CharRead for Stream { Stream::Readline(rl_stream) => (*rl_stream).peek_char(), Stream::StaticString(src) => (*src).peek_char(), Stream::Byte(cursor) => (*cursor).peek_char(), + Stream::InputChannel(cursor) => (*cursor).peek_char(), #[cfg(feature = "http")] Stream::HttpWrite(_) => Some(Err(std::io::Error::new( ErrorKind::PermissionDenied, @@ -807,6 +841,7 @@ impl CharRead for Stream { Stream::Readline(rl_stream) => (*rl_stream).read_char(), Stream::StaticString(src) => (*src).read_char(), Stream::Byte(cursor) => (*cursor).read_char(), + Stream::InputChannel(cursor) => (*cursor).read_char(), #[cfg(feature = "http")] Stream::HttpWrite(_) => Some(Err(std::io::Error::new( ErrorKind::PermissionDenied, @@ -841,6 +876,7 @@ impl CharRead for Stream { | Stream::StandardOutput(_) | Stream::Null(_) | Stream::Callback(_) => {} + Stream::InputChannel(_) => {} } } @@ -855,6 +891,7 @@ impl CharRead for Stream { Stream::Readline(ref mut rl_stream) => rl_stream.consume(nread), Stream::StaticString(ref mut src) => src.consume(nread), Stream::Byte(ref mut cursor) => cursor.consume(nread), + Stream::InputChannel(ref mut cursor) => cursor.consume(nread), #[cfg(feature = "http")] Stream::HttpWrite(_) => {} Stream::OutputFile(_) @@ -879,6 +916,7 @@ impl Read for Stream { Stream::Readline(rl_stream) => (*rl_stream).read(buf), Stream::StaticString(src) => (*src).read(buf), Stream::Byte(cursor) => (*cursor).read(buf), + Stream::InputChannel(cursor) => (*cursor).read(buf), #[cfg(feature = "http")] Stream::HttpWrite(_) => Err(std::io::Error::new( ErrorKind::PermissionDenied, @@ -915,6 +953,7 @@ impl Write for Stream { StreamError::WriteToInputStream, )), Stream::StaticString(_) + | Stream::InputChannel(_) | Stream::Readline(_) | Stream::InputFile(..) | Stream::Null(_) => Err(std::io::Error::new( @@ -942,6 +981,7 @@ impl Write for Stream { StreamError::FlushToInputStream, )), Stream::StaticString(_) + | Stream::InputChannel(_) | Stream::Readline(_) | Stream::InputFile(_) | Stream::Null(_) => Err(std::io::Error::new( @@ -1095,6 +1135,7 @@ impl Stream { Stream::StandardOutput(stream) => stream.past_end_of_stream, Stream::StandardError(stream) => stream.past_end_of_stream, Stream::Callback(stream) => stream.past_end_of_stream, + Stream::InputChannel(stream) => stream.past_end_of_stream, } } @@ -1122,6 +1163,7 @@ impl Stream { Stream::StandardOutput(stream) => stream.past_end_of_stream = value, Stream::StandardError(stream) => stream.past_end_of_stream = value, Stream::Callback(stream) => stream.past_end_of_stream = value, + Stream::InputChannel(stream) => stream.past_end_of_stream = value, } } @@ -1221,6 +1263,7 @@ impl Stream { #[cfg(feature = "tls")] Stream::NamedTls(..) => atom!("read_append"), Stream::Byte(_) + | Stream::InputChannel(_) | Stream::Readline(_) | Stream::StaticString(_) | Stream::InputFile(..) => atom!("read"), @@ -1396,6 +1439,10 @@ impl Stream { stream.drop_payload(); Ok(()) } + Stream::InputChannel(mut stream) => { + stream.drop_payload(); + Ok(()) + } Stream::StaticString(mut stream) => { stream.drop_payload(); Ok(()) @@ -1423,6 +1470,7 @@ impl Stream { Stream::HttpRead(..) => true, Stream::NamedTcp(..) | Stream::Byte(_) + | Stream::InputChannel(_) | Stream::Readline(_) | Stream::StaticString(_) | Stream::InputFile(..) => true, diff --git a/src/macros.rs b/src/macros.rs index 9b2cbabe..547ccc00 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -306,6 +306,7 @@ macro_rules! match_untyped_arena_ptr_pat { | ArenaHeaderTag::StaticStringStream | ArenaHeaderTag::ByteStream | ArenaHeaderTag::CallbackStream + | ArenaHeaderTag::InputChannelStream | ArenaHeaderTag::StandardOutputStream | ArenaHeaderTag::StandardErrorStream };