]> Repositorios git - scryer-prolog.git/commitdiff
Add callback streams
authorbakaq <[email protected]>
Tue, 28 Jan 2025 20:58:40 +0000 (17:58 -0300)
committerbakaq <[email protected]>
Sun, 16 Feb 2025 06:52:52 +0000 (03:52 -0300)
src/arena.rs
src/machine/config.rs
src/machine/streams.rs
src/macros.rs

index e5aceb4dfac24c5cdce956cb4d55c3b4cfb33a4e..e3e6cbbdf01b75d80535a5aa5658c3396b2ff861 100644 (file)
@@ -181,6 +181,7 @@ pub enum ArenaHeaderTag {
     ReadlineStream = 0b110000,
     StaticStringStream = 0b110100,
     ByteStream = 0b111000,
+    CallbackStream = 0b111001,
     StandardOutputStream = 0b1100,
     StandardErrorStream = 0b11000,
     NullStream = 0b111100,
@@ -841,6 +842,9 @@ unsafe fn drop_slab_in_place(value: NonNull<AllocSlab>, tag: ArenaHeaderTag) {
         ArenaHeaderTag::ByteStream => {
             drop_typed_slab_in_place!(ByteStream, value);
         }
+        ArenaHeaderTag::CallbackStream => {
+            drop_typed_slab_in_place!(CallbackStream, value);
+        }
         ArenaHeaderTag::LiveLoadState | ArenaHeaderTag::InactiveLoadState => {
             drop_typed_slab_in_place!(LiveLoadState, value);
         }
index 2981899d0899e2461a517eb7db013e01d5c8f97c..34529434de023d2eee5cf982d002c7de2f41166a 100644 (file)
@@ -6,7 +6,8 @@ use crate::Machine;
 
 use super::{
     bootstrapping_compile, current_dir, import_builtin_impls, libraries, load_module, Atom,
-    CompilationTarget, IndexStore, ListingSource, MachineArgs, MachineState, Stream, StreamOptions,
+    Callback, CompilationTarget, IndexStore, ListingSource, MachineArgs, MachineState, Stream,
+    StreamOptions,
 };
 
 /// Describes how the streams of a [`Machine`](crate::Machine) will be handled.
@@ -31,6 +32,13 @@ impl StreamConfig {
             inner: StreamConfigInner::Memory,
         }
     }
+
+    /// Calls the given callbacks when the respective streams are written to.
+    pub fn with_callbacks(stdout: Option<Callback>, stderr: Option<Callback>) -> Self {
+        StreamConfig {
+            inner: StreamConfigInner::Callbacks { stdout, stderr },
+        }
+    }
 }
 
 #[derive(Default)]
@@ -38,6 +46,10 @@ enum StreamConfigInner {
     Stdio,
     #[default]
     Memory,
+    Callbacks {
+        stdout: Option<Callback>,
+        stderr: Option<Callback>,
+    },
 }
 
 /// Describes how a [`Machine`](crate::Machine) will be configured.
@@ -90,6 +102,17 @@ impl MachineBuilder {
                 Stream::from_owned_string("".to_owned(), &mut machine_st.arena),
                 Stream::stderr(&mut machine_st.arena),
             ),
+            StreamConfigInner::Callbacks { stdout, stderr } => (
+                Stream::Null(StreamOptions::default()),
+                stdout.map_or_else(
+                    || Stream::Null(StreamOptions::default()),
+                    |x| Stream::from_callback(x, &mut machine_st.arena),
+                ),
+                stderr.map_or_else(
+                    || Stream::Null(StreamOptions::default()),
+                    |x| Stream::from_callback(x, &mut machine_st.arena),
+                ),
+            ),
         };
 
         let mut wam = Machine {
index b1ecba0354cf2d80ab3eca758e45759789aaeee6..0c180d4a795fc007eeb37e2093ba6f7364a96b14 100644 (file)
@@ -24,6 +24,7 @@ use std::fs::{File, OpenOptions};
 use std::hash::Hash;
 use std::io;
 use std::io::{Cursor, ErrorKind, Read, Seek, SeekFrom, Write};
+use std::mem::ManuallyDrop;
 use std::net::{Shutdown, TcpStream};
 use std::ops::{Deref, DerefMut};
 use std::path::PathBuf;
@@ -375,6 +376,40 @@ impl Write for StandardErrorStream {
     }
 }
 
+pub type Callback = Box<dyn FnMut(&mut Cursor<Vec<u8>>)>;
+
+pub struct CallbackStream {
+    pub(crate) inner: Cursor<Vec<u8>>,
+    callback: Callback,
+}
+
+impl Debug for CallbackStream {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("CallbackStream")
+            .field("inner", &self.inner)
+            .finish()
+    }
+}
+
+impl Write for CallbackStream {
+    #[inline]
+    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
+        let pos = self.inner.position();
+
+        self.inner.seek(SeekFrom::End(0))?;
+        let result = self.inner.write(buf);
+        self.inner.seek(SeekFrom::Start(pos))?;
+
+        result
+    }
+
+    #[inline]
+    fn flush(&mut self) -> std::io::Result<()> {
+        (self.callback)(&mut self.inner);
+        self.inner.flush()
+    }
+}
+
 #[bitfield]
 #[repr(u64)]
 #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
