]> Repositorios git - scryer-prolog.git/commitdiff
correct GetPartialString (#2887)
authorMark Thom <[email protected]>
Thu, 17 Apr 2025 05:20:31 +0000 (22:20 -0700)
committerMark Thom <[email protected]>
Wed, 23 Apr 2025 06:33:11 +0000 (23:33 -0700)
src/machine/dispatch.rs
src/machine/heap.rs
src/machine/mod.rs
src/machine/partial_string.rs

index 3cb1a8e5effdaa54989122ecc1b5db0b7d9fae41..4494b8cde7d283c77ffaa1483c8d05e5ff9e1f77 100644 (file)
@@ -2806,62 +2806,142 @@ impl Machine {
                         self.machine_st.p += 1;
                     }
                     &Instruction::GetPartialString(_, ref string, reg) => {
-                        use crate::machine::partial_string::{HeapPStrIter, PStrCmpResult};
+                        self.machine_st.heap[0] = self.machine_st[reg];
+
+                        let mut h = 0;
+                        let mut string_cursor = string.as_str();
+
+                        while let Some(c) = string_cursor.chars().next() {
+                            read_heap_cell!(self.machine_st.heap[h],
+                                (HeapCellValueTag::PStrLoc, pstr_loc) => {
+                                    let heap_slice = &self.machine_st.heap.as_slice()[pstr_loc ..];
+
+                                    match compare_pstr_slices(heap_slice, string_cursor.as_bytes()) {
+                                        PStrSegmentCmpResult::Continue(v1, v2) => {
+                                            // for v2, the value of a TailIndex mustn't ever be read
+                                            // since string does not lie in the heap.
+                                            match (v1, v2) {
+                                                (PStrContinuable::TailIndex(tail_idx), PStrContinuable::TailIndex(_)) => {
+                                                    self.machine_st.s = HeapPtr::HeapCell(tail_idx + cell_index!(pstr_loc));
+                                                    self.machine_st.s_offset = 0;
+                                                    self.machine_st.mode = MachineMode::Read;
+
+                                                    break;
+                                                }
+                                                (PStrContinuable::TailIndex(tail_idx), PStrContinuable::PStrOffset(pos)) => {
+                                                    h = tail_idx + cell_index!(pstr_loc);
+                                                    string_cursor = &string_cursor[pos ..];
+                                                }
+                                                (PStrContinuable::PStrOffset(pos), PStrContinuable::TailIndex(_)) => {
+                                                    self.machine_st.s = HeapPtr::PStr(pstr_loc);
+                                                    self.machine_st.s_offset = pos;
+                                                    self.machine_st.mode = MachineMode::Read;
 
-                        let deref_v = self.machine_st.deref(self.machine_st[reg]);
-                        let store_v = self.machine_st.store(deref_v);
+                                                    break;
+                                                }
+                                                _ => unreachable!(),
+                                            }
+                                        }
+                                        _ => {
+                                            self.machine_st.fail = true;
+                                            break;
+                                        }
+                                    }
+                                }
+                                (HeapCellValueTag::Lis, l) => {
+                                    let cell = self.machine_st.store(self.machine_st.deref(self.machine_st.heap[l]));
 
-                        read_heap_cell!(store_v,
-                            (HeapCellValueTag::Str |
-                             HeapCellValueTag::Lis |
-                             HeapCellValueTag::PStrLoc) => {
-                                self.machine_st.heap[0] = store_v;
-                                let heap_pstr_iter = HeapPStrIter::new(&self.machine_st.heap, 0);
+                                    if let Some(d) = cell.as_char() {
+                                        if c != d {
+                                            self.machine_st.fail = true;
+                                            break;
+                                        }
+                                    } else if let Some(r) = cell.as_var() {
+                                        self.machine_st.bind(r, char_as_cell!(c));
+                                    } else {
+                                        self.machine_st.fail = true;
+                                    }
 
-                                match heap_pstr_iter.compare_pstr_to_string(string) {
-                                    Some(PStrCmpResult::CompletePStrMatch { chars_matched, pstr_loc }) => {
-                                        self.machine_st.s_offset = chars_matched;
-                                        self.machine_st.s = HeapPtr::PStr(pstr_loc);
-                                        self.machine_st.mode = MachineMode::Read;
+                                    if self.machine_st.fail {
+                                        break;
+                                    } else {
+                                        h = l+1;
+                                        string_cursor = &string_cursor[c.len_utf8() ..];
+
+                                        if string_cursor.is_empty() {
+                                            self.machine_st.s = HeapPtr::HeapCell(h);
+                                            self.machine_st.s_offset = 0;
+                                            self.machine_st.mode = MachineMode::Read;
+                                        }
                                     }
-                                    Some(PStrCmpResult::PartialPStrMatch { string, var_loc }) => {
-                                        let cell = backtrack_on_resource_error!(
-                                            self.machine_st,
-                                            self.machine_st.heap.allocate_pstr(string)
-                                        );
+                                }
+                                (HeapCellValueTag::Str, s) => {
+                                    let cell = self.machine_st.store(self.machine_st.deref(self.machine_st.heap[s+1]));
 
-                                        self.machine_st.mode = MachineMode::Write;
-                                        unify!(self.machine_st, cell, heap_loc_as_cell!(var_loc));
+                                    if let Some(d) = cell.as_char() {
+                                        if c != d {
+                                            self.machine_st.fail = true;
+                                            break;
+                                        }
+                                    } else if let Some(r) = cell.as_var() {
+                                        self.machine_st.bind(r, char_as_cell!(c));
+                                    } else {
+                                        self.machine_st.fail = true;
                                     }
-                                    Some(PStrCmpResult::ListMatch { list_loc }) => {
+
+                                    if self.machine_st.fail {
+                                        break;
+                                    }
+
+                                    h = s+2;
+                                    string_cursor = &string_cursor[c.len_utf8() ..];
+
+                                    if string_cursor.is_empty() {
+                                        self.machine_st.s = HeapPtr::HeapCell(h);
                                         self.machine_st.s_offset = 0;
-                                        self.machine_st.s = HeapPtr::HeapCell(list_loc);
                                         self.machine_st.mode = MachineMode::Read;
                                     }
-                                    None => {
-                                        self.machine_st.fail = true;
+                                }
+                                (HeapCellValueTag::AttrVar | HeapCellValueTag::Var, v) => {
+                                    if h == v {
+                                        let target_cell = backtrack_on_resource_error!(
+                                            self.machine_st,
+                                            self.machine_st.heap.allocate_pstr(string_cursor)
+                                        );
+
+                                        self.machine_st.bind(
+                                            self.machine_st.heap[h].as_var().unwrap(),
+                                            target_cell,
+                                        );
+
+                                        self.machine_st.mode = MachineMode::Write;
+                                        break;
+                                    } else {
+                                        h = v;
                                     }
                                 }
-                            }
-                            (HeapCellValueTag::AttrVar |
-                             HeapCellValueTag::StackVar |
-                             HeapCellValueTag::Var) => {
-                                let target_cell = backtrack_on_resource_error!(
-                                    self.machine_st,
-                                    self.machine_st.heap.allocate_pstr(string)
-                                );
+                                (HeapCellValueTag::StackVar, s) => {
+                                    debug_assert_eq!(h, 0);
 
-                                self.machine_st.bind(
-                                    store_v.as_var().unwrap(),
-                                    target_cell,
-                                );
+                                    let target_cell = backtrack_on_resource_error!(
+                                        self.machine_st,
+                                        self.machine_st.heap.allocate_pstr(string_cursor)
+                                    );
 
-                                self.machine_st.mode = MachineMode::Write;
-                            }
-                            _ => {
-                                self.machine_st.fail = true;
-                            }
-                        );
+                                    self.machine_st.bind(
+                                        Ref::stack_cell(s),
+                                        target_cell,
+                                    );
+
+                                    self.machine_st.mode = MachineMode::Write;
+                                    break;
+                                }
+                                _ => {
+                                    self.machine_st.fail = true;
+                                    break;
+                                }
+                            );
+                        }
 
                         step_or_fail!(self, self.machine_st.p += 1);
                     }
index fa554437afc5fa4734c914639fc077d65de9253b..f682a1a0879f38443a3fa8a188dc777ffc46e4e1 100644 (file)
@@ -135,69 +135,72 @@ pub(crate) enum PStrSegmentCmpResult {
 }
 
 pub(crate) fn compare_pstr_slices(slice1: &[u8], slice2: &[u8]) -> PStrSegmentCmpResult {
-    use std::cmp::Ordering;
-
     debug_assert!(!slice1.is_empty() && !slice2.is_empty());
     let find_tail = |slice| unsafe { scan_slice_to_str(slice).tail_idx };
 
-    match slice1
-        .iter()
-        .zip(slice2.iter())
-        .position(|(b1, b2)| b1 != b2 || *b1 == 0 || *b2 == 0)
-    {
-        Some(pos) => {
-            if slice1[pos] == 0 {
-                // subtract 1 from pos to offset the increment of scan_slice_to_str if the
-                // string is "\0\".
-                let tail1_idx = find_tail(&slice1[pos..]);
-
-                if slice2[pos] == 0 {
-                    let tail2_idx = find_tail(&slice2[pos..]);
-
-                    PStrSegmentCmpResult::Continue(
-                        PStrContinuable::TailIndex(tail1_idx + cell_index!(pos)),
-                        PStrContinuable::TailIndex(tail2_idx + cell_index!(pos)),
-                    )
-                } else {
-                    PStrSegmentCmpResult::Continue(
-                        PStrContinuable::TailIndex(tail1_idx + cell_index!(pos)),
-                        PStrContinuable::PStrOffset(pos),
-                    )
-                }
-            } else if slice2[pos] == 0 {
+    let calculate_result = |pos| {
+        use std::cmp::Ordering;
+
+        if slice1.get(pos).cloned().unwrap_or(0) == 0 {
+            // subtract 1 from pos to offset the increment of scan_slice_to_str if the
+            // string is "\0\".
+            let tail1_idx = find_tail(&slice1[pos..]);
+            let offset_pos_1 = (ALIGN - slice1.as_ptr().align_offset(ALIGN)) % ALIGN;
+
+            if slice2.get(pos).cloned().unwrap_or(0) == 0 {
                 let tail2_idx = find_tail(&slice2[pos..]);
+                let offset_pos_2 = (ALIGN - slice2.as_ptr().align_offset(ALIGN)) % ALIGN;
 
                 PStrSegmentCmpResult::Continue(
-                    PStrContinuable::PStrOffset(pos),
-                    PStrContinuable::TailIndex(tail2_idx + cell_index!(pos)),
+                    PStrContinuable::TailIndex(tail1_idx + cell_index!(pos + offset_pos_1)),
+                    PStrContinuable::TailIndex(tail2_idx + cell_index!(pos + offset_pos_2)),
                 )
             } else {
-                // Compute 7-byte chunks with the mismatching character at pos in the middle of
-                // each. This way, the character of which the byte at pos is a part will be
-                // validated and reached eventually by the utf8_chunks() iterator.
+                PStrSegmentCmpResult::Continue(
+                    PStrContinuable::TailIndex(tail1_idx + cell_index!(pos)),
+                    PStrContinuable::PStrOffset(pos),
+                )
+            }
+        } else if slice2.get(pos).cloned().unwrap_or(0) == 0 {
+            let tail2_idx = find_tail(&slice2[pos..]);
+            let offset_pos_2 = (ALIGN - slice2.as_ptr().align_offset(ALIGN)) % ALIGN;
 
-                let slice1_range = pos.saturating_sub(3)..(pos + 4).min(slice1.len());
-                let slice2_range = pos.saturating_sub(3)..(pos + 4).min(slice2.len());
+            PStrSegmentCmpResult::Continue(
+                PStrContinuable::PStrOffset(pos),
+                PStrContinuable::TailIndex(tail2_idx + cell_index!(pos + offset_pos_2)),
+            )
+        } else {
+            // Compute 7-byte chunks with the mismatching character at pos in the middle of
+            // each. This way, the character of which the byte at pos is a part will be
+            // validated and reached eventually by the utf8_chunks() iterator.
 
-                let chars1_iter = slice1[slice1_range].utf8_chunks();
-                let chars2_iter = slice2[slice2_range].utf8_chunks();
+            let slice1_range = pos.saturating_sub(3)..(pos + 4).min(slice1.len());
+            let slice2_range = pos.saturating_sub(3)..(pos + 4).min(slice2.len());
 
-                for (chunk1, chunk2) in chars1_iter.zip(chars2_iter) {
-                    let result = chunk1.valid().cmp(chunk2.valid());
+            let chars1_iter = slice1[slice1_range].utf8_chunks();
+            let chars2_iter = slice2[slice2_range].utf8_chunks();
 
-                    if result == Ordering::Greater {
-                        return PStrSegmentCmpResult::Greater;
-                    } else if result == Ordering::Less {
-                        return PStrSegmentCmpResult::Less;
-                    }
-                }
+            for (chunk1, chunk2) in chars1_iter.zip(chars2_iter) {
+                let result = chunk1.valid().cmp(chunk2.valid());
 
-                unreachable!()
+                if result == Ordering::Greater {
+                    return PStrSegmentCmpResult::Greater;
+                } else if result == Ordering::Less {
+                    return PStrSegmentCmpResult::Less;
+                }
             }
-        }
-        None => {
+
             unreachable!()
         }
+    };
+
+    match slice1
+        .iter()
+        .zip(slice2.iter())
+        .position(|(b1, b2)| b1 != b2 || *b1 == 0 || *b2 == 0)
+    {
+        Some(pos) => calculate_result(pos),
+        None => calculate_result(slice1.len().min(slice2.len())),
     }
 }
 
index 989ff489f7ae207829f17e39195c946eb745b197..d76d00f23eaee220b862bafbcac64583658f6132 100644 (file)
@@ -638,14 +638,8 @@ impl Machine {
                         }
                     );
                 }
-                &Instruction::GetPartialString(
-                    Level::Shallow,
-                    ref string,
-                    RegType::Temp(t),
-                    // has_tail,
-                ) => {
+                &Instruction::GetPartialString(Level::Shallow, ref string, RegType::Temp(t)) => {
                     use crate::machine::partial_string::HeapPStrIter;
-
                     let cell = self.deref_register(t);
 
                     read_heap_cell!(cell,
@@ -653,7 +647,7 @@ impl Machine {
                             self.machine_st.heap[0] = cell;
                             let iter = HeapPStrIter::new(&self.machine_st.heap, 0);
 
-                            if iter.compare_pstr_to_string(&string).is_none() {
+                            if iter.compare_pstr_to_string(string).is_none() {
                                 return false;
                             }
 
index 43163ce20c8d2262d4c3e6751aad5fe0e47ec9c5..a8e908a9b6551b9ef88644753ac67d70839c0368 100644 (file)
@@ -60,7 +60,7 @@ impl<'a> HeapPStrIter<'a> {
         self.brent_st.hare
     }
 
-    pub fn compare_pstr_to_string<'b>(self, mut s: &'b str) -> Option<PStrCmpResult<'b>> {
+    pub fn compare_pstr_to_string(self, mut s: &str) -> Option<PStrCmpResult> {
         let mut curr_hare = self.brent_st.hare;
 
         while !s.is_empty() {