]> Repositorios git - scryer-prolog.git/commitdiff
add more variable probing, chunk type labeling
authorMark Thom <[email protected]>
Wed, 2 Nov 2022 03:10:15 +0000 (21:10 -0600)
committerMark <[email protected]>
Fri, 23 Jun 2023 20:11:07 +0000 (14:11 -0600)
src/machine/disjuncts.rs

index a97f08c5c8cdbbe9e0e4867de0f6d2273feba7f6..adbf341cff90c2adc9777a9d6f0d42014bcdd0d2 100644 (file)
@@ -133,8 +133,20 @@ impl DerefMut for BranchMap {
 
 type RootSet = IndexSet<BranchNumber>;
 
+#[derive(Debug, Clone, Copy)]
+enum ChunkType {
+    Head,
+    Mid,
+    Last,
+}
+
 enum TraversalState {
-    BuildDisjunct(usize), // construct a QueryTerm::Branch with number of disjuncts.
+    // construct a QueryTerm::Branch with number of disjuncts, reset
+    // the chunk type to that of the chunk preceding the disjunct.
+    BuildDisjunct(ChunkType, usize),
+    // add the last disjunct to a QueryTerm::Branch, continuing from
+    // where it leaves off.
+    BuildFinalDisjunct(usize),
     BuildIf(usize, Term), // build the P term of P -> Q
     BuildThen(usize, Vec<QueryTerm>), // build the Q term of P -> Q
     BuildNot(usize), // build the P term of \+ P
@@ -144,6 +156,7 @@ enum TraversalState {
     RemoveBranchNum, // remove latest branch number from the root set
     RepBranchNum(BranchNumber), // replace current_branch_number and the latest in the root set
     IncrChunkNum, // increment self.current_chunk_number
+    SetLastChunkType, // consider remaining terms as belonging to a last chunk
 }
 
 impl Term {
@@ -215,6 +228,62 @@ fn merge_branch_seq<Iter: Iterator<Item = BranchInfo>>(branches: Iter) -> Branch
     branch_info
 }
 
+fn flatten_into_disjunct(build_stack: &mut Vec<QueryTerm>, preceding_len: usize) {
+    let iter = build_stack.drain(preceding_len ..);
+
+    if let QueryTerm::Branch(ref mut disjuncts) = &mut build_stack[preceding_len] {
+        disjuncts.push(iter.collect());
+    }
+}
+
+fn term_in_other_chunk(term: &Term) -> Option<bool> {
+    match term {
+        Term::Clause(_, name, terms) => Some(!ClauseType::is_inbuilt(name, terms.len())),
+        Term::Literal(_, Literal::Atom(atom!("!"))) |
+        Term::Literal(_, Literal::Char('!')) => Some(false),
+        Term::Literal(_, Literal::Atom(name)) => Some(!ClauseType::is_inbuilt(name, 0)),
+        Term::Var(..) => Some(true),
+        _ => None,
+    }
+}
+
+// returns true if the insertion of SetLastChunkType was the final push.
+fn insert_set_last_chunk_type(
+    state_stack: &mut Vec<TraversalState>,
+    iter: impl Iterator<Item = TraversalState>,
+) -> bool {
+    let beg = state_stack.len();
+    let mut idx = beg;
+
+    while let Some(traversal_st) = iter.next() {
+        match traversal_st {
+            TraversalState::Term(term) | TraversalState::BuildIf(_, term) => {
+                let mut will_break = false;
+
+                match term_in_other_chunk(&term) {
+                    Some(true) if idx > beg => will_break = true,
+                    Some(_) => idx += 1,
+                    None => will_break = true,
+                }
+
+                if will_break {
+                    state_stack.push(TraversalState::SetLastChunkType);
+                    state_stack.push(traversal_st);
+                    break;
+                } else {
+                    state_stack.push(traversal_st);
+                }
+            }
+            _ => {
+                unreachable!();
+            }
+        }
+    }
+
+    state_stack.extend(iter);
+    idx == state_stack.len()
+}
+
 impl VariableClassifier {
     pub fn new(call_policy: CallPolicy) -> Self {
         Self {
@@ -286,16 +355,16 @@ impl VariableClassifier {
         }
     }
 
-    fn probe_body_term(&mut self, term: &Term) {
-        // true to iterate the root, which may be a variable!
+    fn probe_body_term(&mut self, term: &Term, term_loc: GenContext) {
+        // second arg is true to iterate the root, which may be a variable
         for term_ref in breadth_first_iter(term, true) {
             if let TermRef::Var(_, _, var_name) = term_ref {
-                self.probe_body_var(Var::from(var_name));
+                self.probe_body_var(Var::from(var_name), term_loc);
             }
         }
     }
 
-    fn probe_body_var(&mut self, var_name: Var) {
+    fn probe_body_var(&mut self, var_name: Var, chunk_type: ChunkType) {
         let branch_info_v = self.branch_map.entry(var_name)
             .or_insert_with(|| vec![]);
 
@@ -369,6 +438,7 @@ impl VariableClassifier {
     ) -> Result<Vec<QueryTerm>, CompilationError> {
         let mut state_stack = vec![TraversalState::Term(term)];
         let mut build_stack = vec![];
+        let mut chunk_type  = ChunkType::Head;
 
         while let Some(traversal_st) = state_stack.pop() {
             match traversal_st {
@@ -386,19 +456,26 @@ impl VariableClassifier {
                 }
                 TraversalState::IncrChunkNum => {
                     self.current_chunk_num += 1;
+                    chunk_type = ChunkType::Mid;
                 }
-                TraversalState::BuildDisjunct(preceding_len) => {
-                    let iter = build_stack.drain(preceding_len ..);
-
-                    if let QueryTerm::Branch(ref mut disjuncts) = &mut build_stack[preceding_len] {
-                        disjuncts.push(iter.collect());
-                    }
+                TraversalState::ResetCallPolicy(call_policy) => {
+                    self.call_policy = call_policy;
+                }
+                TraversalState::SetLastChunkType => {
+                    chunk_type = ChunkType::Last;
+                }
+                TraversalState::BuildDisjunct(reset_chunk_type, preceding_len) => {
+                    chunk_type = reset_chunk_type;
+                    flatten_into_disjunct(&mut build_stack, preceding_len);
+                }
+                TraversalState::BuildFinalDisjunct(preceding_len) => {
+                    flatten_into_disjunct(&mut build_stack, preceding_len);
                 }
                 TraversalState::BuildIf(preceding_len, then_term) => {
                     let iter = build_stack.drain(preceding_len ..);
-                    let build_stack_len = build_stack.len();
 
-                    state_stack.push(TraversalState::BuildThen(build_stack_len, iter.collect()));
+                    state_stack.push(TraversalState::BuildThen(preceding_len, iter.collect()));
+                    state_stack.push(TraversalState::Term(then_term));
                 }
                 TraversalState::BuildThen(preceding_len, if_terms) => {
                     let iter = build_stack.drain(preceding_len ..);
@@ -408,20 +485,22 @@ impl VariableClassifier {
                     let iter = build_stack.drain(preceding_len ..);
                     build_stack.push(QueryTerm::Not(iter.collect()));
                 }
-                TraversalState::ResetCallPolicy(call_policy) => {
-                    self.call_policy = call_policy;
-                }
                 TraversalState::Term(term) => {
                     match term {
                         Term::Clause(_, atom!(","), terms) if terms.len() == 2 => {
-                            state_stack.extend(
-                                unfold_by_str(terms[1], atom!(","))
-                                    .into_iter()
-                                    .rev()
-                                    .map(TraversalState::Term),
-                            );
-
-                            state_stack.push(TraversalState::Term(terms[0]));
+                            let iter = unfold_by_str(terms[1], atom!(","))
+                                .into_iter()
+                                .rev()
+                                .chain(std::iter::once(terms[0]))
+                                .map(TraversalState::Term);
+
+                            if let ChunkType::Last = chunk_type {
+                                if !insert_set_last_chunk_type(&mut state_stack, iter) {
+                                    chunk_type = ChunkType::Mid;
+                                }
+                            } else {
+                                state_stack.extend(iter);
+                            }
                         }
                         Term::Clause(_, atom!(";"), terms) if terms.len() == 2 => {
                             let first_branch_num = self.current_branch_num.split();
@@ -442,31 +521,46 @@ impl VariableClassifier {
                             }
 
                             let build_stack_len = build_stack.len();
-
                             build_stack.push(QueryTerm::Branch(vec![]));
-                            state_stack.push(TraversalState::BuildDisjunct(build_stack_len));
 
                             state_stack.push(TraversalState::RepBranchNum(
                                 self.current_branch_num.halve_delta(),
                             ));
 
                             let iter = branches.into_iter().zip(branch_numbers.into_iter());
+                            let final_disjunct_loc = state_stack.len();
 
                             for (term, branch_num) in iter.rev() {
-                                state_stack.push(TraversalState::BuildDisjunct(build_stack_len));
+                                state_stack.push(TraversalState::BuildDisjunct(chunk_type, build_stack_len));
 
                                 state_stack.push(TraversalState::RemoveBranchNum);
                                 state_stack.push(TraversalState::Term(term));
                                 state_stack.push(TraversalState::AddBranchNum(branch_num));
                             }
+
+                            state_stack[final_disjunct_loc] =
+                                TraversalState::BuildFinalDisjunct(build_stack_len);
                         }
                         Term::Clause(_, atom!("->"), mut terms) if terms.len() == 2 => {
                             let then_term = terms.pop().unwrap();
                             let if_term = terms.pop().unwrap();
+
                             let build_stack_len = build_stack.len();
 
-                            state_stack.push(TraversalState::BuildIf(build_stack_len, then_term));
-                            state_stack.push(TraversalState::Term(if_term));
+                            // TODO: insert GetLevelAndUnify between
+                            // the two traversal states and detect
+                            // that as a chunk boundary in
+                            // insert_set_last_chunk_type ??
+
+                            let iter = vec![TraversalState::BuildIf(build_stack_len, then_term),
+                                            TraversalState::Term(if_term)]
+                                .into_iter();
+
+                            if let ChunkType::Last = chunk_type {
+                                if !insert_set_last_chunk_type(&mut state_stack, iter) {
+                                    chunk_type = ChunkType::Mid;
+                                }
+                            }
                         }
                         Term::Clause(_, atom!("\\+"), terms) if terms.len() == 1 => {
                             let build_stack_len = build_stack.len();
@@ -477,8 +571,14 @@ impl VariableClassifier {
                         Term::Clause(_, atom!("$get_level"), terms) if terms.len() == 1 => {
                             state_stack.push(TraversalState::IncrChunkNum);
 
+                            // TODO: need to classify this variable?
                             if let Term::Var(_, ref var) = &terms[0] {
-                                build_stack.push(QueryTerm::GetLevelAndUnify(Cell::default(), var.clone()));
+                                build_stack.push(
+                                    QueryTerm::GetLevelAndUnify(
+                                        Cell::default(),
+                                        var.clone(),
+                                    ),
+                                );
                             } else {
                                 return Err(CompilationError::InadmissibleQueryTerm);
                             }
@@ -514,6 +614,10 @@ impl VariableClassifier {
                                         state_stack.push(TraversalState::IncrChunkNum);
                                     }
 
+                                    for term in terms.iter() {
+                                        self.probe_body_term(term, term_loc);
+                                    }
+
                                     build_stack.push(
                                         qualified_clause_to_query_term(
                                             loader,
@@ -527,6 +631,9 @@ impl VariableClassifier {
                                 (module_name, predicate_name) => {
                                     state_stack.push(TraversalState::IncrChunkNum);
 
+                                    self.probe_body_term(&module_name, term_loc);
+                                    self.probe_body_term(&predicate_name, term_loc);
+
                                     terms.push(module_name);
                                     terms.push(predicate_name);
 
@@ -541,7 +648,11 @@ impl VariableClassifier {
                                 }
                             }
                         }
-                        Term::Clause(cell, atom!("$call_with_inference_counting"), terms) if terms.len() == 2 => {
+                        Term::Clause(cell, atom!("$call_with_inference_counting"), terms) if terms.len() == 1 => {
+                            for term in terms.iter() {
+                                self.probe_body_term(term, term_loc);
+                            }
+
                             state_stack.push(TraversalState::ResetCallPolicy(self.call_policy));
                             state_stack.push(TraversalState::Term(terms[0]));
 
@@ -553,7 +664,7 @@ impl VariableClassifier {
                             }
 
                             for term in terms.iter() {
-                                self.probe_body_term(term);
+                                self.probe_body_term(term, term_loc);
                             }
 
                             build_stack.push(