@@ -500,6 +535,7 @@ arena_allocated_impl_for_stream!(ReadlineStream, ReadlineStream);
 arena_allocated_impl_for_stream!(StaticStringStream, StaticStringStream);
 arena_allocated_impl_for_stream!(StandardOutputStream, StandardOutputStream);
 arena_allocated_impl_for_stream!(StandardErrorStream, StandardErrorStream);
+arena_allocated_impl_for_stream!(CharReader<CallbackStream>, CallbackStream);
 
 #[derive(Debug, Copy, Clone)]
 pub enum Stream {
@@ -518,6 +554,7 @@ pub enum Stream {
     Readline(TypedArenaPtr<ReadlineStream>),
     StandardOutput(TypedArenaPtr<StandardOutputStream>),
     StandardError(TypedArenaPtr<StandardErrorStream>),
+    Callback(TypedArenaPtr<CallbackStream>),
 }
 
 impl From<TypedArenaPtr<ReadlineStream>> for Stream {
@@ -581,6 +618,7 @@ impl Stream {
             ArenaHeaderTag::Dropped | ArenaHeaderTag::NullStream => {
                 Stream::Null(StreamOptions::default())
             }
+            ArenaHeaderTag::CallbackStream => Stream::Callback(unsafe { ptr.as_typed_ptr() }),
             _ => unreachable!(),
         }
     }
@@ -617,6 +655,7 @@ impl Stream {
             Stream::Readline(ptr) => ptr.header_ptr(),
             Stream::StandardOutput(ptr) => ptr.header_ptr(),
             Stream::StandardError(ptr) => ptr.header_ptr(),
+            Stream::Callback(ptr) => ptr.header_ptr(),
         }
     }
 
@@ -637,6 +676,7 @@ impl Stream {
             Stream::Readline(ref ptr) => &ptr.options,
             Stream::StandardOutput(ref ptr) => &ptr.options,
             Stream::StandardError(ref ptr) => &ptr.options,
+            Stream::Callback(ref ptr) => &ptr.options,
         }
     }
 
@@ -657,6 +697,7 @@ impl Stream {
             Stream::Readline(ref mut ptr) => &mut ptr.options,
             Stream::StandardOutput(ref mut ptr) => &mut ptr.options,
             Stream::StandardError(ref mut ptr) => &mut ptr.options,
+            Stream::Callback(ref mut ptr) => &mut ptr.options,
         }
     }
 
@@ -678,6 +719,7 @@ impl Stream {
             Stream::Readline(ptr) => ptr.lines_read += incr_num_lines_read,
             Stream::StandardOutput(ptr) => ptr.lines_read += incr_num_lines_read,
             Stream::StandardError(ptr) => ptr.lines_read += incr_num_lines_read,
+            Stream::Callback(ptr) => ptr.lines_read += incr_num_lines_read,
         }
     }
 
@@ -699,6 +741,7 @@ impl Stream {
             Stream::Readline(ptr) => ptr.lines_read = value,
             Stream::StandardOutput(ptr) => ptr.lines_read = value,
             Stream::StandardError(ptr) => ptr.lines_read = value,
+            Stream::Callback(ptr) => ptr.lines_read = value,
         }
     }
 
@@ -720,6 +763,7 @@ impl Stream {
             Stream::Readline(ptr) => ptr.lines_read,
             Stream::StandardOutput(ptr) => ptr.lines_read,
             Stream::StandardError(ptr) => ptr.lines_read,
+            Stream::Callback(ptr) => ptr.lines_read,
         }
     }
 }
