From 4c46f0e54df9448d832ae0c701d6e47c51605280 Mon Sep 17 00:00:00 2001 From: Mark Thom Date: Wed, 16 Apr 2025 22:20:31 -0700 Subject: [PATCH] correct GetPartialString (#2887) --- src/machine/dispatch.rs | 166 +++++++++++++++++++++++++--------- src/machine/heap.rs | 97 ++++++++++---------- src/machine/mod.rs | 10 +- src/machine/partial_string.rs | 2 +- 4 files changed, 176 insertions(+), 99 deletions(-) diff --git a/src/machine/dispatch.rs b/src/machine/dispatch.rs index 3cb1a8e5..4494b8cd 100644 --- a/src/machine/dispatch.rs +++ b/src/machine/dispatch.rs @@ -2806,62 +2806,142 @@ impl Machine { self.machine_st.p += 1; } &Instruction::GetPartialString(_, ref string, reg) => { - use crate::machine::partial_string::{HeapPStrIter, PStrCmpResult}; + self.machine_st.heap[0] = self.machine_st[reg]; + + let mut h = 0; + let mut string_cursor = string.as_str(); + + while let Some(c) = string_cursor.chars().next() { + read_heap_cell!(self.machine_st.heap[h], + (HeapCellValueTag::PStrLoc, pstr_loc) => { + let heap_slice = &self.machine_st.heap.as_slice()[pstr_loc ..]; + + match compare_pstr_slices(heap_slice, string_cursor.as_bytes()) { + PStrSegmentCmpResult::Continue(v1, v2) => { + // for v2, the value of a TailIndex mustn't ever be read + // since string does not lie in the heap. + match (v1, v2) { + (PStrContinuable::TailIndex(tail_idx), PStrContinuable::TailIndex(_)) => { + self.machine_st.s = HeapPtr::HeapCell(tail_idx + cell_index!(pstr_loc)); + self.machine_st.s_offset = 0; + self.machine_st.mode = MachineMode::Read; + + break; + } + (PStrContinuable::TailIndex(tail_idx), PStrContinuable::PStrOffset(pos)) => { + h = tail_idx + cell_index!(pstr_loc); + string_cursor = &string_cursor[pos ..]; + } + (PStrContinuable::PStrOffset(pos), PStrContinuable::TailIndex(_)) => { + self.machine_st.s = HeapPtr::PStr(pstr_loc); + self.machine_st.s_offset = pos; + self.machine_st.mode = MachineMode::Read; - let deref_v = self.machine_st.deref(self.machine_st[reg]); - let store_v = self.machine_st.store(deref_v); + break; + } + _ => unreachable!(), + } + } + _ => { + self.machine_st.fail = true; + break; + } + } + } + (HeapCellValueTag::Lis, l) => { + let cell = self.machine_st.store(self.machine_st.deref(self.machine_st.heap[l])); - read_heap_cell!(store_v, - (HeapCellValueTag::Str | - HeapCellValueTag::Lis | - HeapCellValueTag::PStrLoc) => { - self.machine_st.heap[0] = store_v; - let heap_pstr_iter = HeapPStrIter::new(&self.machine_st.heap, 0); + if let Some(d) = cell.as_char() { + if c != d { + self.machine_st.fail = true; + break; + } + } else if let Some(r) = cell.as_var() { + self.machine_st.bind(r, char_as_cell!(c)); + } else { + self.machine_st.fail = true; + } - match heap_pstr_iter.compare_pstr_to_string(string) { - Some(PStrCmpResult::CompletePStrMatch { chars_matched, pstr_loc }) => { - self.machine_st.s_offset = chars_matched; - self.machine_st.s = HeapPtr::PStr(pstr_loc); - self.machine_st.mode = MachineMode::Read; + if self.machine_st.fail { + break; + } else { + h = l+1; + string_cursor = &string_cursor[c.len_utf8() ..]; + + if string_cursor.is_empty() { + self.machine_st.s = HeapPtr::HeapCell(h); + self.machine_st.s_offset = 0; + self.machine_st.mode = MachineMode::Read; + } } - Some(PStrCmpResult::PartialPStrMatch { string, var_loc }) => { - let cell = backtrack_on_resource_error!( - self.machine_st, - self.machine_st.heap.allocate_pstr(string) - ); + } + (HeapCellValueTag::Str, s) => { + let cell = self.machine_st.store(self.machine_st.deref(self.machine_st.heap[s+1])); - self.machine_st.mode = MachineMode::Write; - unify!(self.machine_st, cell, heap_loc_as_cell!(var_loc)); + if let Some(d) = cell.as_char() { + if c != d { + self.machine_st.fail = true; + break; + } + } else if let Some(r) = cell.as_var() { + self.machine_st.bind(r, char_as_cell!(c)); + } else { + self.machine_st.fail = true; } - Some(PStrCmpResult::ListMatch { list_loc }) => { + + if self.machine_st.fail { + break; + } + + h = s+2; + string_cursor = &string_cursor[c.len_utf8() ..]; + + if string_cursor.is_empty() { + self.machine_st.s = HeapPtr::HeapCell(h); self.machine_st.s_offset = 0; - self.machine_st.s = HeapPtr::HeapCell(list_loc); self.machine_st.mode = MachineMode::Read; } - None => { - self.machine_st.fail = true; + } + (HeapCellValueTag::AttrVar | HeapCellValueTag::Var, v) => { + if h == v { + let target_cell = backtrack_on_resource_error!( + self.machine_st, + self.machine_st.heap.allocate_pstr(string_cursor) + ); + + self.machine_st.bind( + self.machine_st.heap[h].as_var().unwrap(), + target_cell, + ); + + self.machine_st.mode = MachineMode::Write; + break; + } else { + h = v; } } - } - (HeapCellValueTag::AttrVar | - HeapCellValueTag::StackVar | - HeapCellValueTag::Var) => { - let target_cell = backtrack_on_resource_error!( - self.machine_st, - self.machine_st.heap.allocate_pstr(string) - ); + (HeapCellValueTag::StackVar, s) => { + debug_assert_eq!(h, 0); - self.machine_st.bind( - store_v.as_var().unwrap(), - target_cell, - ); + let target_cell = backtrack_on_resource_error!( + self.machine_st, + self.machine_st.heap.allocate_pstr(string_cursor) + ); - self.machine_st.mode = MachineMode::Write; - } - _ => { - self.machine_st.fail = true; - } - ); + self.machine_st.bind( + Ref::stack_cell(s), + target_cell, + ); + + self.machine_st.mode = MachineMode::Write; + break; + } + _ => { + self.machine_st.fail = true; + break; + } + ); + } step_or_fail!(self, self.machine_st.p += 1); } diff --git a/src/machine/heap.rs b/src/machine/heap.rs index fa554437..f682a1a0 100644 --- a/src/machine/heap.rs +++ b/src/machine/heap.rs @@ -135,69 +135,72 @@ pub(crate) enum PStrSegmentCmpResult { } pub(crate) fn compare_pstr_slices(slice1: &[u8], slice2: &[u8]) -> PStrSegmentCmpResult { - use std::cmp::Ordering; - debug_assert!(!slice1.is_empty() && !slice2.is_empty()); let find_tail = |slice| unsafe { scan_slice_to_str(slice).tail_idx }; - match slice1 - .iter() - .zip(slice2.iter()) - .position(|(b1, b2)| b1 != b2 || *b1 == 0 || *b2 == 0) - { - Some(pos) => { - if slice1[pos] == 0 { - // subtract 1 from pos to offset the increment of scan_slice_to_str if the - // string is "\0\". - let tail1_idx = find_tail(&slice1[pos..]); - - if slice2[pos] == 0 { - let tail2_idx = find_tail(&slice2[pos..]); - - PStrSegmentCmpResult::Continue( - PStrContinuable::TailIndex(tail1_idx + cell_index!(pos)), - PStrContinuable::TailIndex(tail2_idx + cell_index!(pos)), - ) - } else { - PStrSegmentCmpResult::Continue( - PStrContinuable::TailIndex(tail1_idx + cell_index!(pos)), - PStrContinuable::PStrOffset(pos), - ) - } - } else if slice2[pos] == 0 { + let calculate_result = |pos| { + use std::cmp::Ordering; + + if slice1.get(pos).cloned().unwrap_or(0) == 0 { + // subtract 1 from pos to offset the increment of scan_slice_to_str if the + // string is "\0\". + let tail1_idx = find_tail(&slice1[pos..]); + let offset_pos_1 = (ALIGN - slice1.as_ptr().align_offset(ALIGN)) % ALIGN; + + if slice2.get(pos).cloned().unwrap_or(0) == 0 { let tail2_idx = find_tail(&slice2[pos..]); + let offset_pos_2 = (ALIGN - slice2.as_ptr().align_offset(ALIGN)) % ALIGN; PStrSegmentCmpResult::Continue( - PStrContinuable::PStrOffset(pos), - PStrContinuable::TailIndex(tail2_idx + cell_index!(pos)), + PStrContinuable::TailIndex(tail1_idx + cell_index!(pos + offset_pos_1)), + PStrContinuable::TailIndex(tail2_idx + cell_index!(pos + offset_pos_2)), ) } else { - // Compute 7-byte chunks with the mismatching character at pos in the middle of - // each. This way, the character of which the byte at pos is a part will be - // validated and reached eventually by the utf8_chunks() iterator. + PStrSegmentCmpResult::Continue( + PStrContinuable::TailIndex(tail1_idx + cell_index!(pos)), + PStrContinuable::PStrOffset(pos), + ) + } + } else if slice2.get(pos).cloned().unwrap_or(0) == 0 { + let tail2_idx = find_tail(&slice2[pos..]); + let offset_pos_2 = (ALIGN - slice2.as_ptr().align_offset(ALIGN)) % ALIGN; - let slice1_range = pos.saturating_sub(3)..(pos + 4).min(slice1.len()); - let slice2_range = pos.saturating_sub(3)..(pos + 4).min(slice2.len()); + PStrSegmentCmpResult::Continue( + PStrContinuable::PStrOffset(pos), + PStrContinuable::TailIndex(tail2_idx + cell_index!(pos + offset_pos_2)), + ) + } else { + // Compute 7-byte chunks with the mismatching character at pos in the middle of + // each. This way, the character of which the byte at pos is a part will be + // validated and reached eventually by the utf8_chunks() iterator. - let chars1_iter = slice1[slice1_range].utf8_chunks(); - let chars2_iter = slice2[slice2_range].utf8_chunks(); + let slice1_range = pos.saturating_sub(3)..(pos + 4).min(slice1.len()); + let slice2_range = pos.saturating_sub(3)..(pos + 4).min(slice2.len()); - for (chunk1, chunk2) in chars1_iter.zip(chars2_iter) { - let result = chunk1.valid().cmp(chunk2.valid()); + let chars1_iter = slice1[slice1_range].utf8_chunks(); + let chars2_iter = slice2[slice2_range].utf8_chunks(); - if result == Ordering::Greater { - return PStrSegmentCmpResult::Greater; - } else if result == Ordering::Less { - return PStrSegmentCmpResult::Less; - } - } + for (chunk1, chunk2) in chars1_iter.zip(chars2_iter) { + let result = chunk1.valid().cmp(chunk2.valid()); - unreachable!() + if result == Ordering::Greater { + return PStrSegmentCmpResult::Greater; + } else if result == Ordering::Less { + return PStrSegmentCmpResult::Less; + } } - } - None => { + unreachable!() } + }; + + match slice1 + .iter() + .zip(slice2.iter()) + .position(|(b1, b2)| b1 != b2 || *b1 == 0 || *b2 == 0) + { + Some(pos) => calculate_result(pos), + None => calculate_result(slice1.len().min(slice2.len())), } } diff --git a/src/machine/mod.rs b/src/machine/mod.rs index c0dcb6e4..674dda42 100644 --- a/src/machine/mod.rs +++ b/src/machine/mod.rs @@ -618,14 +618,8 @@ impl Machine { } ); } - &Instruction::GetPartialString( - Level::Shallow, - ref string, - RegType::Temp(t), - // has_tail, - ) => { + &Instruction::GetPartialString(Level::Shallow, ref string, RegType::Temp(t)) => { use crate::machine::partial_string::HeapPStrIter; - let cell = self.deref_register(t); read_heap_cell!(cell, @@ -633,7 +627,7 @@ impl Machine { self.machine_st.heap[0] = cell; let iter = HeapPStrIter::new(&self.machine_st.heap, 0); - if iter.compare_pstr_to_string(&string).is_none() { + if iter.compare_pstr_to_string(string).is_none() { return false; } diff --git a/src/machine/partial_string.rs b/src/machine/partial_string.rs index 43163ce2..a8e908a9 100644 --- a/src/machine/partial_string.rs +++ b/src/machine/partial_string.rs @@ -60,7 +60,7 @@ impl<'a> HeapPStrIter<'a> { self.brent_st.hare } - pub fn compare_pstr_to_string<'b>(self, mut s: &'b str) -> Option> { + pub fn compare_pstr_to_string(self, mut s: &str) -> Option { let mut curr_hare = self.brent_st.hare; while !s.is_empty() { -- 2.54.0