]> Repositorios git - scryer-prolog.git/commitdiff
optimize CharReader
authorSkgland <[email protected]>
Thu, 30 Apr 2026 17:44:31 +0000 (19:44 +0200)
committerSkgland <[email protected]>
Thu, 30 Apr 2026 17:46:40 +0000 (19:46 +0200)
src/parser/char_reader.rs

index 334a75f5bd1939eeb303fe55be7d59b970864ec1..62cc81566e36246faa2d47d2d6749ea5b9d01d45 100644 (file)
@@ -102,19 +102,27 @@ impl<R> CharReader<R> {
 }
 
 impl<R: Read> CharReader<R> {
+    pub fn read_chunck(&mut self) -> io::Result<usize> {
+        let mut chunk = [0u8; 8 * 1024];
+        let nread = self.inner.read(&mut chunk)?;
+        self.buf.extend_from_slice(&chunk[..nread]);
+        Ok(nread)
+    }
+
     pub fn refresh_buffer(&mut self) -> io::Result<&[u8]> {
         // If we've reached the end of our internal buffer then we need to fetch
         // some more data from the underlying reader.
         // Branch using `>=` instead of the more correct `==`
         // to tell the compiler that the pos..cap slice is always valid.
         if self.pos >= self.buf.len() {
-            self.buf.clear();
-
-            let mut chunk = [0u8; 8 * 1024];
-            let nread = self.inner.read(&mut chunk)?;
+            // make some space in buf
+            if self.buf.len() > 4 {
+                // keep 4 bytes so that put_back_char can put back at least one char
+                self.buf.drain(4..);
+            }
+            self.pos = self.buf.len();
 
-            self.buf.extend_from_slice(&chunk[..nread]);
-            self.pos = 0;
+            self.read_chunck()?;
         }
 
         Ok(&self.buf[self.pos..])
@@ -143,93 +151,83 @@ impl<R: Read> CharRead for CharReader<R> {
             // leading bytes until either the buffer is
             // empty, or we have a valid code point.
 
-            let mut split_point = 1;
-            let mut badbytes = vec![];
+            // note we might have a sequence of invalid bytes followed by valid bytes followed by invalid bytes
 
-            loop {
-                let (bad, rest) = buf.split_at(split_point);
+            let err = str::from_utf8(buf).expect_err("the start of buf should be invalid utf-8");
+            assert_eq!(err.valid_up_to(), 0, "the error should be a prefix");
 
-                if rest.is_empty() || str::from_utf8(rest).is_ok() {
-                    badbytes.extend_from_slice(bad);
-                    break;
-                }
+            let invalid_prefix = err.error_len().expect("we should have at least 4 bytes");
 
-                split_point += 1;
-            }
+            let bad_bytes = buf[..invalid_prefix].to_vec();
 
             // Raise the error. If we still have data in
             // the buffer, it will be returned on the next
             // loop.
 
-            io::Error::new(io::ErrorKind::InvalidData, BadUtf8Error { bytes: badbytes })
-        };
+            io::Error::new(
+                io::ErrorKind::InvalidData,
+                BadUtf8Error { bytes: bad_bytes },
+            )
+        }
 
-        loop {
+        // while we haven't consumed all bytes from the buffer
+        while self.pos < self.buf.len() {
+            // buf must be non-empty
             let buf = &self.buf[self.pos..];
 
-            if !buf.is_empty() {
-                let e = match str::from_utf8(buf) {
-                    Ok(s) => {
-                        let mut chars = s.chars();
-                        let c = chars.next().unwrap();
-
-                        return Some(Ok(c));
-                    }
-                    Err(e) => e,
-                };
-
-                if buf.len() - e.valid_up_to() >= 4 {
-                    return Some(Err(bad_bytes_error(buf)));
-                } else if self.pos >= self.buf.len() {
-                    return None;
-                } else if self.buf.len() - self.pos >= 4 && self.pos < e.valid_up_to() {
-                    return match str::from_utf8(&self.buf[self.pos..self.pos + e.valid_up_to()]) {
-                        Ok(s) => {
-                            let mut chars = s.chars();
-                            let c = chars.next().unwrap();
-
-                            Some(Ok(c))
-                        }
-                        Err(e) => {
-                            let badbytes = self.buf[self.pos..self.pos + e.valid_up_to()].to_vec();
-
-                            Some(Err(io::Error::new(
-                                io::ErrorKind::InvalidData,
-                                BadUtf8Error { bytes: badbytes },
-                            )))
-                        }
-                    };
-                } else {
-                    let buf_len = self.buf.len();
-
-                    for (c, idx) in (self.pos..buf_len).enumerate() {
-                        self.buf[c] = self.buf[idx];
-                    }
-
-                    self.buf.truncate(buf_len - self.pos);
-
-                    let buf_len = self.buf.len();
-                    self.pos = 0;
-
-                    if buf_len >= 4 {
-                        continue;
-                    }
-
-                    let mut word = [0u8; 4];
-                    let word_slice = &mut word[buf_len..4];
-
-                    match self.inner.read(word_slice) {
-                        Err(e) => return Some(Err(e)),
-                        Ok(0) => return Some(Err(bad_bytes_error(&self.buf))),
-                        Ok(nread) => {
-                            self.buf.extend_from_slice(&word_slice[0..nread]);
-                        }
-                    }
+            // we need at most 4 bytes for a char so don't decode the whole buffer
+            // as it can be quite large and we are going to discard the remaining chars anyway
+            // if there is a valid prefix
+            let prefix = if buf.len() > 4 { &buf[..4] } else { buf };
+
+            let e = match str::from_utf8(prefix) {
+                Ok(s) => {
+                    let mut chars = s.chars();
+                    let c = chars.next().expect(
+                        "a non-empty buffer that is valid utf-8 contains at least one character",
+                    );
+
+                    return Some(Ok(c));
+                }
+                Err(e) => e,
+            };
+
+            if e.valid_up_to() != 0 {
+                // the valid prefix is non-empty so it is guaranteed that we can decode at least one char
+                let c = str::from_utf8(&prefix[..e.valid_up_to()])
+                    .expect("prefix is verified valid up to this point")
+                    .chars()
+                    .next()
+                    .expect("the valid prefix was non-empty");
+                return Some(Ok(c));
+            }
+
+            if e.error_len().is_some() {
+                return Some(Err(bad_bytes_error(buf)));
+            }
+
+            // buf is too short to deterin if the remaining bytes in buf are a valid char
+            // i.e. the content of bufg is a prefix of a valid utf-8 encoded char
+            //
+            // we need to read more data from the underlying stream
+            // so that we can determin its validity
+
+            if self.buf.len() > 4 {
+                // keep a prefix of 4 bytes so that we put back at least one char
+                self.buf.drain(4..self.pos);
+                self.pos = 4;
+            }
+
+            match self.read_chunck() {
+                Err(e) => return Some(Err(e)),
+                Ok(0) => return Some(Err(bad_bytes_error(&self.buf))),
+                Ok(_) => {
+                    // successfully filled the buffer with another chuck of data
                 }
-            } else {
-                return None;
             }
         }
+
+        None
     }
 
     #[inline(always)]
@@ -394,6 +392,15 @@ mod tests {
             std::io::ErrorKind::InvalidData
         );
 
+        let err = read_string
+            .read_char()
+            .unwrap()
+            .unwrap_err()
+            .downcast::<BadUtf8Error>()
+            .unwrap();
+
+        read_string.consume(err.bytes.len());
+
         for c in "more_text".chars() {
             assert_eq!(read_string.peek_char().unwrap().ok(), Some(c));
             assert_eq!(read_string.read_char().unwrap().ok(), Some(c));
@@ -404,6 +411,15 @@ mod tests {
             std::io::ErrorKind::InvalidData
         );
 
+        let err = read_string
+            .read_char()
+            .unwrap()
+            .unwrap_err()
+            .downcast::<BadUtf8Error>()
+            .unwrap();
+
+        read_string.consume(err.bytes.len());
+
         assert!(read_string.read_char().is_none());
     }