@@ -744,7 +788,8 @@ impl CharRead for Stream {
             Stream::OutputFile(_)
             | Stream::StandardError(_)
             | Stream::StandardOutput(_)
-            | Stream::Null(_) => Some(Err(std::io::Error::new(
+            | Stream::Null(_)
+            | Stream::Callback(_) => Some(Err(std::io::Error::new(
                 ErrorKind::PermissionDenied,
                 StreamError::ReadFromOutputStream,
             ))),
@@ -770,7 +815,8 @@ impl CharRead for Stream {
             Stream::OutputFile(_)
             | Stream::StandardError(_)
             | Stream::StandardOutput(_)
-            | Stream::Null(_) => Some(Err(std::io::Error::new(
+            | Stream::Null(_)
+            | Stream::Callback(_) => Some(Err(std::io::Error::new(
                 ErrorKind::PermissionDenied,
                 StreamError::ReadFromOutputStream,
             ))),
@@ -793,7 +839,8 @@ impl CharRead for Stream {
             Stream::OutputFile(_)
             | Stream::StandardError(_)
             | Stream::StandardOutput(_)
-            | Stream::Null(_) => {}
+            | Stream::Null(_)
+            | Stream::Callback(_) => {}
         }
     }
 
@@ -813,7 +860,8 @@ impl CharRead for Stream {
             Stream::OutputFile(_)
             | Stream::StandardError(_)
             | Stream::StandardOutput(_)
-            | Stream::Null(_) => {}
+            | Stream::Null(_)
+            | Stream::Callback(_) => {}
         }
     }
 }
@@ -839,7 +887,8 @@ impl Read for Stream {
             Stream::OutputFile(_)
             | Stream::StandardError(_)
             | Stream::StandardOutput(_)
-            | Stream::Null(_) => Err(std::io::Error::new(
+            | Stream::Null(_)
+            | Stream::Callback(_) => Err(std::io::Error::new(
                 ErrorKind::PermissionDenied,
                 StreamError::ReadFromOutputStream,
             )),
@@ -855,6 +904,7 @@ impl Write for Stream {
             #[cfg(feature = "tls")]
             Stream::NamedTls(ref mut tls_stream) => tls_stream.get_mut().write(buf),
             Stream::Byte(ref mut cursor) => cursor.get_mut().write(buf),
+            Stream::Callback(ref mut callback_stream) => callback_stream.get_mut().write(buf),
             Stream::StandardOutput(stream) => stream.write(buf),
             Stream::StandardError(stream) => stream.write(buf),
             #[cfg(feature = "http")]
@@ -881,6 +931,7 @@ impl Write for Stream {
             #[cfg(feature = "tls")]
             Stream::NamedTls(ref mut tls_stream) => tls_stream.stream.get_mut().flush(),
             Stream::Byte(ref mut cursor) => cursor.stream.get_mut().flush(),
+            Stream::Callback(ref mut callback_stream) => callback_stream.stream.get_mut().flush(),
             Stream::StandardError(stream) => stream.stream.flush(),
             Stream::StandardOutput(stream) => stream.stream.flush(),
             #[cfg(feature = "http")]
@@ -1043,6 +1094,7 @@ impl Stream {
             Stream::Readline(stream) => stream.past_end_of_stream,
             Stream::StandardOutput(stream) => stream.past_end_of_stream,
             Stream::StandardError(stream) => stream.past_end_of_stream,
+            Stream::Callback(stream) => stream.past_end_of_stream,
         }
     }
 
@@ -1069,6 +1121,7 @@ impl Stream {
             Stream::Readline(stream) => stream.past_end_of_stream = value,
             Stream::StandardOutput(stream) => stream.past_end_of_stream = value,
             Stream::StandardError(stream) => stream.past_end_of_stream = value,
+            Stream::Callback(stream) => stream.past_end_of_stream = value,
         }
     }
 
@@ -1175,7 +1228,10 @@ impl Stream {
             Stream::OutputFile(file) if file.is_append => atom!("append"),
             #[cfg(feature = "http")]
             Stream::HttpWrite(_) => atom!("write"),
-            Stream::OutputFile(_) | Stream::StandardError(_) | Stream::StandardOutput(_) => {
+            Stream::OutputFile(_)
+            | Stream::StandardError(_)
+            | Stream::StandardOutput(_)
+            | Stream::Callback(_) => {
                 atom!("write")
             }
             Stream::Null(_) => atom!(""),
@@ -1198,6 +1254,17 @@ impl Stream {
         ))
     }
 
+    #[inline]
+    pub fn from_callback(callback: Callback, arena: &mut Arena) -> Self {
+        Stream::Callback(arena_alloc!(
+            ManuallyDrop::new(StreamLayout::new(CharReader::new(CallbackStream {
+                inner: Cursor::new(Vec::new()),
+                callback,
+            }))),
+            arena
+        ))
+    }
+
     #[inline]
     pub(crate) fn from_tcp_stream(address: Atom, tcp_stream: TcpStream, arena: &mut Arena) -> Self {
         tcp_stream.set_read_timeout(None).unwrap();
@@ -1325,6 +1392,10 @@ impl Stream {
                 stream.drop_payload();
                 Ok(())
             }
+            Stream::Callback(mut stream) => {
+                stream.drop_payload();
+                Ok(())
+            }
             Stream::StaticString(mut stream) => {
                 stream.drop_payload();
                 Ok(())
@@ -1370,6 +1441,7 @@ impl Stream {
             | Stream::StandardOutput(_)
             | Stream::NamedTcp(..)
             | Stream::Byte(_)
+            | Stream::Callback(_)
             | Stream::OutputFile(..) => true,
             _ => false,
         }
index 30a863cac21439ac9ca57cd1890febcaf1c660dc..9b2cbabe4bb528eef1f57cc93db09b165b42ab25 100644 (file)
@@ -305,6 +305,7 @@ macro_rules! match_untyped_arena_ptr_pat {
             | ArenaHeaderTag::ReadlineStream
             | ArenaHeaderTag::StaticStringStream
             | ArenaHeaderTag::ByteStream
+            | ArenaHeaderTag::CallbackStream
             | ArenaHeaderTag::StandardOutputStream
             | ArenaHeaderTag::StandardErrorStream
     };