From 09448de139e50357fec43bcc2c517d0ea58d5c65 Mon Sep 17 00:00:00 2001 From: ecoricemon Date: Sun, 17 May 2026 12:18:51 +0900 Subject: [PATCH 1/3] feat: Make query callable on multiple threads --- crates/logic-eval/Cargo.toml | 7 + crates/logic-eval/README.md | 3 +- crates/logic-eval/benches/query_threads.rs | 154 +++++++++ crates/logic-eval/examples/access_control.rs | 3 +- .../logic-eval/examples/dependency_graph.rs | 3 +- .../examples/meal_recommendation.rs | 3 +- crates/logic-eval/src/prove/common.rs | 33 ++ crates/logic-eval/src/prove/db.rs | 293 +++++++--------- crates/logic-eval/src/prove/prover.rs | 320 ++++++++---------- crates/logic-eval/src/prove/repr.rs | 101 +++--- 10 files changed, 492 insertions(+), 428 deletions(-) create mode 100644 crates/logic-eval/benches/query_threads.rs diff --git a/crates/logic-eval/Cargo.toml b/crates/logic-eval/Cargo.toml index 01e09e2..2340f4e 100644 --- a/crates/logic-eval/Cargo.toml +++ b/crates/logic-eval/Cargo.toml @@ -16,6 +16,13 @@ indexmap = { workspace = true } fxhash = { workspace = true } smallvec = { workspace = true } +[dev-dependencies] +criterion = "0.4" + +[[bench]] +name = "query_threads" +harness = false + [[example]] name = "meal_recommendation" path = "examples/meal_recommendation.rs" diff --git a/crates/logic-eval/README.md b/crates/logic-eval/README.md index dd52e14..062887f 100644 --- a/crates/logic-eval/README.md +++ b/crates/logic-eval/README.md @@ -47,7 +47,7 @@ Replace `src/main.rs` with: use logic_eval::{parse_str, Database, StrInterner}; fn main() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = StrInterner::new(); let pantry = r#" @@ -61,7 +61,6 @@ fn main() { "#; db.insert_dataset(parse_str(pantry, &interner).unwrap()); - db.commit(); let query = parse_str("can_make($Meal).", &interner).unwrap(); let mut results = db.query(query); diff --git a/crates/logic-eval/benches/query_threads.rs b/crates/logic-eval/benches/query_threads.rs new file mode 100644 index 0000000..fc53e00 --- /dev/null +++ b/crates/logic-eval/benches/query_threads.rs @@ -0,0 +1,154 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use logic_eval::{Atom, Clause, ClauseDataset, Database, Expr, Term}; + +const NODES: usize = 240; +const QUERY_COUNT: usize = 64; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct BenchAtom(String); + +impl Atom for BenchAtom { + fn is_variable(&self) -> bool { + self.0.starts_with('$') + } +} + +fn sym(name: impl Into) -> BenchAtom { + BenchAtom(name.into()) +} + +fn atom(name: impl Into) -> Term { + Term::atom(sym(name)) +} + +fn compound( + name: impl Into, + args: impl IntoIterator>, +) -> Term { + Term::compound(sym(name), args) +} + +fn parent(from: usize, to: usize) -> Clause { + Clause::fact(compound("parent", [atom(node(from)), atom(node(to))])) +} + +fn node(index: usize) -> String { + format!("n{index}") +} + +fn build_database() -> Database { + let mut clauses = Vec::new(); + + // A chain makes each ancestry query produce many answers. A few skip edges add branching so the + // engine does enough independent work for parallel throughput to show up in the benchmark. + for i in 0..NODES - 1 { + clauses.push(parent(i, i + 1)); + } + for i in 0..NODES - 3 { + if i % 7 == 0 { + clauses.push(parent(i, i + 3)); + } + } + + clauses.push(Clause::rule( + compound("ancestor", [atom("$X"), atom("$Y")]), + Expr::term(compound("parent", [atom("$X"), atom("$Y")])), + )); + clauses.push(Clause::rule( + compound("ancestor", [atom("$X"), atom("$Z")]), + Expr::expr_and([ + Expr::term(compound("parent", [atom("$X"), atom("$Y")])), + Expr::term(compound("ancestor", [atom("$Y"), atom("$Z")])), + ]), + )); + + let mut db = Database::default(); + db.insert_dataset(ClauseDataset(clauses)); + db +} + +fn build_queries() -> Vec> { + (0..QUERY_COUNT) + .map(|i| { + let root = i % (NODES / 3); + Expr::term(compound("ancestor", [atom(node(root)), atom("$Who")])) + }) + .collect() +} + +fn count_answers(db: &Database, query: Expr) -> usize { + let mut cx = db.query(query); + let mut count = 0; + while let Some(answer) = cx.prove_next() { + count += answer.count(); + } + count +} + +fn run_serial(db: &Database, queries: &[Expr]) -> usize { + queries + .iter() + .cloned() + .map(|query| count_answers(db, query)) + .sum() +} + +fn run_threaded(db: &Database, queries: &[Expr], threads: usize) -> usize { + let chunk_len = (queries.len() + threads - 1) / threads; + + std::thread::scope(|scope| { + let handles = queries + .chunks(chunk_len) + .map(|chunk| { + scope.spawn(move || { + chunk + .iter() + .cloned() + .map(|query| count_answers(db, query)) + .sum::() + }) + }) + .collect::>(); + + handles + .into_iter() + .map(|handle| handle.join().unwrap()) + .sum() + }) +} + +fn benchmark_query_threads(c: &mut Criterion) { + let db = build_database(); + let queries = build_queries(); + let expected = run_serial(&db, &queries); + + let mut group = c.benchmark_group("logic_eval_query_threads"); + group.sample_size(10); + + group.bench_with_input(BenchmarkId::new("threads", 1), &1, |b, _| { + b.iter(|| { + let count = run_serial(&db, &queries); + assert_eq!(count, expected); + black_box(count) + }); + }); + + for threads in [2, 4] { + group.bench_with_input( + BenchmarkId::new("threads", threads), + &threads, + |b, &threads| { + b.iter(|| { + let count = run_threaded(&db, &queries, threads); + assert_eq!(count, expected); + black_box(count) + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, benchmark_query_threads); +criterion_main!(benches); diff --git a/crates/logic-eval/examples/access_control.rs b/crates/logic-eval/examples/access_control.rs index f513d54..e931bf1 100644 --- a/crates/logic-eval/examples/access_control.rs +++ b/crates/logic-eval/examples/access_control.rs @@ -6,7 +6,7 @@ use logic_eval::{parse_str, Database, StrInterner}; fn main() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = StrInterner::new(); // Facts describe users, roles, and documents. @@ -27,7 +27,6 @@ fn main() { "#; db.insert_dataset(parse_str(policy, &interner).unwrap()); - db.commit(); println!("Who can read the handbook?"); { diff --git a/crates/logic-eval/examples/dependency_graph.rs b/crates/logic-eval/examples/dependency_graph.rs index 33da001..8ef76a1 100644 --- a/crates/logic-eval/examples/dependency_graph.rs +++ b/crates/logic-eval/examples/dependency_graph.rs @@ -5,7 +5,7 @@ use logic_eval::{parse_str, Database, StrInterner}; fn main() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = StrInterner::new(); // Direct dependencies are facts. The recursive rule finds indirect dependencies. @@ -24,7 +24,6 @@ fn main() { "#; db.insert_dataset(parse_str(graph, &interner).unwrap()); - db.commit(); println!("app requires:"); { diff --git a/crates/logic-eval/examples/meal_recommendation.rs b/crates/logic-eval/examples/meal_recommendation.rs index e411cf9..a1c6b87 100644 --- a/crates/logic-eval/examples/meal_recommendation.rs +++ b/crates/logic-eval/examples/meal_recommendation.rs @@ -6,7 +6,7 @@ use logic_eval::{parse_str, Database, StrInterner}; fn main() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = StrInterner::new(); // Facts describe what is true. Rules describe what can be inferred. @@ -22,7 +22,6 @@ fn main() { // Load the facts and rules into the database before asking questions. db.insert_dataset(parse_str(pantry, &interner).unwrap()); - db.commit(); // $Meal is a variable. logic-eval will find every value that works. let query = parse_str("can_make($Meal).", &interner).unwrap(); diff --git a/crates/logic-eval/src/prove/common.rs b/crates/logic-eval/src/prove/common.rs index eca3065..3d3eae2 100644 --- a/crates/logic-eval/src/prove/common.rs +++ b/crates/logic-eval/src/prove/common.rs @@ -31,3 +31,36 @@ impl<'int> Atom for Name> { self.starts_with(VAR_PREFIX) } } + +impl Atom for String { + fn is_variable(&self) -> bool { + self.starts_with(VAR_PREFIX) + } +} + +impl Atom for Box +where + T: Atom + AsRef, +{ + fn is_variable(&self) -> bool { + (**self).as_ref().starts_with(VAR_PREFIX) + } +} + +impl Atom for std::rc::Rc +where + T: Atom + AsRef, +{ + fn is_variable(&self) -> bool { + (**self).as_ref().starts_with(VAR_PREFIX) + } +} + +impl Atom for std::sync::Arc +where + T: Atom + AsRef, +{ + fn is_variable(&self) -> bool { + (**self).as_ref().starts_with(VAR_PREFIX) + } +} diff --git a/crates/logic-eval/src/prove/db.rs b/crates/logic-eval/src/prove/db.rs index bb089c7..049cdb6 100644 --- a/crates/logic-eval/src/prove/db.rs +++ b/crates/logic-eval/src/prove/db.rs @@ -1,9 +1,9 @@ use super::{ prover::{ format::{NamedExprView, NamedTermView}, - Integer, NameIntMap, NameIntMapState, ProveCx, Prover, + Integer, NameIntMap, ProveCx, Prover, }, - repr::{ClauseId, TermStorage, TermStorageLen}, + repr::{ClauseId, TermStorage}, }; use crate::{ parse::{ @@ -26,52 +26,20 @@ pub struct Database { /// Predicates that should be handled by tabling. table_clauses: IndexSet>, - /// We do not allow duplicate clauses in the dataset. - dup_checker: DuplicateClauseChecker, - /// Term and expression storage. stor: TermStorage, - /// Proof search engine. - prover: Prover, - /// Mappings between `T` and [`Integer`]. /// /// [`Integer`] is used internally for fast comparison, but clients need values mapped back to /// `T`. nimap: NameIntMap, - /// States of the database fields. - /// - /// Used when discarding uncommitted database changes. - revert_point: Option, + /// We do not allow duplicate clauses in the dataset. + dup_checker: DuplicateClauseChecker, } impl Database { - /// Creates an empty database. - /// - /// # Examples - /// - /// ``` - /// use logic_eval::{Database, InternedStr, Name}; - /// - /// type Db<'a> = Database>>; - /// - /// let db: Db<'_> = Database::new(); - /// assert_eq!(db.clauses().count(), 0); - /// ``` - pub fn new() -> Self { - Self { - clauses: IndexMap::default(), - table_clauses: IndexSet::default(), - dup_checker: DuplicateClauseChecker::default(), - stor: TermStorage::new(), - prover: Prover::new(), - nimap: NameIntMap::new(), - revert_point: None, - } - } - /// Iterates over all terms stored in the database. /// /// # Examples @@ -81,9 +49,8 @@ impl Database { /// /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("sunny.", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_clause(clause); - /// db.commit(); /// /// let terms = db.terms().map(|term| term.to_string()).collect::>(); /// assert_eq!(terms, vec!["sunny"]); @@ -104,9 +71,8 @@ impl Database { /// /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("sunny.", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_clause(clause); - /// db.commit(); /// /// let clauses = db.clauses().map(|clause| clause.to_string()).collect::>(); /// assert_eq!(clauses, vec!["sunny."]); @@ -130,10 +96,9 @@ impl Database { /// /// let interner = StrInterner::new(); /// let dataset: ClauseDataset<_> = parse_str("sunny.\nwarm.", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// /// db.insert_dataset(dataset); - /// db.commit(); /// /// assert_eq!(db.clauses().count(), 2); /// ``` @@ -153,19 +118,13 @@ impl Database { /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("sunny.", &interner).unwrap(); /// let query: Expr<_> = parse_str("sunny", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// /// db.insert_clause(clause); - /// db.commit(); /// /// assert!(db.query(query).is_true()); /// ``` pub fn insert_clause(&mut self, clause: Clause) { - // Saves current state. We will revert DB when the change is not committed. - if self.revert_point.is_none() { - self.revert_point = Some(self.state()); - } - let clause = clause.map(&mut |t| self.nimap.name_to_int(t)); // Records whether the clause needs tabling. @@ -194,7 +153,7 @@ impl Database { .or_insert(vec![value]); } - /// Starts a query against the committed database. + /// Starts a query against the database. /// /// # Examples /// @@ -204,9 +163,8 @@ impl Database { /// let interner = StrInterner::new(); /// let dataset: ClauseDataset<_> = parse_str("parent(alice, bob).", &interner).unwrap(); /// let query: Expr<_> = parse_str("parent(alice, $Who)", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_dataset(dataset); - /// db.commit(); /// /// let mut cx = db.query(query); /// let answer = cx.prove_next().unwrap().next().unwrap(); @@ -214,42 +172,16 @@ impl Database { /// assert_eq!(answer.get_lhs_variable().as_ref(), "$Who"); /// assert_eq!(answer.rhs().to_string(), "bob"); /// ``` - pub fn query(&mut self, expr: Expr) -> ProveCx<'_, T> { - // Discards uncommitted changes. - if let Some(revert_point) = self.revert_point.take() { - self.revert(revert_point); - } - - self.prover.prove( + pub fn query(&self, expr: Expr) -> ProveCx<'_, T> { + Prover::new().prove( expr, &self.clauses, &self.table_clauses, - &mut self.stor, - &mut self.nimap, + &self.stor, + &self.nimap, ) } - /// Commits pending clause insertions. - /// - /// # Examples - /// - /// ``` - /// use logic_eval::{parse_str, Clause, Database, Expr, StrInterner}; - /// - /// let interner = StrInterner::new(); - /// let clause: Clause<_> = parse_str("sunny.", &interner).unwrap(); - /// let query: Expr<_> = parse_str("sunny", &interner).unwrap(); - /// let mut db = Database::new(); - /// - /// db.insert_clause(clause); - /// db.commit(); - /// - /// assert!(db.query(query).is_true()); - /// ``` - pub fn commit(&mut self) { - self.revert_point.take(); - } - /// * sanitize - Removes unacceptable characters from prolog. /// /// Requires T to implement [`AsRef`] so that functor names can be serialized into Prolog @@ -262,9 +194,8 @@ impl Database { /// /// let interner = StrInterner::new(); /// let dataset: ClauseDataset<_> = parse_str("parent(Alice, Bob).", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_dataset(dataset); - /// db.commit(); /// /// let prolog = db.to_prolog(|name| name); /// assert_eq!(prolog, "parent(alice, bob).\n"); @@ -440,39 +371,17 @@ impl Database { } } } - - fn revert( - &mut self, - DatabaseState { - clauses_len, - clause_set_len, - stor_len, - nimap_state, - }: DatabaseState, - ) { - self.clauses.truncate(clauses_len.len()); - for (i, len) in clauses_len.into_iter().enumerate() { - self.clauses[i].truncate(len); - } - self.dup_checker.truncate(clause_set_len); - self.stor.truncate(stor_len); - self.nimap.revert(nimap_state); - // `self.prover: Prover` does not store any persistent data. - } - - fn state(&self) -> DatabaseState { - DatabaseState { - clauses_len: self.clauses.values().map(|v| v.len()).collect(), - clause_set_len: self.dup_checker.len(), - stor_len: self.stor.len(), - nimap_state: self.nimap.state(), - } - } } -impl Default for Database { +impl Default for Database { fn default() -> Self { - Self::new() + Self { + clauses: IndexMap::default(), + table_clauses: IndexSet::default(), + stor: TermStorage::default(), + nimap: NameIntMap::default(), + dup_checker: DuplicateClauseChecker::default(), + } } } @@ -480,22 +389,14 @@ impl Debug for Database { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Database") .field("clauses", &self.clauses) - .field("dup_checker", &self.dup_checker) + .field("table_clauses", &self.table_clauses) .field("stor", &self.stor) .field("nimap", &self.nimap) - .field("revert_point", &self.revert_point) + .field("dup_checker", &self.dup_checker) .finish_non_exhaustive() } } -#[derive(Debug, PartialEq, Eq)] -struct DatabaseState { - clauses_len: Vec, - clause_set_len: usize, - stor_len: TermStorageLen, - nimap_state: NameIntMapState, -} - #[derive(Debug, Default)] struct DuplicateClauseChecker { seen: IndexSet>, @@ -523,14 +424,6 @@ impl DuplicateClauseChecker { self.vars.clear(); is_new } - - fn len(&self) -> usize { - self.seen.len() - } - - fn truncate(&mut self, len: usize) { - self.seen.truncate(len); - } } /// Rewrites variables using the values produced by `canonical_var`. @@ -645,9 +538,8 @@ impl<'a, T: Atom> ClauseRef<'a, T> { /// /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("outdoors :- sunny.", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_clause(clause); - /// db.commit(); /// /// let clause = db.clauses().next().unwrap(); /// assert_eq!(clause.head().to_string(), "outdoors"); @@ -666,9 +558,8 @@ impl<'a, T: Atom> ClauseRef<'a, T> { /// /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("outdoors :- sunny.", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_clause(clause); - /// db.commit(); /// /// let clause = db.clauses().next().unwrap(); /// assert_eq!(clause.body().unwrap().to_string(), "sunny"); @@ -728,7 +619,7 @@ impl<'a, T: Atom> Iterator for NamedTermViewIter<'a, T> { impl FusedIterator for NamedTermViewIter<'_, T> {} #[cfg(test)] -mod str_atom_tests { +mod tests { use crate::{parse, NameIn}; type Interner = any_intern::DroplessInterner; @@ -750,7 +641,7 @@ mod str_atom_tests { assert_eq!(answer, expected); } - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); for _ in 0..2 { @@ -763,15 +654,15 @@ mod str_atom_tests { g($X) :- f($X). ", ); - let len = db.stor.len(); + let len = db.terms().count(); assert_query(&mut db, &interner); - assert_eq!(db.stor.len(), len); + assert_eq!(db.terms().count(), len); } } #[test] fn test_not_expression() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -794,7 +685,7 @@ mod str_atom_tests { #[test] fn test_and_expression() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -816,7 +707,7 @@ mod str_atom_tests { #[test] fn test_or_expression() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -837,7 +728,7 @@ mod str_atom_tests { #[test] fn test_mixed_expression() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -865,7 +756,7 @@ mod str_atom_tests { #[test] fn test_and_has_higher_precedence_than_or() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -898,7 +789,7 @@ mod str_atom_tests { #[test] fn test_parentheses_override_and_or_precedence() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -924,7 +815,7 @@ mod str_atom_tests { #[test] fn test_not_applies_to_parenthesized_or() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -954,7 +845,7 @@ mod str_atom_tests { #[test] fn test_not_with_grouped_and_or_expression() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -996,7 +887,7 @@ mod str_atom_tests { #[test] fn test_simple_recursion() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -1032,7 +923,7 @@ mod str_atom_tests { #[test] fn test_right_recursion() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -1067,7 +958,7 @@ mod str_atom_tests { // SLG resolution (tabling) is required to pass this test. #[test] fn test_mid_recursion() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -1105,7 +996,7 @@ mod str_atom_tests { // SLG resolution (tabling) is required to pass this test. #[test] fn test_left_recursion() { - let mut db = Database::new(); + let mut db = Database::default(); let interner = Interner::new(); insert_dataset( @@ -1138,25 +1029,76 @@ mod str_atom_tests { } #[test] - fn test_discarding_uncomitted_change() { - let mut db = Database::new(); + fn test_inserted_clause_is_immediately_visible_to_query() { + let mut db = Database::default(); let interner = Interner::new(); let clause: Clause<'_> = parse::parse_str("f(a).", &interner).unwrap(); db.insert_clause(clause); - let fa_state = db.state(); - db.commit(); let clause: Clause<'_> = parse::parse_str("f(b).", &interner).unwrap(); db.insert_clause(clause); + assert_eq!(db.clauses().count(), 2); let query: Expr<'_> = parse::parse_str("f($X).", &interner).unwrap(); - let answer = collect_answer(db.query(query)); - - // `f(b).` was discarded. - let expected = [["$X = a"]]; + let mut answer = collect_answer(db.query(query)); + let mut expected = [["$X = a"], ["$X = b"]]; + answer.sort_unstable(); + expected.sort_unstable(); assert_eq!(answer, expected); - assert_eq!(db.state(), fa_state); + } + + #[test] + fn test_database_and_query_context_are_send_sync() { + fn assert_send_sync() {} + assert_send_sync::>(); + } + + #[test] + fn test_query_from_multiple_threads() { + let mut db = Database::default(); + let interner = Interner::new(); + + insert_dataset( + &mut db, + &interner, + r" + parent(alice, bob). + parent(alice, carol). + parent(carol, dave). + + ancestor($X, $Y) :- parent($X, $Y). + ancestor($X, $Z) :- parent($X, $Y), ancestor($Y, $Z). + ", + ); + + let barrier = std::sync::Barrier::new(4); + std::thread::scope(|scope| { + let mut handles = Vec::new(); + for _ in 0..4 { + let db = &db; + let barrier = &barrier; + handles.push(scope.spawn(|| { + barrier.wait(); + + let query: Expr<'_> = + parse::parse_str("ancestor(alice, $Who).", &interner).unwrap(); + + let mut answers = collect_answer(db.query(query)) + .into_iter() + .collect::>(); + answers.sort_unstable(); + answers + })); + } + + for handle in handles { + assert_eq!( + handle.join().unwrap(), + [["$Who = bob"], ["$Who = carol"], ["$Who = dave"]] + ); + } + }); } // === Test helper functions === @@ -1164,7 +1106,6 @@ mod str_atom_tests { fn insert_dataset<'int>(db: &mut Database<'int>, interner: &'int Interner, text: &str) { let dataset: ClauseDataset<'int> = parse::parse_str(text, interner).unwrap(); db.insert_dataset(dataset); - db.commit(); } fn collect_answer(mut cx: ProveCx<'_, '_>) -> Vec> { @@ -1178,8 +1119,8 @@ mod str_atom_tests { } #[cfg(test)] -mod tests { - use crate::{Atom, Clause, ClauseDataset, Database, Expr, ProveCx, Term}; +mod custom_atom_tests { + use crate::{Atom, Clause, Database, Expr, ProveCx, Term}; #[test] fn test_custom_atom() { @@ -1203,7 +1144,7 @@ mod tests { } } - let mut db = Database::new(); + let mut db = Database::default(); let child_a_b = Clause::fact(Term::compound( A::child, @@ -1228,16 +1169,13 @@ mod tests { Expr::term_compound(A::descend, [Term::atom(A::Y), Term::atom(A::Z)]), ]), ); - insert_dataset( - &mut db, - crate::ClauseDataset(vec![ - child_a_b, - child_b_c, - child_c_d, - descend_x_y, - descend_x_z, - ]), - ); + db.insert_dataset(crate::ClauseDataset(vec![ + child_a_b, + child_b_c, + child_c_d, + descend_x_y, + descend_x_z, + ])); let query = Expr::term_compound(A::descend, [Term::atom(A::X), Term::atom(A::Y)]); let mut answer = collect_answer(db.query(query)); @@ -1276,11 +1214,6 @@ mod tests { // === Test helper functions === - fn insert_dataset(db: &mut Database, dataset: ClauseDataset) { - db.insert_dataset(dataset); - db.commit(); - } - fn collect_answer(mut cx: ProveCx<'_, T>) -> Vec, Term)>> { let mut v = Vec::new(); while let Some(eval) = cx.prove_next() { diff --git a/crates/logic-eval/src/prove/prover.rs b/crates/logic-eval/src/prove/prover.rs index 8346c27..e7c983a 100644 --- a/crates/logic-eval/src/prove/prover.rs +++ b/crates/logic-eval/src/prove/prover.rs @@ -2,7 +2,7 @@ use super::{ canonical, repr::{ ApplyResult, ClauseId, ExprId, ExprKind, ExprView, TermDeepView, TermElem, TermId, - TermStorage, TermStorageLen, TermView, TermViewMut, UniqueTermArray, + TermStorage, TermView, TermViewMut, UniqueTermArray, }, table::Table, }; @@ -55,6 +55,12 @@ pub(crate) struct Prover { /// SLG resolution. table: Table, + + /// Database clauses imported into query-local storage. + /// + /// The cached clause is a per-query template. Each proof use still clones this template before + /// variable freshening, so repeated use of a database rule does not re-import from `db_stor`. + imported_clause_templates: Map, } impl Prover { @@ -70,6 +76,7 @@ impl Prover { temp_var_buf: Map::default(), temp_var_int: 0, table: Table::default(), + imported_clause_templates: Map::default(), } } @@ -81,43 +88,18 @@ impl Prover { self.query_answers.clear(); self.queue.clear(); self.table.clear(); + self.imported_clause_templates.clear(); } pub(crate) fn prove<'a, T: Atom>( - &'a mut self, + self, query: Expr, clauses: &'a IndexMap, Vec>, table_clauses: &'a IndexSet>, - stor: &'a mut TermStorage, - nimap: &'a mut NameIntMap, + db_stor: &'a TermStorage, + db_nimap: &'a NameIntMap, ) -> ProveCx<'a, T> { - self.clear(); - - let old_nimap_state = nimap.state(); - let query = query.map(&mut |name| nimap.name_to_int(name)); - - let old_stor_len = stor.len(); - self.query = stor.insert_expr(query); - - stor.get_expr(self.query) - .with_term(&mut |term: TermView<'_, Integer>| { - term.with_variable(|term| self.query_vars.push(term.id)); - }); - - let node_kind = NodeKind::Expr(self.query); - let node_parent = self.nodes.len(); - self.nodes.push(Node::new(node_kind, node_parent)); - self.queue.push(0); - - ProveCx { - prover: self, - clauses, - table_clauses, - stor, - nimap, - old_stor_len, - old_nimap_state, - } + ProveCx::new(self, query, clauses, table_clauses, db_stor, db_nimap) } /// Evaluates the given node against matching clauses or table answers, then returns whether a @@ -130,6 +112,7 @@ impl Prover { node_index: usize, clauses: &IndexMap, Vec>, table_clauses: &IndexSet>, + db_stor: &TermStorage, stor: &mut TermStorage, ) -> Option { let node_expr = match self.nodes[node_index].kind { @@ -147,8 +130,7 @@ impl Prover { let node_leftmost = stor.get_expr(node_expr).leftmost_term().id; let node_leftmost_pred = stor.get_term(node_leftmost).predicate(); - let mut similar_clauses = &[][..]; - let mut clause_buf: SmallVec<[ClauseId; 1]> = SmallVec::new(); + let mut similar_clauses: SmallVec<[ClauseId; 2]> = SmallVec::new(); // === SLG path === // * Table entry - Created from non-canonical leftmost term of the node. In tabling, @@ -174,11 +156,10 @@ impl Prover { for (var, answer) in vars.into_iter().zip(answers) { term.replace(var, *answer); } - clause_buf.push(ClauseId { + similar_clauses.push(ClauseId { head: term.id(), body: None, }); - similar_clauses = &clause_buf[..]; // More answers? We'll handle them next time. if !entry.answers(next_offset).is_empty() { @@ -197,7 +178,17 @@ impl Prover { if similar_clauses.is_empty() { if let Some(v) = clauses.get(&node_leftmost_pred) { - similar_clauses = v.as_slice() + for &clause in v { + let template = + if let Some(&template) = self.imported_clause_templates.get(&clause) { + template + } else { + let template = stor.import_clause(db_stor, clause); + self.imported_clause_templates.insert(clause, template); + template + }; + similar_clauses.push(template); + } } } @@ -211,7 +202,7 @@ impl Prover { } let clause = Self::convert_var_into_temp( - *clause, + clause, stor, &mut self.temp_var_buf, &mut self.temp_var_int, @@ -223,9 +214,9 @@ impl Prover { } // We may need to apply true or false to the leftmost term of the node expression due to - // unification failure or exhaustive search. + // unification failure or for the exhaustive search. // - Unification failure means the leftmost term should be false. - // - But we need to consider exhaustive search at the same time. + // - But we need to consider the exhaustive search at the same time. let expr = stor.get_expr(node_expr); let eval = self.nodes.len() > old_len; @@ -753,16 +744,51 @@ enum UnifyOp { /// Proof-search context for a query. pub struct ProveCx<'a, T: Atom> { - prover: &'a mut Prover, + prover: Prover, clauses: &'a IndexMap, Vec>, table_clauses: &'a IndexSet>, - stor: &'a mut TermStorage, - nimap: &'a mut NameIntMap, - old_stor_len: TermStorageLen, - old_nimap_state: NameIntMapState, + db_stor: &'a TermStorage, + stor: TermStorage, + nimap: QueryNameIntMap<'a, T>, } impl<'a, T: Atom> ProveCx<'a, T> { + fn new( + mut prover: Prover, + query: Expr, + clauses: &'a IndexMap, Vec>, + table_clauses: &'a IndexSet>, + db_stor: &'a TermStorage, + db_nimap: &'a NameIntMap, + ) -> Self { + prover.clear(); + + let mut nimap = QueryNameIntMap::new(db_nimap); + let query = query.map(&mut |name| nimap.name_to_int(name)); + + let mut stor = TermStorage::default(); + prover.query = stor.insert_expr(query); + + stor.get_expr(prover.query) + .with_term(&mut |term: TermView<'_, Integer>| { + term.with_variable(|term| prover.query_vars.push(term.id)); + }); + + let node_kind = NodeKind::Expr(prover.query); + let node_parent = prover.nodes.len(); + prover.nodes.push(Node::new(node_kind, node_parent)); + prover.queue.push(0); + + Self { + prover, + clauses, + table_clauses, + db_stor, + stor, + nimap, + } + } + /// Returns the next proof result, if one is available. /// /// # Examples @@ -774,9 +800,8 @@ impl<'a, T: Atom> ProveCx<'a, T> { /// let dataset: ClauseDataset<_> = /// parse_str("parent(alice, bob). parent(alice, carol).", &interner).unwrap(); /// let query: Expr<_> = parse_str("parent(alice, $Who)", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_dataset(dataset); - /// db.commit(); /// /// let mut cx = db.query(query); /// let mut answers = Vec::new(); @@ -790,18 +815,21 @@ impl<'a, T: Atom> ProveCx<'a, T> { /// ``` pub fn prove_next(&mut self) -> Option> { while let Some(node_index) = self.prover.queue.pop() { - if let Some(proof_result) = - self.prover - .evaluate_node(node_index, self.clauses, self.table_clauses, self.stor) - { + if let Some(proof_result) = self.prover.evaluate_node( + node_index, + self.clauses, + self.table_clauses, + self.db_stor, + &mut self.stor, + ) { // Return Some(EvalView) only if the result is TRUE and yielded a new ground // query answer. - if proof_result && self.prover.record_query_answer(self.stor) { + if proof_result && self.prover.record_query_answer(&mut self.stor) { return Some(EvalView { query_vars: &self.prover.query_vars, terms: &self.stor.terms.buf, term_assigns: &self.prover.term_assigns, - nimap: self.nimap, + nimap: &self.nimap, start: 0, end: self.prover.query_vars.len(), }); @@ -821,9 +849,8 @@ impl<'a, T: Atom> ProveCx<'a, T> { /// let interner = StrInterner::new(); /// let dataset: ClauseDataset<_> = parse_str("sunny.", &interner).unwrap(); /// let query: Expr<_> = parse_str("sunny", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_dataset(dataset); - /// db.commit(); /// /// assert!(db.query(query).is_true()); /// ``` @@ -832,10 +859,38 @@ impl<'a, T: Atom> ProveCx<'a, T> { } } -impl Drop for ProveCx<'_, T> { - fn drop(&mut self) { - self.stor.truncate(self.old_stor_len.clone()); - self.nimap.revert(self.old_nimap_state.clone()); +/// Name mapping for a query: database names are borrowed, query-only names are local. +struct QueryNameIntMap<'a, T> { + database: &'a NameIntMap, + local: NameIntMap, +} + +impl<'a, T> QueryNameIntMap<'a, T> { + fn new(database: &'a NameIntMap) -> Self { + Self { + database, + local: NameIntMap { + name2int: IndexMap::default(), + int2name: IndexMap::default(), + next_int: database.next_int, + }, + } + } + + fn get_name(&self, int: &Integer) -> Option<&T> { + self.database + .get_name(int) + .or_else(|| self.local.get_name(int)) + } +} + +impl QueryNameIntMap<'_, T> { + fn name_to_int(&mut self, name: T) -> Integer { + if let Some(int) = self.database.name2int.get(&name) { + *int + } else { + self.local.name_to_int(name) + } } } @@ -844,7 +899,7 @@ pub struct EvalView<'a, T> { query_vars: &'a [TermId], terms: &'a [TermElem], term_assigns: &'a TermAssignments, - nimap: &'a NameIntMap, + nimap: &'a QueryNameIntMap<'a, T>, /// Inclusive start: usize, /// Exclusive @@ -895,7 +950,7 @@ pub struct Assignment<'a, T> { buf: &'a [TermElem], from: TermId, term_assigns: &'a TermAssignments, - nimap: &'a NameIntMap, + nimap: &'a QueryNameIntMap<'a, T>, } impl<'a, T: 'a> Assignment<'a, T> { @@ -911,9 +966,8 @@ impl<'a, T: 'a> Assignment<'a, T> { /// let interner = StrInterner::new(); /// let dataset: ClauseDataset<_> = parse_str("parent(alice, bob).", &interner).unwrap(); /// let query: Expr<_> = parse_str("parent(alice, $Who)", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_dataset(dataset); - /// db.commit(); /// /// let mut cx = db.query(query); /// let assignment = cx.prove_next().unwrap().next().unwrap(); @@ -954,9 +1008,8 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { /// let interner = StrInterner::new(); /// let dataset: ClauseDataset<_> = parse_str("parent(alice, bob).", &interner).unwrap(); /// let query: Expr<_> = parse_str("parent(alice, $Who)", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_dataset(dataset); - /// db.commit(); /// /// let mut cx = db.query(query); /// let assignment = cx.prove_next().unwrap().next().unwrap(); @@ -979,9 +1032,8 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { /// let interner = StrInterner::new(); /// let dataset: ClauseDataset<_> = parse_str("parent(alice, bob).", &interner).unwrap(); /// let query: Expr<_> = parse_str("parent(alice, $Who)", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_dataset(dataset); - /// db.commit(); /// /// let mut cx = db.query(query); /// let assignment = cx.prove_next().unwrap().next().unwrap(); @@ -992,7 +1044,7 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { Self::term_deep_view_to_term(self.rhs_view(), self.nimap) } - fn term_view_to_term(view: TermView<'_, Integer>, nimap: &NameIntMap) -> Term { + fn term_view_to_term(view: TermView<'_, Integer>, nimap: &QueryNameIntMap<'_, T>) -> Term { let functor = view.functor(); let args = view.args(); @@ -1010,7 +1062,10 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { Term { functor, args } } - fn term_deep_view_to_term(view: TermDeepView<'_, Integer>, nimap: &NameIntMap) -> Term { + fn term_deep_view_to_term( + view: TermDeepView<'_, Integer>, + nimap: &QueryNameIntMap<'_, T>, + ) -> Term { let functor = view.functor(); let args = view.args(); @@ -1031,24 +1086,19 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { impl Display for Assignment<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let view = format::NamedTermView::new(self.lhs_view(), self.nimap); - Display::fmt(&view, f)?; + Display::fmt(&self.lhs(), f)?; f.write_str(" = ")?; - let view = format::NamedTermDeepView::new(self.rhs_view(), self.nimap); - Display::fmt(&view, f) + Display::fmt(&self.rhs(), f) } } impl Debug for Assignment<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let lhs = format::NamedTermView::new(self.lhs_view(), self.nimap); - let rhs = format::NamedTermDeepView::new(self.rhs_view(), self.nimap); - f.debug_struct("Assignment") - .field("lhs", &lhs) - .field("rhs", &rhs) + .field("lhs", &self.lhs()) + .field("rhs", &self.rhs()) .finish() } } @@ -1199,38 +1249,9 @@ pub(crate) struct NameIntMap { } impl NameIntMap { - pub(crate) fn new() -> Self { - Self { - name2int: IndexMap::default(), - int2name: IndexMap::default(), - next_int: 0, - } - } - pub(crate) fn get_name(&self, int: &Integer) -> Option<&T> { self.int2name.get(int) } - - pub(crate) fn state(&self) -> NameIntMapState { - NameIntMapState { - name2int_len: self.name2int.len(), - int2name_len: self.int2name.len(), - next_int: self.next_int, - } - } - - pub(crate) fn revert( - &mut self, - NameIntMapState { - name2int_len, - int2name_len, - next_int, - }: NameIntMapState, - ) { - self.name2int.truncate(name2int_len); - self.int2name.truncate(int2name_len); - self.next_int = next_int; - } } impl NameIntMap { @@ -1249,11 +1270,14 @@ impl NameIntMap { } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct NameIntMapState { - name2int_len: usize, - int2name_len: usize, - next_int: u32, +impl Default for NameIntMap { + fn default() -> Self { + Self { + name2int: IndexMap::default(), + int2name: IndexMap::default(), + next_int: 0, + } + } } pub(crate) mod format { @@ -1288,9 +1312,8 @@ pub(crate) mod format { /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("parent(alice, bob).", &interner).unwrap(); /// let expected: Term<_> = parse_str("parent(alice, bob)", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_clause(clause); - /// db.commit(); /// /// let term = db.terms().next().unwrap(); /// assert!(term.is(&expected)); @@ -1318,9 +1341,8 @@ pub(crate) mod format { /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("parent(alice, bob).", &interner).unwrap(); /// let expected: Term<_> = parse_str("bob", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_clause(clause); - /// db.commit(); /// /// let term = db.terms().next().unwrap(); /// assert!(term.contains(&expected)); @@ -1388,71 +1410,6 @@ pub(crate) mod format { } } - pub(crate) struct NamedTermDeepView<'a, T> { - view: TermDeepView<'a, Integer>, - nimap: &'a NameIntMap, - } - - impl<'a, T> NamedTermDeepView<'a, T> { - pub(crate) const fn new(view: TermDeepView<'a, Integer>, nimap: &'a NameIntMap) -> Self { - Self { view, nimap } - } - } - - impl<'a, T: Display> Display for NamedTermDeepView<'a, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, nimap } = self; - - let functor = view.functor(); - let args = view.args(); - let num_args = args.len(); - - write_int(functor, nimap, f)?; - - if num_args > 0 { - f.write_char('(')?; - for (i, arg) in args.enumerate() { - fmt::Display::fmt(&Self::new(arg, nimap), f)?; - if i + 1 < num_args { - f.write_str(", ")?; - } - } - f.write_char(')')?; - } - Ok(()) - } - } - - impl<'a, T: Debug> Debug for NamedTermDeepView<'a, T> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, nimap } = self; - - let functor = view.functor(); - let args = view.args(); - let num_args = args.len(); - - if num_args == 0 { - if let Some(name) = nimap.get_name(functor) { - fmt::Debug::fmt(name, f) - } else { - fmt::Debug::fmt(functor, f) - } - } else { - let name_str = if let Some(name) = nimap.get_name(functor) { - format!("{:?}", name) - } else { - format!("{:?}", functor) - }; - let mut d = f.debug_tuple(&name_str); - - for arg in args { - d.field(&Self::new(arg, nimap)); - } - d.finish() - } - } - } - pub struct NamedExprView<'a, T> { view: ExprView<'a, Integer>, nimap: &'a NameIntMap, @@ -1475,9 +1432,8 @@ pub(crate) mod format { /// let interner = StrInterner::new(); /// let clause: Clause<_> = parse_str("outdoors :- sunny, warm.", &interner).unwrap(); /// let expected: Term<_> = parse_str("warm", &interner).unwrap(); - /// let mut db = Database::new(); + /// let mut db = Database::default(); /// db.insert_clause(clause); - /// db.commit(); /// /// let body = db.clauses().next().unwrap().body().unwrap(); /// assert!(body.contains_term(&expected)); diff --git a/crates/logic-eval/src/prove/repr.rs b/crates/logic-eval/src/prove/repr.rs index b10e156..0508d70 100644 --- a/crates/logic-eval/src/prove/repr.rs +++ b/crates/logic-eval/src/prove/repr.rs @@ -18,25 +18,6 @@ pub(crate) struct TermStorage { } impl TermStorage { - pub(crate) fn new() -> Self { - Self { - exprs: ExprArray::new(), - terms: UniqueTermArray::new(), - } - } - - pub(crate) fn len(&self) -> TermStorageLen { - TermStorageLen { - expr_len: self.exprs.len(), - term_len: self.terms.len(), - } - } - - pub(crate) fn truncate(&mut self, len: TermStorageLen) { - self.exprs.truncate(len.expr_len); - self.terms.truncate(len.term_len); - } - pub(crate) fn get_expr(&self, id: ExprId) -> ExprView<'_, T> { self.exprs.get(id, &self.terms.buf) } @@ -62,24 +43,33 @@ impl TermStorage { pub(crate) fn insert_term(&mut self, term: Term) -> TermId { self.terms.insert(term) } + + // TODO: It seems to be a redundant deserialization and serialization. + pub(crate) fn import_clause(&mut self, src: &Self, clause: ClauseId) -> ClauseId { + ClauseId { + head: self.insert_term(src.get_term(clause.head).deserialize()), + body: clause + .body + .map(|expr| self.insert_expr(src.get_expr(expr).deserialize())), + } + } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct TermStorageLen { - expr_len: usize, - term_len: TermArrayLen, +impl Default for TermStorage { + fn default() -> Self { + Self { + exprs: ExprArray::default(), + terms: UniqueTermArray::default(), + } + } } -#[derive(Debug)] +#[derive(Debug, Default)] pub(crate) struct ExprArray { buf: Vec, } impl ExprArray { - const fn new() -> Self { - Self { buf: Vec::new() } - } - fn get<'a, T>(&'a self, id: ExprId, term_buf: &'a [TermElem]) -> ExprView<'a, T> { ExprView { expr_buf: &self.buf, @@ -288,6 +278,17 @@ impl<'a, T> ExprView<'a, T> { } } +impl<'a, T: Clone> ExprView<'a, T> { + pub(crate) fn deserialize(self) -> Expr { + match self.as_kind() { + ExprKind::Term(term) => Expr::Term(term.deserialize()), + ExprKind::Not(inner) => Expr::Not(Box::new(inner.deserialize())), + ExprKind::And(args) => Expr::And(args.map(Self::deserialize).collect()), + ExprKind::Or(args) => Expr::Or(args.map(Self::deserialize).collect()), + } + } +} + pub(crate) enum ExprKind<'a, T> { Term(TermView<'a, T>), Not(ExprView<'a, T>), @@ -710,13 +711,6 @@ pub(crate) struct UniqueTermArray { } impl UniqueTermArray { - fn new() -> Self { - Self { - buf: Vec::new(), - map: PassThroughIndexMap::default(), - } - } - pub(crate) fn terms(&self) -> TermViewIter<'_, T> { TermViewIter { buf: &self.buf, @@ -736,18 +730,6 @@ impl UniqueTermArray { TermViewMut { arr: self, id } } - pub(crate) fn len(&self) -> TermArrayLen { - TermArrayLen { - buf_len: self.buf.len(), - map_len: self.map.len(), - } - } - - pub(crate) fn truncate(&mut self, len: TermArrayLen) { - self.buf.truncate(len.buf_len); - self.map.truncate(len.map_len); - } - fn reserve(&mut self, additional: usize) -> usize { let cur_len = self.buf.len(); self.buf @@ -850,6 +832,15 @@ impl UniqueTermArray { } } +impl Default for UniqueTermArray { + fn default() -> Self { + Self { + buf: Vec::new(), + map: PassThroughIndexMap::default(), + } + } +} + impl AsRef<[TermElem]> for UniqueTermArray { fn as_ref(&self) -> &[TermElem] { &self.buf @@ -904,12 +895,6 @@ impl Iterator for SimilarTerms<'_, T> { impl iter::FusedIterator for SimilarTerms<'_, T> {} -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct TermArrayLen { - buf_len: usize, - map_len: usize, -} - pub(crate) struct TermViewIter<'a, T> { buf: &'a [TermElem], cur: usize, @@ -1513,7 +1498,7 @@ mod tests { #[test] fn test_expr_array_replace_term() { - let mut buf = TermStorage::new(); + let mut buf = TermStorage::default(); let interner = DroplessInterner::default(); let id_expr = insert_expr(&mut buf, &interner, "f(g(X)), (Y; Z), X"); @@ -1567,7 +1552,7 @@ mod tests { #[test] fn test_expr_array_replace_expr() { - let mut buf = TermStorage::new(); + let mut buf = TermStorage::default(); let interner = DroplessInterner::default(); let id_expr = insert_expr(&mut buf, &interner, "X, Y"); @@ -1612,7 +1597,7 @@ mod tests { #[test] fn test_term_array_replace() { - let mut arr = UniqueTermArray::new(); + let mut arr = UniqueTermArray::default(); let interner = DroplessInterner::default(); let id_x = insert_term(&mut arr, &interner, "X"); @@ -1670,7 +1655,7 @@ mod tests { #[test] #[rustfmt::skip] fn test_term_array_replace_with() { - let mut arr = UniqueTermArray::new(); + let mut arr = UniqueTermArray::default(); let interner = DroplessInterner::default(); let id_f = insert_term(&mut arr, &interner, "f($X, $Y, $X)"); @@ -1727,7 +1712,7 @@ mod tests { #[test] fn test_recursive_term() { - let mut arr = UniqueTermArray::new(); + let mut arr = UniqueTermArray::default(); let interner = DroplessInterner::default(); insert_term(&mut arr, &interner, "f(f(a))"); From 53b8d7245eb64a8b16442de2a4f014ab2aaeb504 Mon Sep 17 00:00:00 2001 From: ecoricemon Date: Tue, 19 May 2026 10:37:06 +0900 Subject: [PATCH 2/3] chore: Rename awkward names --- crates/logic-eval/src/lib.rs | 2 +- crates/logic-eval/src/parse/repr.rs | 6 +- crates/logic-eval/src/prove/canonical.rs | 23 +- crates/logic-eval/src/prove/db.rs | 146 ++--- crates/logic-eval/src/prove/mod.rs | 2 +- .../src/prove/{prover.rs => proof_engine.rs} | 578 ++++++++++-------- crates/logic-eval/src/prove/repr.rs | 19 +- crates/logic-eval/src/prove/table.rs | 24 +- 8 files changed, 431 insertions(+), 369 deletions(-) rename crates/logic-eval/src/prove/{prover.rs => proof_engine.rs} (71%) diff --git a/crates/logic-eval/src/lib.rs b/crates/logic-eval/src/lib.rs index b57d8f3..17bba9e 100644 --- a/crates/logic-eval/src/lib.rs +++ b/crates/logic-eval/src/lib.rs @@ -15,7 +15,7 @@ pub use parse::{ pub use prove::{ common::Atom, db::{ClauseIter, ClauseRef, Database}, - prover::ProveCx, + proof_engine::QueryCx, }; /// Re-exports of the interning types used by this crate. diff --git a/crates/logic-eval/src/parse/repr.rs b/crates/logic-eval/src/parse/repr.rs index 2221344..d3e039c 100644 --- a/crates/logic-eval/src/parse/repr.rs +++ b/crates/logic-eval/src/parse/repr.rs @@ -1,5 +1,5 @@ use crate::{ - prove::{canonical as canon, prover::Integer}, + prove::{canonical as canon, proof_engine::AtomId}, Atom, }; use std::{ @@ -124,7 +124,7 @@ impl Clause { } } -impl Clause { +impl Clause { /// Returns `true` if the clause needs SLG resolution (tabling). /// /// If a clause has left or mid recursion, it must be handled by tabling. @@ -145,7 +145,7 @@ impl Clause { // === Internal helper functions === - fn helper(expr: &Expr, head: &Term) -> bool { + fn helper(expr: &Expr, head: &Term) -> bool { match expr { Expr::Term(term) => term == head, Expr::Not(arg) => helper(arg, head), diff --git a/crates/logic-eval/src/prove/canonical.rs b/crates/logic-eval/src/prove/canonical.rs index 37bbd0a..687b185 100644 --- a/crates/logic-eval/src/prove/canonical.rs +++ b/crates/logic-eval/src/prove/canonical.rs @@ -1,6 +1,6 @@ use crate::{ prove::{ - prover::Integer, + proof_engine::AtomId, repr::{TermId, TermStorage, TermViewMut}, }, Atom, Expr, Map, Term, @@ -9,14 +9,17 @@ use crate::{ #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub(crate) struct CanonicalTermId(TermId); -pub(crate) fn canonicalize_term_id(stor: &mut TermStorage, id: TermId) -> CanonicalTermId { - let mut view = stor.get_term_mut(id); +pub(crate) fn canonicalize_term_id( + query_storage: &mut TermStorage, + id: TermId, +) -> CanonicalTermId { + let mut view = query_storage.get_term_mut(id); canonicalize_term_view(&mut view); CanonicalTermId(view.id()) } /// e.g. f($X, $Y, $X) -> f($0, $1, $0) -pub(crate) fn canonicalize_term(term: &mut Term) { +pub(crate) fn canonicalize_term(term: &mut Term) { let mut c = canonicalizer(); term.replace_variables(|functor| *functor = c(*functor)); } @@ -24,7 +27,7 @@ pub(crate) fn canonicalize_term(term: &mut Term) { /// Applies [`canonicalize_term`] to each term without crossing term boundaries. /// /// e.g. f($X), g($Y, $X) -> f($0), g($0, $1) (not f($0), g($1, $0)) -pub(crate) fn canonicalize_expr_on_term(expr: &mut Expr) { +pub(crate) fn canonicalize_expr_on_term(expr: &mut Expr) { match expr { Expr::Term(term) => canonicalize_term(term), Expr::Not(arg) => canonicalize_expr_on_term(arg), @@ -36,7 +39,7 @@ pub(crate) fn canonicalize_expr_on_term(expr: &mut Expr) { } } -pub(crate) fn canonicalize_term_view(view: &mut TermViewMut<'_, Integer>) { +pub(crate) fn canonicalize_term_view(view: &mut TermViewMut<'_, AtomId>) { let mut c = canonicalizer(); view.replace_with(|functor| { if functor.is_variable() { @@ -47,12 +50,12 @@ pub(crate) fn canonicalize_term_view(view: &mut TermViewMut<'_, Integer>) { }); } -fn canonicalizer() -> impl FnMut(Integer) -> Integer { +fn canonicalizer() -> impl FnMut(AtomId) -> AtomId { let mut map = Map::default(); - move |functor: Integer| { + move |functor: AtomId| { if functor.is_variable() { - let next_int = map.len() as u32; - *map.entry(functor).or_insert(Integer::variable(next_int)) + let next_id = map.len() as u32; + *map.entry(functor).or_insert(AtomId::variable(next_id)) } else { functor } diff --git a/crates/logic-eval/src/prove/db.rs b/crates/logic-eval/src/prove/db.rs index 049cdb6..789354d 100644 --- a/crates/logic-eval/src/prove/db.rs +++ b/crates/logic-eval/src/prove/db.rs @@ -1,7 +1,7 @@ use super::{ - prover::{ + proof_engine::{ format::{NamedExprView, NamedTermView}, - Integer, NameIntMap, ProveCx, Prover, + AtomId, NameInterner, ProofEngine, QueryCx, }, repr::{ClauseId, TermStorage}, }; @@ -21,19 +21,19 @@ use core::{ /// A clause database that can answer logic queries. pub struct Database { /// Clauses grouped by predicate. - clauses: IndexMap, Vec>, + clauses: IndexMap, Vec>, /// Predicates that should be handled by tabling. - table_clauses: IndexSet>, + tabled_predicates: IndexSet>, /// Term and expression storage. - stor: TermStorage, + database_storage: TermStorage, - /// Mappings between `T` and [`Integer`]. + /// Mappings between `T` and [`AtomId`]. /// - /// [`Integer`] is used internally for fast comparison, but clients need values mapped back to + /// [`AtomId`] is used internally for fast comparison, but clients need values mapped back to /// `T`. - nimap: NameIntMap, + name_interner: NameInterner, /// We do not allow duplicate clauses in the dataset. dup_checker: DuplicateClauseChecker, @@ -57,8 +57,8 @@ impl Database { /// ``` pub fn terms(&self) -> NamedTermViewIter<'_, T> { NamedTermViewIter { - term_iter: self.stor.terms.terms(), - nimap: &self.nimap, + term_iter: self.database_storage.terms.terms(), + name_interner: &self.name_interner, } } @@ -80,8 +80,8 @@ impl Database { pub fn clauses(&self) -> ClauseIter<'_, T> { ClauseIter { clauses: &self.clauses, - stor: &self.stor, - nimap: &self.nimap, + database_storage: &self.database_storage, + name_interner: &self.name_interner, i: 0, j: 0, } @@ -125,11 +125,11 @@ impl Database { /// assert!(db.query(query).is_true()); /// ``` pub fn insert_clause(&mut self, clause: Clause) { - let clause = clause.map(&mut |t| self.nimap.name_to_int(t)); + let clause = clause.map(&mut |t| self.name_interner.intern(t)); // Records whether the clause needs tabling. if clause.needs_tabling() { - self.table_clauses.insert(clause.head.predicate()); + self.tabled_predicates.insert(clause.head.predicate()); } // If the DB already contains the given clause, then returns. @@ -139,15 +139,17 @@ impl Database { let key = clause.head.predicate(); let value = ClauseId { - head: self.stor.insert_term(clause.head), - body: clause.body.map(|expr| self.stor.insert_expr(expr)), + head: self.database_storage.insert_term(clause.head), + body: clause + .body + .map(|expr| self.database_storage.insert_expr(expr)), }; self.clauses .entry(key) - .and_modify(|similar_clauses| { - if similar_clauses.iter().all(|clause| clause != &value) { - similar_clauses.push(value); + .and_modify(|candidate_clauses| { + if candidate_clauses.iter().all(|clause| clause != &value) { + candidate_clauses.push(value); } }) .or_insert(vec![value]); @@ -172,13 +174,13 @@ impl Database { /// assert_eq!(answer.get_lhs_variable().as_ref(), "$Who"); /// assert_eq!(answer.rhs().to_string(), "bob"); /// ``` - pub fn query(&self, expr: Expr) -> ProveCx<'_, T> { - Prover::new().prove( + pub fn query(&self, expr: Expr) -> QueryCx<'_, T> { + ProofEngine::new().prove( expr, &self.clauses, - &self.table_clauses, - &self.stor, - &self.nimap, + &self.tabled_predicates, + &self.database_storage, + &self.name_interner, ) } @@ -207,21 +209,21 @@ impl Database { let mut prolog_text = String::new(); let mut conv_map = ConversionMap { - int_to_str: Map::default(), + atom_id_to_str: Map::default(), sanitized_to_suffix: Map::default(), - nimap: &self.nimap, + name_interner: &self.name_interner, sanitizer: sanitize, }; for clauses in self.clauses.values() { for clause in clauses { - let head = self.stor.get_term(clause.head); + let head = self.database_storage.get_term(clause.head); write_term(head, &mut conv_map, &mut prolog_text); if let Some(body) = clause.body { prolog_text.push_str(" :- "); - let body = self.stor.get_expr(body); + let body = self.database_storage.get_expr(body); write_expr(body, &mut conv_map, &mut prolog_text); } @@ -234,10 +236,10 @@ impl Database { // === Internal helper functions === struct ConversionMap<'a, T, F> { - int_to_str: Map, + atom_id_to_str: Map, // e.g. 0 -> No suffix, 1 -> _1, 2 -> _2, ... sanitized_to_suffix: Map<&'a str, u32>, - nimap: &'a NameIntMap, + name_interner: &'a NameInterner, sanitizer: F, } @@ -246,9 +248,9 @@ impl Database { T: AsRef, F: FnMut(&str) -> &str, { - fn int_to_str(&mut self, int: Integer) -> &str { - self.int_to_str.entry(int).or_insert_with(|| { - let name = self.nimap.get_name(&int).unwrap(); + fn atom_id_to_str(&mut self, atom_id: AtomId) -> &str { + self.atom_id_to_str.entry(atom_id).or_insert_with(|| { + let name = self.name_interner.get_name(&atom_id).unwrap(); let name: &str = name.as_ref(); let mut is_var = false; @@ -296,7 +298,7 @@ impl Database { } fn write_term( - term: TermView<'_, Integer>, + term: TermView<'_, AtomId>, conv_map: &mut ConversionMap<'_, T, F>, prolog_text: &mut String, ) where @@ -307,7 +309,7 @@ impl Database { let args = term.args(); let num_args = args.len(); - let functor = conv_map.int_to_str(*functor); + let functor = conv_map.atom_id_to_str(*functor); prolog_text.push_str(functor); if num_args > 0 { @@ -323,7 +325,7 @@ impl Database { } fn write_expr( - expr: ExprView<'_, Integer>, + expr: ExprView<'_, AtomId>, conv_map: &mut ConversionMap<'_, T, F>, prolog_text: &mut String, ) where @@ -377,9 +379,9 @@ impl Default for Database { fn default() -> Self { Self { clauses: IndexMap::default(), - table_clauses: IndexSet::default(), - stor: TermStorage::default(), - nimap: NameIntMap::default(), + tabled_predicates: IndexSet::default(), + database_storage: TermStorage::default(), + name_interner: NameInterner::default(), dup_checker: DuplicateClauseChecker::default(), } } @@ -389,9 +391,9 @@ impl Debug for Database { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Database") .field("clauses", &self.clauses) - .field("table_clauses", &self.table_clauses) - .field("stor", &self.stor) - .field("nimap", &self.nimap) + .field("tabled_predicates", &self.tabled_predicates) + .field("database_storage", &self.database_storage) + .field("name_interner", &self.name_interner) .field("dup_checker", &self.dup_checker) .finish_non_exhaustive() } @@ -399,25 +401,25 @@ impl Debug for Database { #[derive(Debug, Default)] struct DuplicateClauseChecker { - seen: IndexSet>, + seen: IndexSet>, - /// Temporary buffer for assigning canonical [`Integer`]s to variables. - vars: Vec, + /// Temporary buffer for assigning canonical [`AtomId`]s to variables. + vars: Vec, } impl DuplicateClauseChecker { /// Returns `true` if the given clause is new and has not been seen before. - fn insert(&mut self, clause: Clause) -> bool { + fn insert(&mut self, clause: Clause) -> bool { let canonical_clause = clause.map(&mut |t| { if !t.is_variable() { t } else if let Some(found) = self.vars.iter().find(|&&var| var == t) { *found } else { - let next_int = self.vars.len() as u32; - let int = Integer::variable(next_int); - self.vars.push(int); - int + let next_id = self.vars.len() as u32; + let atom_id = AtomId::variable(next_id); + self.vars.push(atom_id); + atom_id } }); let is_new = self.seen.insert(canonical_clause); @@ -488,9 +490,9 @@ fn _convert_var_into_num( /// Iterator over clauses in a [`Database`]. #[derive(Clone)] pub struct ClauseIter<'a, T> { - clauses: &'a IndexMap, Vec>, - stor: &'a TermStorage, - nimap: &'a NameIntMap, + clauses: &'a IndexMap, Vec>, + database_storage: &'a TermStorage, + name_interner: &'a NameInterner, i: usize, j: usize, } @@ -513,8 +515,8 @@ impl<'a, T> Iterator for ClauseIter<'a, T> { Some(ClauseRef { id, - stor: self.stor, - nimap: self.nimap, + database_storage: self.database_storage, + name_interner: self.name_interner, }) } } @@ -524,8 +526,8 @@ impl FusedIterator for ClauseIter<'_, T> {} /// Borrowed view of a clause stored in a [`Database`]. pub struct ClauseRef<'a, T> { id: ClauseId, - stor: &'a TermStorage, - nimap: &'a NameIntMap, + database_storage: &'a TermStorage, + name_interner: &'a NameInterner, } impl<'a, T: Atom> ClauseRef<'a, T> { @@ -545,8 +547,8 @@ impl<'a, T: Atom> ClauseRef<'a, T> { /// assert_eq!(clause.head().to_string(), "outdoors"); /// ``` pub fn head(&self) -> NamedTermView<'a, T> { - let head = self.stor.get_term(self.id.head); - NamedTermView::new(head, self.nimap) + let head = self.database_storage.get_term(self.id.head); + NamedTermView::new(head, self.name_interner) } /// Returns the clause body, if this clause is a rule. @@ -566,8 +568,8 @@ impl<'a, T: Atom> ClauseRef<'a, T> { /// ``` pub fn body(&self) -> Option> { self.id.body.map(|id| { - let body = self.stor.get_expr(id); - NamedExprView::new(body, self.nimap) + let body = self.database_storage.get_expr(id); + NamedExprView::new(body, self.name_interner) }) } } @@ -589,12 +591,12 @@ impl Debug for ClauseRef<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut d = f.debug_struct("Clause"); - let head = self.stor.get_term(self.id.head); - d.field("head", &NamedTermView::new(head, self.nimap)); + let head = self.database_storage.get_term(self.id.head); + d.field("head", &NamedTermView::new(head, self.name_interner)); if let Some(body) = self.id.body { - let body = self.stor.get_expr(body); - d.field("body", &NamedExprView::new(body, self.nimap)); + let body = self.database_storage.get_expr(body); + d.field("body", &NamedExprView::new(body, self.name_interner)); } d.finish() @@ -602,8 +604,8 @@ impl Debug for ClauseRef<'_, T> { } pub struct NamedTermViewIter<'a, T> { - term_iter: TermViewIter<'a, Integer>, - nimap: &'a NameIntMap, + term_iter: TermViewIter<'a, AtomId>, + name_interner: &'a NameInterner, } impl<'a, T: Atom> Iterator for NamedTermViewIter<'a, T> { @@ -612,7 +614,7 @@ impl<'a, T: Atom> Iterator for NamedTermViewIter<'a, T> { fn next(&mut self) -> Option { self.term_iter .next() - .map(|view| NamedTermView::new(view, self.nimap)) + .map(|view| NamedTermView::new(view, self.name_interner)) } } @@ -624,7 +626,7 @@ mod tests { type Interner = any_intern::DroplessInterner; type Database<'int> = crate::Database>; - type ProveCx<'a, 'int> = crate::ProveCx<'a, NameIn<'int, Interner>>; + type QueryCx<'a, 'int> = crate::QueryCx<'a, NameIn<'int, Interner>>; type ClauseDataset<'int> = crate::ClauseDatasetIn<'int, Interner>; type Expr<'int> = crate::ExprIn<'int, Interner>; type Clause<'int> = crate::ClauseIn<'int, Interner>; @@ -1108,7 +1110,7 @@ mod tests { db.insert_dataset(dataset); } - fn collect_answer(mut cx: ProveCx<'_, '_>) -> Vec> { + fn collect_answer(mut cx: QueryCx<'_, '_>) -> Vec> { let mut v = Vec::new(); while let Some(eval) = cx.prove_next() { let x = eval.map(|assign| assign.to_string()).collect::>(); @@ -1120,7 +1122,7 @@ mod tests { #[cfg(test)] mod custom_atom_tests { - use crate::{Atom, Clause, Database, Expr, ProveCx, Term}; + use crate::{Atom, Clause, Database, Expr, QueryCx, Term}; #[test] fn test_custom_atom() { @@ -1214,7 +1216,7 @@ mod custom_atom_tests { // === Test helper functions === - fn collect_answer(mut cx: ProveCx<'_, T>) -> Vec, Term)>> { + fn collect_answer(mut cx: QueryCx<'_, T>) -> Vec, Term)>> { let mut v = Vec::new(); while let Some(eval) = cx.prove_next() { let pairs = eval.map(|assign| (assign.lhs(), assign.rhs())).collect(); diff --git a/crates/logic-eval/src/prove/mod.rs b/crates/logic-eval/src/prove/mod.rs index 683d77c..6892cdf 100644 --- a/crates/logic-eval/src/prove/mod.rs +++ b/crates/logic-eval/src/prove/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod canonical; pub(crate) mod common; pub(crate) mod db; -pub(crate) mod prover; +pub(crate) mod proof_engine; pub(crate) mod repr; pub(crate) mod table; diff --git a/crates/logic-eval/src/prove/prover.rs b/crates/logic-eval/src/prove/proof_engine.rs similarity index 71% rename from crates/logic-eval/src/prove/prover.rs rename to crates/logic-eval/src/prove/proof_engine.rs index e7c983a..afcfa42 100644 --- a/crates/logic-eval/src/prove/prover.rs +++ b/crates/logic-eval/src/prove/proof_engine.rs @@ -21,14 +21,14 @@ use smallvec::SmallVec; use std::collections::VecDeque; #[derive(Debug)] -pub(crate) struct Prover { +pub(crate) struct ProofEngine { uni_op: UnificationOperator, /// Nodes created during proof search. nodes: Vec, /// Variable assignments (e.g. X = a, Y = z) - term_assigns: TermAssignments, + unification_assignments: TermVariableBindings, /// A given query. query: ExprId, @@ -39,7 +39,7 @@ pub(crate) struct Prover { query_vars: Vec, /// Previously returned ground query answers. - query_answers: Vec>, + seen_answers: Vec>, /// Task queue containing node index. queue: NodeQueue, @@ -48,33 +48,33 @@ pub(crate) struct Prover { /// /// This buffer is used when we convert variables into temporary variables for a clause. It is /// empty after each conversion. - temp_var_buf: Map, + fresh_var_map: Map, - /// A monotonically increasing integer used to generate temporary variables. - temp_var_int: u32, + /// A monotonically increasing counter used to generate temporary variables. + next_fresh_var: u32, /// SLG resolution. table: Table, /// Database clauses imported into query-local storage. /// - /// The cached clause is a per-query template. Each proof use still clones this template before - /// variable freshening, so repeated use of a database rule does not re-import from `db_stor`. + /// The cached clause is a per-query template. Repeated use of a database rule can freshen from + /// this query-local template without re-importing from `database_storage`. imported_clause_templates: Map, } -impl Prover { +impl ProofEngine { pub(crate) fn new() -> Self { Self { uni_op: UnificationOperator::new(), nodes: Vec::new(), - term_assigns: TermAssignments::default(), + unification_assignments: TermVariableBindings::default(), query: ExprId(0), query_vars: Vec::new(), - query_answers: Vec::new(), + seen_answers: Vec::new(), queue: NodeQueue::default(), - temp_var_buf: Map::default(), - temp_var_int: 0, + fresh_var_map: Map::default(), + next_fresh_var: 0, table: Table::default(), imported_clause_templates: Map::default(), } @@ -83,9 +83,9 @@ impl Prover { fn clear(&mut self) { self.uni_op.clear(); self.nodes.clear(); - self.term_assigns.clear(); + self.unification_assignments.clear(); self.query_vars.clear(); - self.query_answers.clear(); + self.seen_answers.clear(); self.queue.clear(); self.table.clear(); self.imported_clause_templates.clear(); @@ -94,12 +94,19 @@ impl Prover { pub(crate) fn prove<'a, T: Atom>( self, query: Expr, - clauses: &'a IndexMap, Vec>, - table_clauses: &'a IndexSet>, - db_stor: &'a TermStorage, - db_nimap: &'a NameIntMap, - ) -> ProveCx<'a, T> { - ProveCx::new(self, query, clauses, table_clauses, db_stor, db_nimap) + clauses: &'a IndexMap, Vec>, + tabled_predicates: &'a IndexSet>, + database_storage: &'a TermStorage, + database_name_interner: &'a NameInterner, + ) -> QueryCx<'a, T> { + QueryCx::new( + self, + query, + clauses, + tabled_predicates, + database_storage, + database_name_interner, + ) } /// Evaluates the given node against matching clauses or table answers, then returns whether a @@ -110,10 +117,10 @@ impl Prover { fn evaluate_node( &mut self, node_index: usize, - clauses: &IndexMap, Vec>, - table_clauses: &IndexSet>, - db_stor: &TermStorage, - stor: &mut TermStorage, + clauses: &IndexMap, Vec>, + tabled_predicates: &IndexSet>, + database_storage: &TermStorage, + query_storage: &mut TermStorage, ) -> Option { let node_expr = match self.nodes[node_index].kind { NodeKind::Expr(expr_id) => expr_id, @@ -122,22 +129,22 @@ impl Prover { // On a successful proof, records the answer in the nearest ancestor-owned SLG // table entry, then notifies all waiting consumers. if eval { - self.update_answer_and_notify(node_index, stor); + self.update_answer_and_notify(node_index, query_storage); } return Some(eval); } }; - let node_leftmost = stor.get_expr(node_expr).leftmost_term().id; - let node_leftmost_pred = stor.get_term(node_leftmost).predicate(); - let mut similar_clauses: SmallVec<[ClauseId; 2]> = SmallVec::new(); + let node_leftmost = query_storage.get_expr(node_expr).leftmost_term().id; + let node_leftmost_pred = query_storage.get_term(node_leftmost).predicate(); + let mut candidate_clauses: SmallVec<[ClauseId; 2]> = SmallVec::new(); // === SLG path === // * Table entry - Created from non-canonical leftmost term of the node. In tabling, // we use canonical variables for table keys only. - if table_clauses.contains(&node_leftmost_pred) { - let key = canonical::canonicalize_term_id(stor, node_leftmost); + if tabled_predicates.contains(&node_leftmost_pred) { + let key = canonical::canonicalize_term_id(query_storage, node_leftmost); if let Some((_, entry)) = self.table.get_mut(&key) { entry.register_consumer(node_index); @@ -151,12 +158,12 @@ impl Prover { self.nodes[node_index].table_answer_offset = next_offset; // Synthesize an answer clause, then unify it with the current node. - let mut term = stor.get_term_mut(node_leftmost); + let mut term = query_storage.get_term_mut(node_leftmost); let vars = term.as_view().collect_variables(); for (var, answer) in vars.into_iter().zip(answers) { term.replace(var, *answer); } - similar_clauses.push(ClauseId { + candidate_clauses.push(ClauseId { head: term.id(), body: None, }); @@ -167,7 +174,9 @@ impl Prover { } } else { // First encounter: create a table entry, then proceed with SLD. - if let Some(entry) = TableEntry::from_term_view(&stor.get_term(node_leftmost)) { + if let Some(entry) = + TableEntry::from_term_view(&query_storage.get_term(node_leftmost)) + { let index = self.table.register(key, entry); self.nodes[node_index].table_owner = Some(index); } @@ -176,38 +185,38 @@ impl Prover { // === BFS based SLD path === - if similar_clauses.is_empty() { + if candidate_clauses.is_empty() { if let Some(v) = clauses.get(&node_leftmost_pred) { for &clause in v { let template = if let Some(&template) = self.imported_clause_templates.get(&clause) { template } else { - let template = stor.import_clause(db_stor, clause); + let template = query_storage.import_clause(database_storage, clause); self.imported_clause_templates.insert(clause, template); template }; - similar_clauses.push(template); + candidate_clauses.push(template); } } } let old_len = self.nodes.len(); - for clause in similar_clauses { - let head = stor.get_term(clause.head); + for clause in candidate_clauses { + let head = query_storage.get_term(clause.head); - if !stor.get_expr(node_expr).is_unifiable(head) { + if !query_storage.get_expr(node_expr).is_unifiable(head) { continue; } let clause = Self::convert_var_into_temp( clause, - stor, - &mut self.temp_var_buf, - &mut self.temp_var_int, + query_storage, + &mut self.fresh_var_map, + &mut self.next_fresh_var, ); - if let Some(new_node) = self.unify_node_with_clause(node_index, clause, stor) { + if let Some(new_node) = self.unify_node_with_clause(node_index, clause, query_storage) { self.nodes.push(new_node); self.queue.push(self.nodes.len() - 1); } @@ -218,7 +227,7 @@ impl Prover { // - Unification failure means the leftmost term should be false. // - But we need to consider the exhaustive search at the same time. - let expr = stor.get_expr(node_expr); + let expr = query_storage.get_expr(node_expr); let eval = self.nodes.len() > old_len; let mut need_apply = None; @@ -233,7 +242,7 @@ impl Prover { } if let Some(to) = need_apply { - let mut expr = stor.get_expr_mut(node_expr); + let mut expr = query_storage.get_expr_mut(node_expr); let node_kind = match expr.apply_to_leftmost_term(to) { ApplyResult::Expr => NodeKind::Expr(expr.id()), ApplyResult::Complete(eval) => NodeKind::Leaf(eval), @@ -263,7 +272,7 @@ impl Prover { }, } - fn assume_leftmost_term(expr: ExprView<'_, Integer>, to: bool) -> AssumeResult { + fn assume_leftmost_term(expr: ExprView<'_, AtomId>, to: bool) -> AssumeResult { match expr.as_kind() { ExprKind::Term(_) => AssumeResult::Complete { eval: to, @@ -314,7 +323,7 @@ impl Prover { /// Finds the nearest ancestor node that owns SLG table entry, then updates the entry and /// notifies all waiting consumers. - fn update_answer_and_notify(&mut self, node_index: usize, stor: &TermStorage) { + fn update_answer_and_notify(&mut self, node_index: usize, query_storage: &TermStorage) { let tabled_ancestor = { let mut cur = node_index; loop { @@ -333,15 +342,15 @@ impl Prover { let table_index = self.nodes[ancestor].table_owner.unwrap(); let entry = &mut self.table[table_index]; let all_answers_concrete = entry.variables().iter().all(|&var| { - if let Some(answer) = self.term_assigns.find(var) { - !stor.get_term(answer).contains_variable() + if let Some(answer) = self.unification_assignments.find(var) { + !query_storage.get_term(answer).contains_variable() } else { false } }); - if all_answers_concrete && !entry.has_answer(&self.term_assigns) { - entry.update_answer(&self.term_assigns); + if all_answers_concrete && !entry.has_answer(&self.unification_assignments) { + entry.update_answer(&self.unification_assignments); for i in entry.consumer_nodes() { if i != node_index { self.queue.push(i); @@ -360,41 +369,43 @@ impl Prover { // times in a single proof-search path. Each use is considered a distinct clause. fn convert_var_into_temp( mut clause_id: ClauseId, - stor: &mut TermStorage, - temp_var_buf: &mut Map, - temp_var_int: &mut u32, + query_storage: &mut TermStorage, + fresh_var_map: &mut Map, + next_fresh_var: &mut u32, ) -> ClauseId { - debug_assert!(temp_var_buf.is_empty()); + debug_assert!(fresh_var_map.is_empty()); - let mut f = |terms: &mut UniqueTermArray, term_id: TermId| { + let mut f = |terms: &mut UniqueTermArray, term_id: TermId| { let term = terms.get_mut(term_id); if term.is_variable() { let src = term.id(); - temp_var_buf.entry(src).or_insert_with(|| { + fresh_var_map.entry(src).or_insert_with(|| { let temp_term = Term { - functor: Integer::temporary(*temp_var_int), + functor: AtomId::temporary(*next_fresh_var), args: [].into(), }; - *temp_var_int += 1; + *next_fresh_var += 1; terms.insert(temp_term) }); } }; - stor.get_term_mut(clause_id.head).with_terminal(&mut f); + query_storage + .get_term_mut(clause_id.head) + .with_terminal(&mut f); if let Some(body) = clause_id.body { - stor.get_expr_mut(body).with_terminal(&mut f); + query_storage.get_expr_mut(body).with_terminal(&mut f); } - for (src, dst) in temp_var_buf.drain() { - let mut head = stor.get_term_mut(clause_id.head); + for (src, dst) in fresh_var_map.drain() { + let mut head = query_storage.get_term_mut(clause_id.head); head.replace(src, dst); clause_id.head = head.id(); if let Some(body) = clause_id.body { - let mut body = stor.get_expr_mut(body); + let mut body = query_storage.get_expr_mut(body); body.replace_term(src, dst); clause_id.body = Some(body.id()); } @@ -407,7 +418,7 @@ impl Prover { &mut self, node_index: usize, clause: ClauseId, - stor: &mut TermStorage, + query_storage: &mut TermStorage, ) -> Option { debug_assert!(self.uni_op.ops.is_empty()); @@ -415,19 +426,20 @@ impl Prover { unreachable!() }; - if !stor + if !query_storage .get_expr(node_expr) .leftmost_term() - .unify(stor.get_term(clause.head), &mut |op| { + .unify(query_storage.get_term(clause.head), &mut |op| { self.uni_op.push_op(op) }) { return None; } - let (node_expr, clause, uni_history) = self.uni_op.consume_ops(stor, node_expr, clause); + let (node_expr, clause, uni_history) = + self.uni_op.consume_ops(query_storage, node_expr, clause); if let Some(body) = clause.body { - let mut lhs = stor.get_expr_mut(node_expr); + let mut lhs = query_storage.get_expr_mut(node_expr); lhs.replace_leftmost_term(body); let node_kind = NodeKind::Expr(lhs.id()); let node_parent = node_index; @@ -435,7 +447,7 @@ impl Prover { return Some(node); } - let mut lhs = stor.get_expr_mut(node_expr); + let mut lhs = query_storage.get_expr_mut(node_expr); let node_kind = match lhs.apply_to_leftmost_term(true) { ApplyResult::Expr => NodeKind::Expr(lhs.id()), ApplyResult::Complete(eval) => NodeKind::Leaf(eval), @@ -446,9 +458,9 @@ impl Prover { } /// Finds all from/to relations while traversing from the given node to the root, then adds the - /// relations to [`TermAssignments`]. + /// relations to [`TermVariableBindings`]. fn find_assignments(&mut self, node_index: usize) { - self.term_assigns.clear(); + self.unification_assignments.clear(); let mut cur_index = node_index; loop { @@ -456,7 +468,7 @@ impl Prover { let range = node.uni_history.clone(); for (from, to) in self.uni_op.get_record(range).iter().cloned() { - self.term_assigns.add(from, to); + self.unification_assignments.add(from, to); } if node.parent == cur_index { @@ -468,25 +480,25 @@ impl Prover { /// Records the current proof result as a query answer if it is ground and not duplicated, /// then returns whether a new answer was recorded. - fn record_query_answer(&mut self, stor: &mut TermStorage) -> bool { + fn record_query_answer(&mut self, query_storage: &mut TermStorage) -> bool { let mut answer = Vec::with_capacity(self.query_vars.len()); for &var in &self.query_vars { - let Some(resolved) = self.materialize_assigned_term(var, stor) else { + let Some(resolved) = self.materialize_assigned_term(var, query_storage) else { return false; }; answer.push(resolved); } // No query vars -> empty iter -> all() returns true. - if self.query_answers.iter().all(|seen| seen != &answer) { - self.query_answers.push(answer); + if self.seen_answers.iter().all(|seen| seen != &answer) { + self.seen_answers.push(answer); true } else { false } } - /// Builds a fully substituted term for a query-side term from `term_assigns`. + /// Builds a fully substituted term for a query-side term from `unification_assignments`. /// /// Examples: /// @@ -498,19 +510,19 @@ impl Prover { /// | `T = Vec(U)` | `T` | `None` | /// /// This must materialize the whole term tree, not just rewrite functors in place. The returned - /// `TermId` always points to a ground term inserted into `stor`. + /// `TermId` always points to a ground term inserted into `query_storage`. fn materialize_assigned_term( &self, term_id: TermId, - stor: &mut TermStorage, + query_storage: &mut TermStorage, ) -> Option { - let term = stor.get_term(term_id); + let term = query_storage.get_term(term_id); if term.is_variable() { - let resolved = self.term_assigns.find(term_id)?; + let resolved = self.unification_assignments.find(term_id)?; if resolved == term_id { return None; } - return self.materialize_assigned_term(resolved, stor); + return self.materialize_assigned_term(resolved, query_storage); } let functor = *term.functor(); @@ -518,13 +530,13 @@ impl Prover { let args = arg_ids .into_iter() .map(|arg_id| { - self.materialize_assigned_term(arg_id, stor) - .map(|id| stor.get_term(id).deserialize()) + self.materialize_assigned_term(arg_id, query_storage) + .map(|id| query_storage.get_term(id).deserialize()) }) .collect::>>()?; let materialized = Term { functor, args }; - Some(stor.insert_term(materialized)) + Some(query_storage.insert_term(materialized)) } } @@ -573,7 +585,7 @@ impl UnificationOperator { #[must_use] fn consume_ops( &mut self, - stor: &mut TermStorage, + query_storage: &mut TermStorage, mut left: ExprId, mut right: ClauseId, ) -> (ExprId, ClauseId, Range) { @@ -582,7 +594,7 @@ impl UnificationOperator { for op in self.ops.drain(..) { match op { UnifyOp::Left { from, to } => { - let mut expr = stor.get_expr_mut(left); + let mut expr = query_storage.get_expr_mut(left); expr.replace_term(from, to); left = expr.id(); @@ -590,7 +602,7 @@ impl UnificationOperator { } UnifyOp::Right { from, to } => { if let Some(right_body) = right.body { - let mut expr = stor.get_expr_mut(right_body); + let mut expr = query_storage.get_expr_mut(right_body); expr.replace_term(from, to); right.body = Some(expr.id()); @@ -678,7 +690,7 @@ enum NodeKind { } #[derive(Debug, Default)] -pub(crate) struct TermAssignments { +pub(crate) struct TermVariableBindings { /// Union-find from/to relations. /// /// # Examples @@ -687,7 +699,7 @@ pub(crate) struct TermAssignments { relations: Vec, } -impl TermAssignments { +impl TermVariableBindings { pub(crate) fn find(&self, from: TermId) -> Option { let to = *self.relations.get(from.0)?; if from == to { @@ -743,49 +755,50 @@ enum UnifyOp { } /// Proof-search context for a query. -pub struct ProveCx<'a, T: Atom> { - prover: Prover, - clauses: &'a IndexMap, Vec>, - table_clauses: &'a IndexSet>, - db_stor: &'a TermStorage, - stor: TermStorage, - nimap: QueryNameIntMap<'a, T>, +pub struct QueryCx<'a, T: Atom> { + proof_engine: ProofEngine, + clauses: &'a IndexMap, Vec>, + tabled_predicates: &'a IndexSet>, + database_storage: &'a TermStorage, + query_storage: TermStorage, + name_interner: QueryNameInterner<'a, T>, } -impl<'a, T: Atom> ProveCx<'a, T> { +impl<'a, T: Atom> QueryCx<'a, T> { fn new( - mut prover: Prover, + mut proof_engine: ProofEngine, query: Expr, - clauses: &'a IndexMap, Vec>, - table_clauses: &'a IndexSet>, - db_stor: &'a TermStorage, - db_nimap: &'a NameIntMap, + clauses: &'a IndexMap, Vec>, + tabled_predicates: &'a IndexSet>, + database_storage: &'a TermStorage, + database_name_interner: &'a NameInterner, ) -> Self { - prover.clear(); + proof_engine.clear(); - let mut nimap = QueryNameIntMap::new(db_nimap); - let query = query.map(&mut |name| nimap.name_to_int(name)); + let mut name_interner = QueryNameInterner::new(database_name_interner); + let query = query.map(&mut |name| name_interner.intern(name)); - let mut stor = TermStorage::default(); - prover.query = stor.insert_expr(query); + let mut query_storage = TermStorage::default(); + proof_engine.query = query_storage.insert_expr(query); - stor.get_expr(prover.query) - .with_term(&mut |term: TermView<'_, Integer>| { - term.with_variable(|term| prover.query_vars.push(term.id)); + query_storage + .get_expr(proof_engine.query) + .with_term(&mut |term: TermView<'_, AtomId>| { + term.with_variable(|term| proof_engine.query_vars.push(term.id)); }); - let node_kind = NodeKind::Expr(prover.query); - let node_parent = prover.nodes.len(); - prover.nodes.push(Node::new(node_kind, node_parent)); - prover.queue.push(0); + let node_kind = NodeKind::Expr(proof_engine.query); + let node_parent = proof_engine.nodes.len(); + proof_engine.nodes.push(Node::new(node_kind, node_parent)); + proof_engine.queue.push(0); Self { - prover, + proof_engine, clauses, - table_clauses, - db_stor, - stor, - nimap, + tabled_predicates, + database_storage, + query_storage, + name_interner, } } @@ -813,25 +826,29 @@ impl<'a, T: Atom> ProveCx<'a, T> { /// answers.sort_unstable(); /// assert_eq!(answers, vec!["bob", "carol"]); /// ``` - pub fn prove_next(&mut self) -> Option> { - while let Some(node_index) = self.prover.queue.pop() { - if let Some(proof_result) = self.prover.evaluate_node( + pub fn prove_next(&mut self) -> Option> { + while let Some(node_index) = self.proof_engine.queue.pop() { + if let Some(proof_result) = self.proof_engine.evaluate_node( node_index, self.clauses, - self.table_clauses, - self.db_stor, - &mut self.stor, + self.tabled_predicates, + self.database_storage, + &mut self.query_storage, ) { - // Return Some(EvalView) only if the result is TRUE and yielded a new ground + // Return Some(AnswerView) only if the result is TRUE and yielded a new ground // query answer. - if proof_result && self.prover.record_query_answer(&mut self.stor) { - return Some(EvalView { - query_vars: &self.prover.query_vars, - terms: &self.stor.terms.buf, - term_assigns: &self.prover.term_assigns, - nimap: &self.nimap, + if proof_result + && self + .proof_engine + .record_query_answer(&mut self.query_storage) + { + return Some(AnswerView { + query_vars: &self.proof_engine.query_vars, + terms: &self.query_storage.terms.buf, + unification_assignments: &self.proof_engine.unification_assignments, + name_interner: &self.name_interner, start: 0, - end: self.prover.query_vars.len(), + end: self.proof_engine.query_vars.len(), }); } } @@ -860,71 +877,71 @@ impl<'a, T: Atom> ProveCx<'a, T> { } /// Name mapping for a query: database names are borrowed, query-only names are local. -struct QueryNameIntMap<'a, T> { - database: &'a NameIntMap, - local: NameIntMap, +struct QueryNameInterner<'a, T> { + database: &'a NameInterner, + local: NameInterner, } -impl<'a, T> QueryNameIntMap<'a, T> { - fn new(database: &'a NameIntMap) -> Self { +impl<'a, T> QueryNameInterner<'a, T> { + fn new(database: &'a NameInterner) -> Self { Self { database, - local: NameIntMap { - name2int: IndexMap::default(), - int2name: IndexMap::default(), - next_int: database.next_int, + local: NameInterner { + name_to_id: IndexMap::default(), + id_to_name: IndexMap::default(), + next_id: database.next_id, }, } } - fn get_name(&self, int: &Integer) -> Option<&T> { + fn get_name(&self, atom_id: &AtomId) -> Option<&T> { self.database - .get_name(int) - .or_else(|| self.local.get_name(int)) + .get_name(atom_id) + .or_else(|| self.local.get_name(atom_id)) } } -impl QueryNameIntMap<'_, T> { - fn name_to_int(&mut self, name: T) -> Integer { - if let Some(int) = self.database.name2int.get(&name) { - *int +impl QueryNameInterner<'_, T> { + fn intern(&mut self, name: T) -> AtomId { + if let Some(atom_id) = self.database.name_to_id.get(&name) { + *atom_id } else { - self.local.name_to_int(name) + self.local.intern(name) } } } /// View over the assignments produced by one proof result. -pub struct EvalView<'a, T> { +pub struct AnswerView<'a, T> { query_vars: &'a [TermId], - terms: &'a [TermElem], - term_assigns: &'a TermAssignments, - nimap: &'a QueryNameIntMap<'a, T>, + terms: &'a [TermElem], + unification_assignments: &'a TermVariableBindings, + name_interner: &'a QueryNameInterner<'a, T>, /// Inclusive start: usize, /// Exclusive end: usize, } -impl EvalView<'_, T> { +impl AnswerView<'_, T> { const fn len(&self) -> usize { self.end - self.start } } -impl<'a, T> Iterator for EvalView<'a, T> { - type Item = Assignment<'a, T>; +impl<'a, T> Iterator for AnswerView<'a, T> { + type Item = VariableBinding<'a, T>; fn next(&mut self) -> Option { if self.start < self.end { let from = self.query_vars[self.start]; self.start += 1; - Some(Assignment { + Some(VariableBinding { buf: self.terms, from, - term_assigns: self.term_assigns, - nimap: self.nimap, + unification_assignments: self.unification_assignments, + name_interner: self.name_interner, }) } else { None @@ -937,23 +954,23 @@ impl<'a, T> Iterator for EvalView<'a, T> { } } -impl ExactSizeIterator for EvalView<'_, T> { +impl ExactSizeIterator for AnswerView<'_, T> { fn len(&self) -> usize { ::len(self) } } -impl iter::FusedIterator for EvalView<'_, T> {} +impl iter::FusedIterator for AnswerView<'_, T> {} /// A single variable assignment from a proof result. -pub struct Assignment<'a, T> { - buf: &'a [TermElem], +pub struct VariableBinding<'a, T> { + buf: &'a [TermElem], from: TermId, - term_assigns: &'a TermAssignments, - nimap: &'a QueryNameIntMap<'a, T>, + unification_assignments: &'a TermVariableBindings, + name_interner: &'a QueryNameInterner<'a, T>, } -impl<'a, T: 'a> Assignment<'a, T> { +impl<'a, T: 'a> VariableBinding<'a, T> { /// Returns the left-hand-side variable name of the assignment. /// /// Note that the assignment's left-hand side is always a variable. @@ -975,27 +992,27 @@ impl<'a, T: 'a> Assignment<'a, T> { /// assert_eq!(assignment.get_lhs_variable().as_ref(), "$Who"); /// ``` pub fn get_lhs_variable(&self) -> &T { - let int = self.lhs_view().find_variable().unwrap(); - self.nimap.get_name(&int).unwrap() + let atom_id = self.lhs_view().find_variable().unwrap(); + self.name_interner.get_name(&atom_id).unwrap() } - const fn lhs_view(&self) -> TermView<'_, Integer> { + const fn lhs_view(&self) -> TermView<'_, AtomId> { TermView { buf: self.buf, id: self.from, } } - const fn rhs_view(&self) -> TermDeepView<'_, Integer> { + const fn rhs_view(&self) -> TermDeepView<'_, AtomId> { TermDeepView { buf: self.buf, - term_assigns: self.term_assigns, + unification_assignments: self.unification_assignments, id: self.from, } } } -impl<'a, T: Atom + 'a> Assignment<'a, T> { +impl<'a, T: Atom + 'a> VariableBinding<'a, T> { /// Creates the left-hand-side term of the assignment. /// /// Creating a term may allocate memory. @@ -1017,7 +1034,7 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { /// assert_eq!(assignment.lhs().to_string(), "$Who"); /// ``` pub fn lhs(&self) -> Term { - Self::term_view_to_term(self.lhs_view(), self.nimap) + Self::term_view_to_term(self.lhs_view(), self.name_interner) } /// Creates the right-hand-side term of the assignment. @@ -1041,50 +1058,53 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { /// assert_eq!(assignment.rhs().to_string(), "bob"); /// ``` pub fn rhs(&self) -> Term { - Self::term_deep_view_to_term(self.rhs_view(), self.nimap) + Self::term_deep_view_to_term(self.rhs_view(), self.name_interner) } - fn term_view_to_term(view: TermView<'_, Integer>, nimap: &QueryNameIntMap<'_, T>) -> Term { + fn term_view_to_term( + view: TermView<'_, AtomId>, + name_interner: &QueryNameInterner<'_, T>, + ) -> Term { let functor = view.functor(); let args = view.args(); - let functor = if let Some(name) = nimap.get_name(functor) { + let functor = if let Some(name) = name_interner.get_name(functor) { name.clone() } else { - unreachable!("integer {:?} has no name mapping", functor) + unreachable!("atom id {:?} has no name mapping", functor) }; let args = args .into_iter() - .map(|arg| Self::term_view_to_term(arg, nimap)) + .map(|arg| Self::term_view_to_term(arg, name_interner)) .collect(); Term { functor, args } } fn term_deep_view_to_term( - view: TermDeepView<'_, Integer>, - nimap: &QueryNameIntMap<'_, T>, + view: TermDeepView<'_, AtomId>, + name_interner: &QueryNameInterner<'_, T>, ) -> Term { let functor = view.functor(); let args = view.args(); - let functor = if let Some(name) = nimap.get_name(functor) { + let functor = if let Some(name) = name_interner.get_name(functor) { name.clone() } else { - unreachable!("integer {:?} has no name mapping", functor) + unreachable!("atom id {:?} has no name mapping", functor) }; let args = args .into_iter() - .map(|arg| Self::term_deep_view_to_term(arg, nimap)) + .map(|arg| Self::term_deep_view_to_term(arg, name_interner)) .collect(); Term { functor, args } } } -impl Display for Assignment<'_, T> { +impl Display for VariableBinding<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { Display::fmt(&self.lhs(), f)?; @@ -1094,17 +1114,17 @@ impl Display for Assignment<'_, T> { } } -impl Debug for Assignment<'_, T> { +impl Debug for VariableBinding<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Assignment") + f.debug_struct("VariableBinding") .field("lhs", &self.lhs()) .field("rhs", &self.rhs()) .finish() } } -impl ExprView<'_, Integer> { - fn is_unifiable(&self, other: TermView<'_, Integer>) -> bool { +impl ExprView<'_, AtomId> { + fn is_unifiable(&self, other: TermView<'_, AtomId>) -> bool { match self.as_kind() { ExprKind::Term(term) => term.is_unifiable(other), ExprKind::Not(inner) => inner.is_unifiable(other), @@ -1125,7 +1145,7 @@ impl ExprView<'_, Integer> { } } -impl TermView<'_, Integer> { +impl TermView<'_, AtomId> { fn unify(self, other: Self, f: &mut F) -> bool { if self.is_variable() { f(UnifyOp::Left { @@ -1172,17 +1192,17 @@ impl TermView<'_, Integer> { } } -impl TermViewMut<'_, Integer> { +impl TermViewMut<'_, AtomId> { fn is_variable(&self) -> bool { self.arity() == 0 && self.functor().is_variable() } } -/// Internal integer representation of an atom. +/// Internal identifier for an atom. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Integer(u32); +pub struct AtomId(u32); -impl Integer { +impl AtomId { const VAR_FLAG: u32 = 0x1 << 31; const TEMPORARY_FLAG: u32 = 0x1 << 30; @@ -1193,16 +1213,16 @@ impl Integer { Self(index) } - pub(crate) fn variable(int: u32) -> Self { + pub(crate) fn variable(index: u32) -> Self { let mask = Self::VAR_FLAG; - debug_assert_eq!(int & mask, 0); - Self(int | mask) + debug_assert_eq!(index & mask, 0); + Self(index | mask) } - pub(crate) fn temporary(int: u32) -> Self { + pub(crate) fn temporary(index: u32) -> Self { let mask = Self::VAR_FLAG | Self::TEMPORARY_FLAG; - debug_assert_eq!(int & mask, 0); - Self(int | mask) + debug_assert_eq!(index & mask, 0); + Self(index | mask) } pub(crate) const fn is_temporary_variable(self) -> bool { @@ -1211,13 +1231,13 @@ impl Integer { } } -impl Atom for Integer { +impl Atom for AtomId { fn is_variable(&self) -> bool { (Self::VAR_FLAG & self.0) == Self::VAR_FLAG } } -impl Debug for Integer { +impl Debug for AtomId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mask: u32 = Self::VAR_FLAG | Self::TEMPORARY_FLAG; let index = !mask & self.0; @@ -1232,7 +1252,7 @@ impl Debug for Integer { } } -impl ops::AddAssign for Integer { +impl ops::AddAssign for AtomId { fn add_assign(&mut self, rhs: u32) { self.0 += rhs; } @@ -1242,40 +1262,40 @@ impl ops::AddAssign for Integer { /// /// Auto-generated names such as temporary variables are not stored here. #[derive(Debug)] -pub(crate) struct NameIntMap { - name2int: IndexMap, - int2name: IndexMap, - next_int: u32, +pub(crate) struct NameInterner { + name_to_id: IndexMap, + id_to_name: IndexMap, + next_id: u32, } -impl NameIntMap { - pub(crate) fn get_name(&self, int: &Integer) -> Option<&T> { - self.int2name.get(int) +impl NameInterner { + pub(crate) fn get_name(&self, atom_id: &AtomId) -> Option<&T> { + self.id_to_name.get(atom_id) } } -impl NameIntMap { - pub(crate) fn name_to_int(&mut self, name: T) -> Integer { - if let Some(int) = self.name2int.get(&name) { - *int +impl NameInterner { + pub(crate) fn intern(&mut self, name: T) -> AtomId { + if let Some(atom_id) = self.name_to_id.get(&name) { + *atom_id } else { - let int = Integer::from_value(&name, self.next_int); + let atom_id = AtomId::from_value(&name, self.next_id); - self.name2int.insert(name.clone(), int); - self.int2name.insert(int, name); + self.name_to_id.insert(name.clone(), atom_id); + self.id_to_name.insert(atom_id, name); - self.next_int += 1; - int + self.next_id += 1; + atom_id } } } -impl Default for NameIntMap { +impl Default for NameInterner { fn default() -> Self { Self { - name2int: IndexMap::default(), - int2name: IndexMap::default(), - next_int: 0, + name_to_id: IndexMap::default(), + id_to_name: IndexMap::default(), + next_id: 0, } } } @@ -1284,19 +1304,25 @@ pub(crate) mod format { use super::*; pub struct NamedTermView<'a, T> { - view: TermView<'a, Integer>, - nimap: &'a NameIntMap, + view: TermView<'a, AtomId>, + name_interner: &'a NameInterner, } impl<'a, T> NamedTermView<'a, T> { - pub(crate) const fn new(view: TermView<'a, Integer>, nimap: &'a NameIntMap) -> Self { - Self { view, nimap } + pub(crate) const fn new( + view: TermView<'a, AtomId>, + name_interner: &'a NameInterner, + ) -> Self { + Self { + view, + name_interner, + } } fn args<'s>(&'s self) -> impl Iterator> + 's { self.view.args().map(|arg| Self { view: arg, - nimap: self.nimap, + name_interner: self.name_interner, }) } } @@ -1320,7 +1346,7 @@ pub(crate) mod format { /// ``` pub fn is(&self, term: &Term) -> bool { let functor = self.view.functor(); - let Some(functor) = self.nimap.get_name(functor) else { + let Some(functor) = self.name_interner.get_name(functor) else { return false; }; @@ -1358,18 +1384,21 @@ pub(crate) mod format { impl<'a, T: Display> Display for NamedTermView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, nimap } = self; + let Self { + view, + name_interner, + } = self; let functor = view.functor(); let args = view.args(); let num_args = args.len(); - write_int(functor, nimap, f)?; + write_atom_id(functor, name_interner, f)?; if num_args > 0 { f.write_char('(')?; for (i, arg) in args.enumerate() { - fmt::Display::fmt(&Self::new(arg, nimap), f)?; + fmt::Display::fmt(&Self::new(arg, name_interner), f)?; if i + 1 < num_args { f.write_str(", ")?; } @@ -1382,20 +1411,23 @@ pub(crate) mod format { impl<'a, T: Debug> Debug for NamedTermView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, nimap } = self; + let Self { + view, + name_interner, + } = self; let functor = view.functor(); let args = view.args(); let num_args = args.len(); if num_args == 0 { - if let Some(name) = nimap.get_name(functor) { + if let Some(name) = name_interner.get_name(functor) { fmt::Debug::fmt(name, f) } else { fmt::Debug::fmt(functor, f) } } else { - let name_str = if let Some(name) = nimap.get_name(functor) { + let name_str = if let Some(name) = name_interner.get_name(functor) { format!("{:?}", name) } else { format!("{:?}", functor) @@ -1403,7 +1435,7 @@ pub(crate) mod format { let mut d = f.debug_tuple(&name_str); for arg in args { - d.field(&Self::new(arg, nimap)); + d.field(&Self::new(arg, name_interner)); } d.finish() } @@ -1411,13 +1443,19 @@ pub(crate) mod format { } pub struct NamedExprView<'a, T> { - view: ExprView<'a, Integer>, - nimap: &'a NameIntMap, + view: ExprView<'a, AtomId>, + name_interner: &'a NameInterner, } impl<'a, T> NamedExprView<'a, T> { - pub(crate) const fn new(view: ExprView<'a, Integer>, nimap: &'a NameIntMap) -> Self { - Self { view, nimap } + pub(crate) const fn new( + view: ExprView<'a, AtomId>, + name_interner: &'a NameInterner, + ) -> Self { + Self { + view, + name_interner, + } } } @@ -1442,18 +1480,18 @@ pub(crate) mod format { match self.view.as_kind() { ExprKind::Term(view) => NamedTermView { view, - nimap: self.nimap, + name_interner: self.name_interner, } .contains(term), ExprKind::Not(view) => NamedExprView { view, - nimap: self.nimap, + name_interner: self.name_interner, } .contains_term(term), ExprKind::And(args) | ExprKind::Or(args) => args.into_iter().any(|view| { NamedExprView { view, - nimap: self.nimap, + name_interner: self.name_interner, } .contains_term(term) }), @@ -1463,18 +1501,27 @@ pub(crate) mod format { impl<'a, T: Display> Display for NamedExprView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, nimap } = self; + let Self { + view, + name_interner, + } = self; match view.as_kind() { - ExprKind::Term(term) => fmt::Display::fmt(&NamedTermView { view: term, nimap }, f)?, + ExprKind::Term(term) => fmt::Display::fmt( + &NamedTermView { + view: term, + name_interner, + }, + f, + )?, ExprKind::Not(inner) => { f.write_str("\\+ ")?; if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) { f.write_char('(')?; - fmt::Display::fmt(&Self::new(inner, nimap), f)?; + fmt::Display::fmt(&Self::new(inner, name_interner), f)?; f.write_char(')')?; } else { - fmt::Display::fmt(&Self::new(inner, nimap), f)?; + fmt::Display::fmt(&Self::new(inner, name_interner), f)?; } } ExprKind::And(args) => { @@ -1482,10 +1529,10 @@ pub(crate) mod format { for (i, arg) in args.enumerate() { if matches!(arg.as_kind(), ExprKind::Or(_)) { f.write_char('(')?; - fmt::Display::fmt(&Self::new(arg, nimap), f)?; + fmt::Display::fmt(&Self::new(arg, name_interner), f)?; f.write_char(')')?; } else { - fmt::Display::fmt(&Self::new(arg, nimap), f)?; + fmt::Display::fmt(&Self::new(arg, name_interner), f)?; } if i + 1 < num_args { f.write_str(", ")?; @@ -1495,7 +1542,7 @@ pub(crate) mod format { ExprKind::Or(args) => { let num_args = args.len(); for (i, arg) in args.enumerate() { - fmt::Display::fmt(&Self::new(arg, nimap), f)?; + fmt::Display::fmt(&Self::new(arg, name_interner), f)?; if i + 1 < num_args { f.write_str("; ")?; } @@ -1508,25 +1555,30 @@ pub(crate) mod format { impl<'a, T: Debug> Debug for NamedExprView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, nimap } = self; + let Self { + view, + name_interner, + } = self; match view.as_kind() { - ExprKind::Term(term) => fmt::Debug::fmt(&NamedTermView::new(term, nimap), f), + ExprKind::Term(term) => { + fmt::Debug::fmt(&NamedTermView::new(term, name_interner), f) + } ExprKind::Not(inner) => f .debug_tuple("Not") - .field(&NamedExprView::new(inner, nimap)) + .field(&NamedExprView::new(inner, name_interner)) .finish(), ExprKind::And(args) => { let mut d = f.debug_tuple("And"); for arg in args { - d.field(&NamedExprView::new(arg, nimap)); + d.field(&NamedExprView::new(arg, name_interner)); } d.finish() } ExprKind::Or(args) => { let mut d = f.debug_tuple("Or"); for arg in args { - d.field(&NamedExprView::new(arg, nimap)); + d.field(&NamedExprView::new(arg, name_interner)); } d.finish() } @@ -1534,15 +1586,15 @@ pub(crate) mod format { } } - fn write_int( - int: &Integer, - nimap: &NameIntMap, + fn write_atom_id( + atom_id: &AtomId, + name_interner: &NameInterner, f: &mut fmt::Formatter<'_>, ) -> fmt::Result { - if let Some(name) = nimap.get_name(int) { + if let Some(name) = name_interner.get_name(atom_id) { fmt::Display::fmt(name, f) } else { - fmt::Debug::fmt(int, f) + fmt::Debug::fmt(atom_id, f) } } } diff --git a/crates/logic-eval/src/prove/repr.rs b/crates/logic-eval/src/prove/repr.rs index 0508d70..518fa21 100644 --- a/crates/logic-eval/src/prove/repr.rs +++ b/crates/logic-eval/src/prove/repr.rs @@ -1,4 +1,6 @@ -use crate::{prove::prover::TermAssignments, Atom, Expr, PassThroughIndexMap, Predicate, Term}; +use crate::{ + prove::proof_engine::TermVariableBindings, Atom, Expr, PassThroughIndexMap, Predicate, Term, +}; use fxhash::FxHasher; use std::{ hash::{Hash, Hasher}, @@ -1077,7 +1079,7 @@ impl iter::FusedIterator for TermViewArgs<'_, T> {} #[derive(Debug, Clone)] pub struct TermDeepView<'a, T> { pub(crate) buf: &'a [TermElem], - pub(crate) term_assigns: &'a TermAssignments, + pub(crate) unification_assignments: &'a TermVariableBindings, pub(crate) id: TermId, } @@ -1107,17 +1109,20 @@ impl<'a, T> TermDeepView<'a, T> { let end = start + view.arity() as usize; TermDeepViewArgs { buf: view.buf, - term_assigns: view.term_assigns, + unification_assignments: view.unification_assignments, start, end, } } pub(crate) fn jump(&self) -> Self { - let root = self.term_assigns.find(self.id).unwrap_or(self.id); + let root = self + .unification_assignments + .find(self.id) + .unwrap_or(self.id); Self { buf: self.buf, - term_assigns: self.term_assigns, + unification_assignments: self.unification_assignments, id: root, } } @@ -1126,7 +1131,7 @@ impl<'a, T> TermDeepView<'a, T> { #[derive(Clone)] pub(crate) struct TermDeepViewArgs<'a, T> { buf: &'a [TermElem], - term_assigns: &'a TermAssignments, + unification_assignments: &'a TermVariableBindings, /// Inclusive start: TermId, /// Exclusive @@ -1150,7 +1155,7 @@ impl<'a, T> Iterator for TermDeepViewArgs<'a, T> { self.start += 1; Some(TermDeepView { buf: self.buf, - term_assigns: self.term_assigns, + unification_assignments: self.unification_assignments, id, }) } else { diff --git a/crates/logic-eval/src/prove/table.rs b/crates/logic-eval/src/prove/table.rs index a773444..829779b 100644 --- a/crates/logic-eval/src/prove/table.rs +++ b/crates/logic-eval/src/prove/table.rs @@ -1,9 +1,9 @@ use super::{ canonical::CanonicalTermId, - prover::Integer, + proof_engine::AtomId, repr::{TermId, TermView}, }; -use crate::{prove::prover::TermAssignments, Map}; +use crate::{prove::proof_engine::TermVariableBindings, Map}; use core::ops::{Index, IndexMut}; #[derive(Debug, Default)] @@ -76,7 +76,7 @@ impl TableEntry { /// - `view` is just a variable, which does not make sense for tabling. Use a term such as /// `f(X)` instead. /// - `view` does not contain any variables, so it does not need tabling. - pub(crate) fn from_term_view(view: &TermView<'_, Integer>) -> Option { + pub(crate) fn from_term_view(view: &TermView<'_, AtomId>) -> Option { if view.is_variable() || !view.contains_variable() { return None; } @@ -97,12 +97,12 @@ impl TableEntry { } /// See [`AnswerMatrix::update`]. - pub(crate) fn update_answer(&mut self, term_assigns: &TermAssignments) { - self.seen.update(term_assigns); + pub(crate) fn update_answer(&mut self, unification_assignments: &TermVariableBindings) { + self.seen.update(unification_assignments); } - pub(crate) fn has_answer(&self, term_assigns: &TermAssignments) -> bool { - self.seen.has_answer(term_assigns) + pub(crate) fn has_answer(&self, unification_assignments: &TermVariableBindings) -> bool { + self.seen.has_answer(unification_assignments) } pub(crate) fn consumer_nodes(&self) -> impl Iterator + '_ { @@ -177,25 +177,25 @@ impl AnswerMatrix { } } - /// This method assumes that `term_assigns` has ground answers for this entry's variables. - fn update(&mut self, term_assigns: &TermAssignments) { + /// This method assumes that `unification_assignments` has ground answers for this entry's variables. + fn update(&mut self, unification_assignments: &TermVariableBindings) { self.elems.reserve_exact(self.rows); for r in 0..self.rows { let var = self.elems[r]; - let answer = term_assigns.find(var).unwrap(); + let answer = unification_assignments.find(var).unwrap(); self.elems.push(answer); } self.cols += 1; } - fn has_answer(&self, term_assigns: &TermAssignments) -> bool { + fn has_answer(&self, unification_assignments: &TermVariableBindings) -> bool { let vars = self.column(0); for col_idx in 1..self.cols { let answers = self.column(col_idx); if vars .iter() .zip(answers) - .all(|(var, answer)| term_assigns.find(*var) == Some(*answer)) + .all(|(var, answer)| unification_assignments.find(*var) == Some(*answer)) { return true; } From 4811260f9a517fdec7513209299de80c710d2344 Mon Sep 17 00:00:00 2001 From: ecoricemon Date: Tue, 19 May 2026 11:57:58 +0900 Subject: [PATCH 3/3] chore: Simplify logic-eval benchmark --- crates/logic-eval/benches/query_threads.rs | 77 ++++++++-------------- 1 file changed, 26 insertions(+), 51 deletions(-) diff --git a/crates/logic-eval/benches/query_threads.rs b/crates/logic-eval/benches/query_threads.rs index fc53e00..1d3b703 100644 --- a/crates/logic-eval/benches/query_threads.rs +++ b/crates/logic-eval/benches/query_threads.rs @@ -1,82 +1,52 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; -use logic_eval::{Atom, Clause, ClauseDataset, Database, Expr, Term}; +use logic_eval::{parse_str, ClauseDataset, Database, Expr, InternedStr, Name, StrInterner}; const NODES: usize = 240; const QUERY_COUNT: usize = 64; -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -struct BenchAtom(String); - -impl Atom for BenchAtom { - fn is_variable(&self) -> bool { - self.0.starts_with('$') - } -} - -fn sym(name: impl Into) -> BenchAtom { - BenchAtom(name.into()) -} - -fn atom(name: impl Into) -> Term { - Term::atom(sym(name)) -} - -fn compound( - name: impl Into, - args: impl IntoIterator>, -) -> Term { - Term::compound(sym(name), args) -} - -fn parent(from: usize, to: usize) -> Clause { - Clause::fact(compound("parent", [atom(node(from)), atom(node(to))])) -} +type BenchName<'a> = Name>; fn node(index: usize) -> String { format!("n{index}") } -fn build_database() -> Database { - let mut clauses = Vec::new(); +fn build_database<'a>(interner: &'a StrInterner) -> Database> { + let mut source = String::new(); // A chain makes each ancestry query produce many answers. A few skip edges add branching so the // engine does enough independent work for parallel throughput to show up in the benchmark. for i in 0..NODES - 1 { - clauses.push(parent(i, i + 1)); + source.push_str(&format!("parent({}, {}).\n", node(i), node(i + 1))); } for i in 0..NODES - 3 { if i % 7 == 0 { - clauses.push(parent(i, i + 3)); + source.push_str(&format!("parent({}, {}).\n", node(i), node(i + 3))); } } - clauses.push(Clause::rule( - compound("ancestor", [atom("$X"), atom("$Y")]), - Expr::term(compound("parent", [atom("$X"), atom("$Y")])), - )); - clauses.push(Clause::rule( - compound("ancestor", [atom("$X"), atom("$Z")]), - Expr::expr_and([ - Expr::term(compound("parent", [atom("$X"), atom("$Y")])), - Expr::term(compound("ancestor", [atom("$Y"), atom("$Z")])), - ]), - )); + source.push_str( + " + ancestor($X, $Y) :- parent($X, $Y). + ancestor($X, $Z) :- parent($X, $Y), ancestor($Y, $Z). + ", + ); + let dataset: ClauseDataset<_> = parse_str(&source, interner).unwrap(); let mut db = Database::default(); - db.insert_dataset(ClauseDataset(clauses)); + db.insert_dataset(dataset); db } -fn build_queries() -> Vec> { +fn build_queries<'a>(interner: &'a StrInterner) -> Vec>> { (0..QUERY_COUNT) .map(|i| { let root = i % (NODES / 3); - Expr::term(compound("ancestor", [atom(node(root)), atom("$Who")])) + parse_str(&format!("ancestor({}, $Who)", node(root)), interner).unwrap() }) .collect() } -fn count_answers(db: &Database, query: Expr) -> usize { +fn count_answers(db: &Database>, query: Expr>) -> usize { let mut cx = db.query(query); let mut count = 0; while let Some(answer) = cx.prove_next() { @@ -85,7 +55,7 @@ fn count_answers(db: &Database, query: Expr) -> usize { count } -fn run_serial(db: &Database, queries: &[Expr]) -> usize { +fn run_serial(db: &Database>, queries: &[Expr>]) -> usize { queries .iter() .cloned() @@ -93,7 +63,11 @@ fn run_serial(db: &Database, queries: &[Expr]) -> usize { .sum() } -fn run_threaded(db: &Database, queries: &[Expr], threads: usize) -> usize { +fn run_threaded( + db: &Database>, + queries: &[Expr>], + threads: usize, +) -> usize { let chunk_len = (queries.len() + threads - 1) / threads; std::thread::scope(|scope| { @@ -118,8 +92,9 @@ fn run_threaded(db: &Database, queries: &[Expr], threads: } fn benchmark_query_threads(c: &mut Criterion) { - let db = build_database(); - let queries = build_queries(); + let interner = StrInterner::new(); + let db = build_database(&interner); + let queries = build_queries(&interner); let expected = run_serial(&db, &queries); let mut group = c.benchmark_group("logic_eval_query_threads");