diff --git a/Cargo.toml b/Cargo.toml index c3a7bdb..5b397ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,10 +17,13 @@ comemo-macros = { version = "0.4.0", path = "macros" } once_cell = "1.18" parking_lot = "0.12" proc-macro2 = "1" +quickcheck = "1" +quickcheck_macros = "1" quote = "1" rustc-hash = "2.1" serial_test = "3" siphasher = "1" +slab = "0.4" syn = { version = "2", features = ["full"] } [package] @@ -45,8 +48,11 @@ comemo-macros = { workspace = true, optional = true } parking_lot = { workspace = true } rustc-hash = { workspace = true } siphasher = { workspace = true } +slab = { workspace = true } [dev-dependencies] +quickcheck = { workspace = true } +quickcheck_macros = { workspace = true } serial_test = { workspace = true } [[test]] diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8430f42..4329290 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -39,8 +39,7 @@ use syn::{Error, Result, parse_quote}; /// - _Mutably tracked:_ The argument is of the form `TrackedMut`. Through /// this type, you can safely mutate an argument from within a memoized /// function. If there is a cache hit, comemo will replay all mutations. -/// Mutable tracked methods can also have return values that are tracked just -/// like immutable methods. +/// Mutable tracked methods cannot have return values. /// /// # Restrictions /// The following restrictions apply to memoized functions: @@ -50,11 +49,40 @@ use syn::{Error, Result, parse_quote}; /// expose to the hasher**. Otherwise, memoized results might get reused /// invalidly. /// -/// - The **only obversable impurity memoized functions may exhibit are +/// - The **only observable impurity memoized functions may exhibit are /// mutations through `TrackedMut` arguments.** Comemo stops you from using /// basic mutable arguments, but it cannot determine all sources of impurity, /// so this is your responsibility. /// +/// - Memoized functions must **call tracked methods in _reorderably +/// deterministic_ fashion.** Consider two executions A and B of a memoized +/// function. We define the following two properties: +/// +/// - _In-order deterministic:_ If the first N tracked calls and their results +/// are the same in A and B, then the N+1th call must also be the same. This +/// is a fairly natural property as far as deterministic functions go, as, +/// if the first N calls and their results were the same across two +/// execution, the available information for choosing the N+1th call is the +/// same. However, this property is a bit too restrictive in practice. For +/// instance, a function that internally uses multi-threading may call +/// tracked methods out-of-order while still producing a deterministic +/// result. +/// +/// - _Reorderably deterministic:_ If, for the first N calls in A, B has +/// matching calls (same arguments, same return value) somewhere in its call +/// sequence, then the N+1th call invoked by A must also occur _somewhere_ +/// in the call sequence of B. This is a somewhat relaxed version of +/// in-order determinism that still allows comemo to perform internal +/// optimizations while permitting memoization of many more functions (e.g. +/// ones that use internal multi-threading in an outwardly deterministic +/// fashion). +/// +/// Reorderable determinism is necessary for efficient cache lookups. If a +/// memoized function is not reorderably determinstic, comemo may panic in +/// debug mode to bring your attention to this. Meanwhile, in release mode, +/// memoized functions will still yield correct results, but caching may prove +/// ineffective. +/// /// - The output of a memoized function must be `Send` and `Sync` because it is /// stored in the global cache. /// @@ -126,10 +154,6 @@ pub fn memoize(args: BoundaryStream, stream: BoundaryStream) -> BoundaryStream { /// arguments, tracking is the only option, so that comemo can replay the side /// effects when there is a cache hit. /// -/// If you attempt to track any mutable methods, your type must implement -/// [`Clone`] so that comemo can roll back attempted mutations which did not -/// result in a cache hit. -/// /// # Restrictions /// Tracked impl blocks or traits may not be generic and may only contain /// methods. Just like with memoized functions, certain restrictions apply to @@ -147,6 +171,11 @@ pub fn memoize(args: BoundaryStream, stream: BoundaryStream) -> BoundaryStream { /// [`Hash`](std::hash::Hash) and **must feed all the information they expose /// to the hasher**. Otherwise, memoized results might get reused invalidly. /// +/// - Mutable tracked methods must not have a return value. +/// +/// - A tracked implementation cannot have a mix of mutable and immutable +/// methods. +/// /// - The arguments to a tracked method must be `Send` and `Sync` because they /// are stored in the global cache. /// diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index a8f58f2..bd431b9 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -148,7 +148,7 @@ fn process(function: &Function) -> Result { wrapped.block = parse_quote! { { static __CACHE: ::comemo::internal::Cache< - <::comemo::internal::Multi<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, + <::comemo::internal::Multi<#arg_ty_tuple> as ::comemo::internal::Input>::Call, #output, > = ::comemo::internal::Cache::new(|| { ::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age)); @@ -160,7 +160,7 @@ fn process(function: &Function) -> Result { ::comemo::internal::memoize( &__CACHE, ::comemo::internal::Multi(#arg_tuple), - &::core::default::Default::default(), + &mut ::core::default::Default::default(), #enabled, #closure, ) diff --git a/macros/src/track.rs b/macros/src/track.rs index 1cd8ade..3194026 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -42,6 +42,13 @@ pub fn expand(item: &syn::Item) -> Result { _ => bail!(item, "`track` can only be applied to impl blocks and traits"), }; + if methods.iter().any(|m| m.mutable) && methods.iter().any(|m| !m.mutable) { + bail!( + item, + "`track` cannot be applied to a mix of mutable and immutable methods" + ); + } + // Produce the necessary items for the type to become trackable. let variants = create_variants(&methods); let scope = create(&ty, generics, trait_, &methods)?; @@ -168,6 +175,12 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result bail!(ty, "tracked methods cannot return mutable references"); } + if let syn::ReturnType::Type(_, ty) = &sig.output + && receiver.mutability.is_some() + { + bail!(ty, "mutable tracked methods cannot have a return value"); + } + Ok(Method { vis, sig: sig.clone(), @@ -225,11 +238,6 @@ fn create( let t: syn::GenericParam = parse_quote! { '__comemo_tracked }; let r: syn::GenericParam = parse_quote! { '__comemo_retrack }; let d: syn::GenericParam = parse_quote! { '__comemo_dynamic }; - let maybe_cloned = if methods.iter().any(|it| it.mutable) { - quote! { ::core::clone::Clone::clone(self) } - } else { - quote! { self } - }; // Prepare generics. let (impl_gen, type_gen, where_clause) = generics.split_for_impl(); @@ -245,37 +253,9 @@ fn create( impl_params_t.params.push(t.clone()); type_params_t.params.push(t.clone()); - // Prepare validations. let prefix = trait_.as_ref().map(|name| quote! { #name for }); - let validations: Vec<_> = methods.iter().map(create_validation).collect(); - let validate = if !methods.is_empty() { - quote! { - let mut this = #maybe_cloned; - constraint.validate(|call| match &call.0 { #(#validations,)* }) - } - } else { - quote! { true } - }; - let validate_with_id = if !methods.is_empty() { - quote! { - let mut this = #maybe_cloned; - constraint.validate_with_id( - |call| match &call.0 { #(#validations,)* }, - id, - ) - } - } else { - quote! { true } - }; - - // Prepare replying. - let immutable = methods.iter().all(|m| !m.mutable); - let replays = methods.iter().map(create_replay); - let replay = (!immutable).then(|| { - quote! { - constraint.replay(|call| match &call.0 { #(#replays,)* }); - } - }); + let calls: Vec<_> = methods.iter().map(create_call).collect(); + let calls_mut: Vec<_> = methods.iter().map(create_call_mut).collect(); // Prepare variants and wrapper methods. let wrapper_methods = methods @@ -284,32 +264,18 @@ fn create( .map(|m| create_wrapper(m, false)); let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true)); - let constraint = if immutable { - quote! { ImmutableConstraint } - } else { - quote! { MutableConstraint } - }; - Ok(quote! { - impl #impl_params ::comemo::Track for #ty #where_clause {} - - impl #impl_params ::comemo::Validate for #ty #where_clause { - type Constraint = ::comemo::internal::#constraint<__ComemoCall>; + impl #impl_params ::comemo::Track for #ty #where_clause { + type Call = __ComemoCall; #[inline] - fn validate(&self, constraint: &Self::Constraint) -> bool { - #validate + fn call(&self, call: &Self::Call) -> u128 { + match call.0 { #(#calls,)* } } #[inline] - fn validate_with_id(&self, constraint: &Self::Constraint, id: usize) -> bool { - #validate_with_id - } - - #[inline] - #[allow(unused_variables)] - fn replay(&mut self, constraint: &Self::Constraint) { - #replay + fn call_mut(&mut self, call: &Self::Call) { + match call.0 { #(#calls_mut,)* } } } @@ -363,41 +329,50 @@ fn create( }) } -/// Produce a constraint validation for a method. +/// Produce a call enum variant for a method. fn create_variant(method: &Method) -> TokenStream { let name = &method.sig.ident; let types = &method.types; quote! { #name(#(<#types as ::std::borrow::ToOwned>::Owned),*) } } -/// Produce a constraint validation for a method. -fn create_validation(method: &Method) -> TokenStream { +/// Produce a call branch for a method. +fn create_call(method: &Method) -> TokenStream { let name = &method.sig.ident; let args = &method.args; let prepared = method.args.iter().zip(&method.kinds).map(|(arg, kind)| match kind { Kind::Normal => quote! { #arg.to_owned() }, Kind::Reference => quote! { #arg }, }); - quote! { - __ComemoVariant::#name(#(#args),*) - => ::comemo::internal::hash(&this.#name(#(#prepared),*)) + if method.mutable { + quote! { + __ComemoVariant::#name(..) => 0 + } + } else { + quote! { + __ComemoVariant::#name(#(ref #args),*) + => ::comemo::internal::hash(&self.#name(#(#prepared),*)) + } } } -/// Produce a constraint validation for a method. -fn create_replay(method: &Method) -> TokenStream { +/// Produce a mutable call branch for a method. +fn create_call_mut(method: &Method) -> TokenStream { let name = &method.sig.ident; let args = &method.args; let prepared = method.args.iter().zip(&method.kinds).map(|(arg, kind)| match kind { Kind::Normal => quote! { #arg.to_owned() }, Kind::Reference => quote! { #arg }, }); - let body = method.mutable.then(|| { + if method.mutable { quote! { - self.#name(#(#prepared),*); + __ComemoVariant::#name(#(ref #args),*) => self.#name(#(#prepared),*) } - }); - quote! { __ComemoVariant::#name(#(#args),*) => { #body } } + } else { + quote! { + __ComemoVariant::#name(..) => {} + } + } } /// Produce a wrapped surface method. @@ -417,16 +392,19 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream { #[track_caller] #[inline] #vis #sig { - let __comemo_variant = __ComemoVariant::#name(#(#args.to_owned()),*); - let (__comemo_value, __comemo_constraint) = ::comemo::internal::#to_parts; - let output = __comemo_value.#name(#(#args,)*); - if let Some(constraint) = __comemo_constraint { - constraint.push( + let (__comemo_value, __comemo_sink) = ::comemo::internal::#to_parts; + if let Some(__comemo_sink) = __comemo_sink { + let __comemo_variant = __ComemoVariant::#name(#(#args.to_owned()),*); + let output = __comemo_value.#name(#(#args,)*); + ::comemo::internal::Sink::emit( + __comemo_sink, __ComemoCall(__comemo_variant), ::comemo::internal::hash(&output), ); + output + } else { + __comemo_value.#name(#(#args,)*) } - output } } } diff --git a/src/constraint.rs b/src/constraint.rs index 711c0cf..6cfe574 100644 --- a/src/constraint.rs +++ b/src/constraint.rs @@ -1,292 +1,166 @@ -use std::borrow::Cow; use std::collections::hash_map::Entry; use std::hash::Hash; -use parking_lot::RwLock; +use parking_lot::Mutex; use rustc_hash::FxHashMap; -use crate::accelerate; +use crate::Track; +use crate::track::{Call, Sink}; -/// A call to a tracked function. -pub trait Call: Hash + PartialEq + Clone { - /// Whether the call is mutable. - fn is_mutable(&self) -> bool; -} +/// Records calls performed on a trackable type. +/// +/// Allows to validate that a different instance of the trackable type yields +/// the same outputs for the recorded calls. +/// +/// The constraint can be hooked up to a tracked type through +/// [`Track::track_with`]. +pub struct Constraint(Mutex>); -/// A constraint entry for a single call. -#[derive(Clone)] -struct ConstraintEntry { - call: T, - call_hash: u128, - ret_hash: u128, +/// The internal representation of a [`Constraint`]. +struct ConstraintRepr { + /// The immutable calls, ready for integration into a call tree. + immutable: CallSequence, + /// The mutable calls, for insertion as part of the call tree output value. + mutable: Vec, } -/// Defines a constraint for an immutably tracked type. -pub struct ImmutableConstraint(RwLock>); - -impl ImmutableConstraint { - /// Create an empty constraint. +impl Constraint { + /// Creates a new constraint. pub fn new() -> Self { Self::default() } - /// Enter a constraint for a call to an immutable function. - #[inline] - pub fn push(&self, call: T, ret_hash: u128) { - let call_hash = crate::hash::hash(&call); - let entry = ConstraintEntry { call, call_hash, ret_hash }; - self.0.write().push_inner(Cow::Owned(entry)); - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate(&self, mut f: F) -> bool - where - F: FnMut(&T) -> u128, - { - self.0.read().0.values().all(|entry| f(&entry.call) == entry.ret_hash) - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + /// Checks whether the given value matches the constraints by invoking the + /// recorded calls one by one. + pub fn validate(&self, value: &T) -> bool where - F: FnMut(&T) -> u128, + T: Track, { - let guard = self.0.read(); - if let Some(accelerator) = accelerate::get(id) { - let mut map = accelerator.lock(); - guard.0.values().all(|entry| { - *map.entry(entry.call_hash).or_insert_with(|| f(&entry.call)) - == entry.ret_hash - }) - } else { - guard.0.values().all(|entry| f(&entry.call) == entry.ret_hash) - } - } - - /// Replay all input-output pairs. - #[inline] - pub fn replay(&self, _: F) - where - F: FnMut(&T), - { - #[cfg(debug_assertions)] - for entry in self.0.read().0.values() { - assert!(!entry.call.is_mutable()); - } + self.0 + .lock() + .immutable + .vec + .iter() + .filter_map(|x| x.as_ref()) + .all(|(call, ret)| value.call(call) == *ret) } -} -impl Clone for ImmutableConstraint { - fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().clone())) + /// Takes out the immutable and mutable calls, for insertion into a call + /// tree. + pub(crate) fn take(&self) -> (CallSequence, Vec) { + let mut inner = self.0.lock(); + (std::mem::take(&mut inner.immutable), std::mem::take(&mut inner.mutable)) } } -impl Default for ImmutableConstraint { +impl Default for Constraint { fn default() -> Self { - Self(RwLock::new(EntryMap::default())) + Self(Mutex::new(ConstraintRepr { + immutable: CallSequence::new(), + mutable: Vec::new(), + })) } } -/// Defines a constraint for a mutably tracked type. -pub struct MutableConstraint(RwLock>); - -impl MutableConstraint { - /// Create an empty constraint. - pub fn new() -> Self { - Self::default() - } - - /// Enter a constraint for a call to a mutable function. - #[inline] - pub fn push(&self, call: T, ret_hash: u128) { - let call_hash = crate::hash::hash(&call); - let entry = ConstraintEntry { call, call_hash, ret_hash }; - self.0.write().push_inner(Cow::Owned(entry)); - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate(&self, mut f: F) -> bool - where - F: FnMut(&T) -> u128, - { - self.0.read().0.iter().all(|entry| f(&entry.call) == entry.ret_hash) - } - - /// Whether the method satisfies as all input-output pairs. - /// - /// On mutable tracked types, this does not use an accelerator as it is - /// rarely, if ever used. Therefore, it is not worth the overhead. - #[inline] - pub fn validate_with_id(&self, f: F, _: usize) -> bool - where - F: FnMut(&T) -> u128, - { - self.validate(f) - } +impl Sink for Constraint { + type Call = C; - /// Replay all input-output pairs. - #[inline] - pub fn replay(&self, mut f: F) - where - F: FnMut(&T), - { - for entry in &self.0.read().0 { - if entry.call.is_mutable() { - f(&entry.call); - } + fn emit(&self, call: C, ret: u128) -> bool { + let mut inner = self.0.lock(); + if call.is_mutable() { + inner.mutable.push(call); + true + } else { + inner.immutable.insert(call, ret) } } } -impl Clone for MutableConstraint { - fn clone(&self) -> Self { - Self(RwLock::new(self.0.read().clone())) - } +/// A deduplicated sequence of calls to tracked functions, optimized for +/// efficient insertion into a call tree. +pub struct CallSequence { + /// The raw calls. In order, but deduplicated via the `map`. + vec: Vec>, + /// A map from hashes of calls to the indices in the vector. + map: FxHashMap, + /// A cursor for iteration in `Self::next`. + cursor: usize, } -impl Default for MutableConstraint { - fn default() -> Self { - Self(RwLock::new(EntryVec::default())) - } -} - -/// A map of calls. -#[derive(Clone)] -struct EntryMap(FxHashMap>); - -impl EntryMap { - /// Enter a constraint for a call to a function. - #[inline] - fn push_inner(&mut self, entry: Cow>) { - match self.0.entry(entry.call_hash) { - Entry::Occupied(_occupied) => { - #[cfg(debug_assertions)] - check(_occupied.get(), &entry); - } - Entry::Vacant(vacant) => { - vacant.insert(entry.into_owned()); - } +impl CallSequence { + /// Creates an empty sequence. + pub fn new() -> Self { + Self { + vec: Vec::new(), + map: FxHashMap::default(), + cursor: 0, } } } -impl Default for EntryMap { - fn default() -> Self { - Self(FxHashMap::default()) - } -} - -/// A list of calls. -/// -/// Order matters here, as those are mutable & immutable calls. -#[derive(Clone)] -struct EntryVec(Vec>); - -impl EntryVec { - /// Enter a constraint for a call to a function. - #[inline] - fn push_inner(&mut self, entry: Cow>) { - // If the call is immutable check whether we already have a call - // with the same arguments and return value. - if !entry.call.is_mutable() { - for prev in self.0.iter().rev() { - if entry.call.is_mutable() { - break; - } - - if entry.call_hash == prev.call_hash && entry.ret_hash == prev.ret_hash { - #[cfg(debug_assertions)] - check(&entry, prev); - return; +impl CallSequence { + /// Inserts a pair of a call and its return hash. + /// + /// Returns true when the pair was indeed inserted and false if the call was + /// deduplicated. + pub fn insert(&mut self, call: C, ret: u128) -> bool { + match self.map.entry(crate::hash::hash(&call)) { + Entry::Vacant(entry) => { + let i = self.vec.len(); + self.vec.push(Some((call, ret))); + entry.insert(i); + true + } + #[allow(unused_variables)] + Entry::Occupied(entry) => { + #[cfg(debug_assertions)] + if let Some((_, ret2)) = &self.vec[*entry.get()] { + if ret != *ret2 { + panic!( + "comemo: found differing return values. \ + is there an impure tracked function?" + ) + } } + false } } - - // Insert the call into the call list. - self.0.push(entry.into_owned()); } -} -impl Default for EntryVec { - fn default() -> Self { - Self(Vec::new()) - } -} - -/// Extend an outer constraint by an inner one. -pub trait Join { - /// Join this constraint with the `inner` one. - fn join(&self, inner: &T); - - /// Take out the constraint. - fn take(&self) -> Self; -} - -impl Join for Option<&T> { - #[inline] - fn join(&self, inner: &T) { - if let Some(outer) = self { - outer.join(inner); + /// Retrieves the next call in order. + pub fn next(&mut self) -> Option<(C, u128)> { + while self.cursor < self.vec.len() { + if let Some(pair) = self.vec[self.cursor].take() { + return Some(pair); + } + self.cursor += 1; } + None } - #[inline] - fn take(&self) -> Self { - unimplemented!("cannot call `Join::take` on optional constraint") + /// Retrieves the return hash of an arbitrary upcoming call. Removes the + /// call from the sequence; it will not be yielded by `next()` anymore. + pub fn extract(&mut self, call: &C) -> Option { + let h = crate::hash::hash(&call); + let i = *self.map.get(&h)?; + let res = self.vec[i].take().map(|(_, ret)| ret); + debug_assert!(self.cursor <= i || res.is_none()); + res } } -impl Join for ImmutableConstraint { - #[inline] - fn join(&self, inner: &Self) { - let mut this = self.0.write(); - for entry in inner.0.read().0.values() { - this.push_inner(Cow::Borrowed(entry)); - } - } - - #[inline] - fn take(&self) -> Self { - Self(RwLock::new(std::mem::take(&mut *self.0.write()))) +impl Default for CallSequence { + fn default() -> Self { + Self::new() } } -impl Join for MutableConstraint { - #[inline] - fn join(&self, inner: &Self) { - let mut this = self.0.write(); - for entry in inner.0.read().0.iter() { - this.push_inner(Cow::Borrowed(entry)); +impl FromIterator<(C, u128)> for CallSequence { + fn from_iter>(iter: T) -> Self { + let mut seq = CallSequence::new(); + for (call, ret) in iter { + seq.insert(call, ret); } - } - - #[inline] - fn take(&self) -> Self { - Self(RwLock::new(std::mem::take(&mut *self.0.write()))) - } -} - -/// Check for a constraint violation. -#[inline] -#[track_caller] -#[allow(dead_code)] -fn check(lhs: &ConstraintEntry, rhs: &ConstraintEntry) { - if lhs.ret_hash != rhs.ret_hash { - panic!( - "comemo: found conflicting constraints. \ - is this tracked function pure?" - ) - } - - // Additional checks for debugging. - if lhs.call_hash != rhs.call_hash || lhs.call != rhs.call { - panic!( - "comemo: found conflicting `check` arguments. \ - this is a bug in comemo" - ) + seq } } diff --git a/src/input.rs b/src/input.rs index 9417498..da3ff4d 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,7 +1,6 @@ use std::hash::{Hash, Hasher}; -use crate::constraint::Join; -use crate::track::{Track, Tracked, TrackedMut, Validate}; +use crate::track::{Call, Sink, Track, Tracked, TrackedMut}; /// Ensure a type is suitable as input. #[inline] @@ -12,32 +11,39 @@ pub fn assert_hashable_or_trackable<'a, In: Input<'a>>(_: &In) {} /// This is implemented for hashable types, `Tracked<_>` types and `Args<(...)>` /// types containing tuples up to length twelve. pub trait Input<'a> { - /// The constraints for this input. - type Constraint: Default + Clone + Join + 'static; + /// An enumeration of possible tracked calls that can be performed on any + /// tracked part of this input. + type Call: Call; - /// The extracted outer constraints. - type Outer: Join; + /// Storage for a sink that combines the input's existing sink(s) + /// (if any) with `S`. + type Storage + 'a>: Default; - /// Hash the key parts of the input. + /// Hashes the key (i.e. not tracked) parts of the input. fn key(&self, state: &mut H); - /// Validate the tracked parts of the input. - fn validate(&self, constraint: &Self::Constraint) -> bool; + /// Performs a call on a tracked part of the input and returns the hash of + /// the result. + /// + /// If the call is mutable, the side effect will not be observable. + fn call(&self, call: &Self::Call) -> u128; - /// Replay mutations to the input. - fn replay(&mut self, constraint: &Self::Constraint); + /// Performs a mutable call on a tracked part of the input. + /// Mutable calls cannot have a return value and are only executed for their + /// side effect. As such, this function does not return a result hash. + fn call_mut(&mut self, call: &Self::Call); - /// Hook up the given constraint to the tracked parts of the input and - /// return the result alongside the outer constraints. - fn retrack(self, constraint: &'a Self::Constraint) -> (Self, Self::Outer) + /// Integrates the given sink into the tracked parts of the input, using + /// the external storage to store the new combined sink with lifetime 'a. + fn attach(&mut self, storage: &'a mut Self::Storage, sink: S) where - Self: Sized; + S: Sink + Copy + 'a; } impl<'a, T: Hash> Input<'a> for T { - // No constraint for hashed inputs. - type Constraint = (); - type Outer = (); + // No sink for hashed inputs. + type Call = (); + type Storage + 'a> = (); #[inline] fn key(&self, state: &mut H) { @@ -45,16 +51,19 @@ impl<'a, T: Hash> Input<'a> for T { } #[inline] - fn validate(&self, _: &()) -> bool { - true + fn call(&self, _: &Self::Call) -> u128 { + // No calls on hashed inputs. + 0 } #[inline] - fn replay(&mut self, _: &Self::Constraint) {} + fn call_mut(&mut self, _: &Self::Call) {} #[inline] - fn retrack(self, _: &'a ()) -> (Self, Self::Outer) { - (self, ()) + fn attach(&mut self, _: &'a mut Self::Storage, _: S) + where + S: Sink + Copy + 'a, + { } } @@ -62,29 +71,47 @@ impl<'a, T> Input<'a> for Tracked<'a, T> where T: Track + ?Sized, { - // Forward constraint from `Trackable` implementation. - type Constraint = ::Constraint; - type Outer = Option<&'a Self::Constraint>; + type Call = T::Call; + type Storage + 'a> = Option>; #[inline] fn key(&self, _: &mut H) {} #[inline] - fn validate(&self, constraint: &Self::Constraint) -> bool { - self.value.validate_with_id(constraint, self.id) + fn call(&self, call: &Self::Call) -> u128 { + let hash = if let Some(accelerator) = crate::accelerate::get(self.id) { + // When we have an accelerator for this tracked instance, we might + // already have a cached value of the return hash. Then, we don't + // need to actually perform the call. + let mut map = accelerator.lock(); + let call_hash = crate::hash::hash(call); + *map.entry(call_hash).or_insert_with(|| self.value.call(call)) + } else { + self.value.call(call) + }; + + // The `call` method is used during the constraint validation tree + // traversal. It's crucial that we also send calls to the outer sink + // here so that the outer sink observes the calls when we have a cache + // hit. We do _not_ replay the constraints in another way. + if let Some(sink) = self.sink { + sink.emit(call.clone(), hash); + } + + hash } #[inline] - fn replay(&mut self, _: &Self::Constraint) {} + fn call_mut(&mut self, _: &Self::Call) { + // Cannot perform a mutable call on an immutable reference. + } #[inline] - fn retrack(self, constraint: &'a Self::Constraint) -> (Self, Self::Outer) { - let tracked = Tracked { - value: self.value, - constraint: Some(constraint), - id: self.id, - }; - (tracked, self.constraint) + fn attach(&mut self, storage: &'a mut Self::Storage, sink: S) + where + S: Sink + Copy + 'a, + { + self.sink = Some(storage.insert(MergedSink { prev: self.sink, sink })); } } @@ -92,27 +119,60 @@ impl<'a, T> Input<'a> for TrackedMut<'a, T> where T: Track + ?Sized, { - // Forward constraint from `Trackable` implementation. - type Constraint = T::Constraint; - type Outer = Option<&'a Self::Constraint>; + type Call = T::Call; + type Storage + 'a> = Option>; #[inline] fn key(&self, _: &mut H) {} #[inline] - fn validate(&self, constraint: &Self::Constraint) -> bool { - self.value.validate(constraint) + fn call(&self, call: &Self::Call) -> u128 { + let hash = self.value.call(call); + if let Some(sink) = self.sink { + sink.emit(call.clone(), hash); + } + hash } #[inline] - fn replay(&mut self, constraint: &Self::Constraint) { - self.value.replay(constraint); + fn call_mut(&mut self, call: &Self::Call) { + self.value.call_mut(call); + if let Some(sink) = self.sink { + sink.emit(call.clone(), 0); + } } #[inline] - fn retrack(self, constraint: &'a Self::Constraint) -> (Self, Self::Outer) { - let tracked = TrackedMut { value: self.value, constraint: Some(constraint) }; - (tracked, self.constraint) + fn attach(&mut self, storage: &'a mut Self::Storage, sink: S) + where + S: Sink + Copy + 'a, + { + self.sink = Some(storage.insert(MergedSink { prev: self.sink, sink })); + } +} + +/// Combines an existing sink from a tracked type with `S`. +#[derive(Copy, Clone)] +pub struct MergedSink<'a, S: Sink> { + prev: Option<&'a dyn Sink>, + sink: S, +} + +impl<'a, C, S> Sink for MergedSink<'a, S> +where + C: Call, + S: Sink, +{ + type Call = C; + + fn emit(&self, call: C, ret: u128) -> bool { + if let Some(prev) = self.prev { + // If the current sink already deduplicated the value, we don't have + // to go the previous sink in the first place. + self.sink.emit(call.clone(), ret) && prev.emit(call, ret) + } else { + self.sink.emit(call, ret) + } } } @@ -120,11 +180,11 @@ where pub struct Multi(pub T); macro_rules! multi { - ($($param:tt $alt:tt $idx:tt ),*) => { - #[allow(unused_variables, non_snake_case)] + (@inner $($param:ident $alt:ident $idx:tt),*; $params:tt) => { impl<'a, $($param: Input<'a>),*> Input<'a> for Multi<($($param,)*)> { - type Constraint = ($($param::Constraint,)*); - type Outer = ($($param::Outer,)*); + type Call = MultiCall<$($param::Call),*>; + type Storage + 'a> = + ($($param::Storage>,)*); #[inline] fn key(&self, state: &mut T) { @@ -132,38 +192,66 @@ macro_rules! multi { } #[inline] - fn validate(&self, constraint: &Self::Constraint) -> bool { - true $(&& (self.0).$idx.validate(&constraint.$idx))* + fn call(&self, call: &Self::Call) -> u128 { + match *call { + $(MultiCall::$param(ref $param) => (self.0).$idx.call($param)),* + } } #[inline] - fn replay(&mut self, constraint: &Self::Constraint) { - $((self.0).$idx.replay(&constraint.$idx);)* + fn call_mut(&mut self, call: &Self::Call) { + match *call { + $(MultiCall::$param(ref $param) => (self.0).$idx.call_mut($param)),* + } } #[inline] - fn retrack( - self, - constraint: &'a Self::Constraint, - ) -> (Self, Self::Outer) { - $(let $param = (self.0).$idx.retrack(&constraint.$idx);)* - (Multi(($($param.0,)*)), ($($param.1,)*)) + fn attach(&mut self, storage: &'a mut Self::Storage, sink: S) + where + S: Sink + Copy + 'a { + $((self.0).$idx.attach(&mut storage.$idx, MappedSink::<$idx, _>(sink));)* } } - #[allow(unused_variables, clippy::unused_unit)] - impl<$($param: Join<$alt>, $alt),*> Join<($($alt,)*)> for ($($param,)*) { + #[derive(PartialEq, Clone, Hash)] + pub enum MultiCall<$($param),*> { + $($param($param),)* + } + + impl<$($param: Call),*> Call for MultiCall<$($param),*> { #[inline] - fn join(&self, constraint: &($($alt,)*)) { - $(self.$idx.join(&constraint.$idx);)* + fn is_mutable(&self) -> bool { + match *self { + $(Self::$param(ref $param) => $param.is_mutable(),)* + } } + } - #[inline] - fn take(&self) -> Self { - ($(self.$idx.take(),)*) + #[derive(Copy, Clone)] + pub struct MappedSink(S); + + $(multi!(@mapped $param $idx; $params);)* + }; + + (@mapped $pick:ident $idx:tt; ($($param:ident),*)) => { + impl Sink for MappedSink<$idx, S> + where + S: Sink>, + { + type Call = $pick; + + fn emit(&self, call: $pick, ret: u128) -> bool { + self.0.emit(MultiCall::$pick(call), ret) } } }; + + ($($param:ident $alt:ident $idx:tt),*) => { + #[allow(unused_variables, clippy::unused_unit, non_snake_case)] + const _: () = { + multi!(@inner $($param $alt $idx),*; ($($param),*)); + }; + }; } multi! {} diff --git a/src/lib.rs b/src/lib.rs index e9cd961..90d85cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,13 +94,15 @@ mod hash; mod input; mod memoize; mod track; +mod tree; #[cfg(feature = "testing")] pub mod testing; +pub use crate::constraint::Constraint; pub use crate::hash::Prehashed; pub use crate::memoize::evict; -pub use crate::track::{Track, Tracked, TrackedMut, Validate}; +pub use crate::track::{Track, Tracked, TrackedMut}; #[cfg(feature = "macros")] pub use comemo_macros::{memoize, track}; @@ -108,11 +110,10 @@ pub use comemo_macros::{memoize, track}; /// These are implementation details. Do not rely on them! #[doc(hidden)] pub mod internal { - pub use parking_lot::RwLock; - - pub use crate::constraint::{Call, ImmutableConstraint, MutableConstraint}; pub use crate::hash::hash; pub use crate::input::{Input, Multi, assert_hashable_or_trackable}; - pub use crate::memoize::{Cache, CacheData, memoize, register_evictor}; - pub use crate::track::{Surfaces, to_parts_mut_mut, to_parts_mut_ref, to_parts_ref}; + pub use crate::memoize::{Cache, memoize, register_evictor}; + pub use crate::track::{ + Call, Sink, Surfaces, to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, + }; } diff --git a/src/memoize.rs b/src/memoize.rs index a09c5ef..1cb9e16 100644 --- a/src/memoize.rs +++ b/src/memoize.rs @@ -2,21 +2,29 @@ use std::sync::LazyLock; use std::sync::atomic::{AtomicUsize, Ordering}; use parking_lot::RwLock; -use rustc_hash::FxHashMap; use siphasher::sip128::{Hasher128, SipHasher13}; use crate::accelerate; -use crate::constraint::Join; +use crate::constraint::Constraint; use crate::input::Input; +use crate::track::Call; +use crate::tree::{CallTree, InsertError}; /// The global list of eviction functions. static EVICTORS: RwLock> = RwLock::new(Vec::new()); -/// Execute a function or use a cached result for it. +/// Executes a function, trying to use a cached result for it. +#[allow(clippy::type_complexity)] pub fn memoize<'a, In, Out, F>( - cache: &Cache, + cache: &Cache, mut input: In, - constraint: &'a In::Constraint, + // These values must come from outside so that they have a lifetime that + // allows them to be attached to the `input`. On the call site, they are + // simply initialized as `&mut Default::default()`. + (storage, constraint): &'a mut ( + In::Storage<&'a Constraint>, + Constraint, + ), enabled: bool, func: F, ) -> Out @@ -44,34 +52,38 @@ where }; // Check if there is a cached output. - let borrow = cache.0.read(); - if let Some((constrained, value)) = borrow.lookup::(key, &input) { - // Replay the mutations. - input.replay(constrained); - - // Add the cached constraints to the outer ones. - input.retrack(constraint).1.join(constrained); + if let Some(entry) = cache.0.read().lookup(key, &input) { + // Replay mutations. + for call in &entry.mutable { + input.call_mut(call); + } #[cfg(feature = "testing")] crate::testing::register_hit(); - return value.clone(); + return entry.output.clone(); } - // Release the borrow so that nested memoized calls can access the - // cache without dead locking. - drop(borrow); + // Attach the constraint. + input.attach(storage, constraint); - // Execute the function with the new constraints hooked in. - let (input, outer) = input.retrack(constraint); + // Execute the function with the constraint attached. let output = func(input); - // Add the new constraints to the outer ones. - outer.join(constraint); - // Insert the result into the cache. - let mut borrow = cache.0.write(); - borrow.insert::(key, constraint.take(), output.clone()); + match cache.0.write().insert(key, constraint, output.clone()) { + Ok(()) => {} + Err(InsertError::AlreadyExists) => { + // A concurrent call with the same arguments may have inserted + // a value in the meantime. That's okay. + } + Err(InsertError::MissingCall) => { + // A missing call indicates a bug from a comemo user. See the + // documentation for `InsertError::MissingCall` for more details. + #[cfg(debug_assertions)] + panic!("comemo: memoized function is non-deterministic"); + } + } #[cfg(feature = "testing")] crate::testing::register_miss(); @@ -113,86 +125,68 @@ impl Cache { /// Evict all entries whose age is larger than or equal to `max_age`. pub fn evict(&self, max_age: usize) { - self.0.write().evict(max_age) + self.0.write().evict(max_age); } } /// The internal data for a cache. pub struct CacheData { /// Maps from hashes to memoized results. - entries: FxHashMap>>, + tree: CallTree>, +} + +/// A memoized result. +struct CacheEntry { + /// The memoized function's output. + output: Out, + /// Mutable tracked calls that must be replayed. + mutable: Vec, + /// How many evictions have passed since the entry has last been used. + age: AtomicUsize, } impl CacheData { /// Evict all entries whose age is larger than or equal to `max_age`. fn evict(&mut self, max_age: usize) { - self.entries.retain(|_, entries| { - entries.retain_mut(|entry| { - let age = entry.age.get_mut(); - *age += 1; - *age <= max_age - }); - !entries.is_empty() + self.tree.retain(|entry| { + let age = entry.age.get_mut(); + *age += 1; + *age <= max_age }); } /// Look for a matching entry in the cache. - fn lookup<'a, In>(&self, key: u128, input: &In) -> Option<(&In::Constraint, &Out)> + fn lookup<'a, In>(&self, key: u128, input: &In) -> Option<&CacheEntry> where - In: Input<'a, Constraint = C>, + C: Call, + In: Input<'a, Call = C>, { - self.entries - .get(&key)? - .iter() - .rev() - .find_map(|entry| entry.lookup::(input)) + self.tree + .get(key, |c| input.call(c)) + .inspect(|entry| entry.age.store(0, Ordering::SeqCst)) } /// Insert an entry into the cache. - fn insert<'a, In>(&mut self, key: u128, constraint: In::Constraint, output: Out) + fn insert( + &mut self, + key: u128, + constraint: &Constraint, + output: Out, + ) -> Result<(), InsertError> where - In: Input<'a, Constraint = C>, + C: Call, { - self.entries - .entry(key) - .or_default() - .push(CacheEntry::new::(constraint, output)); + let (immutable, mutable) = constraint.take(); + self.tree.insert( + key, + immutable, + CacheEntry { output, mutable, age: AtomicUsize::new(0) }, + ) } } impl Default for CacheData { fn default() -> Self { - Self { entries: FxHashMap::default() } - } -} - -/// A memoized result. -struct CacheEntry { - /// The memoized function's constraint. - constraint: C, - /// The memoized function's output. - output: Out, - /// How many evictions have passed since the entry has been last used. - age: AtomicUsize, -} - -impl CacheEntry { - /// Create a new entry. - fn new<'a, In>(constraint: In::Constraint, output: Out) -> Self - where - In: Input<'a, Constraint = C>, - { - Self { constraint, output, age: AtomicUsize::new(0) } - } - - /// Return the entry's output if it is valid for the given input. - fn lookup<'a, In>(&self, input: &In) -> Option<(&In::Constraint, &Out)> - where - In: Input<'a, Constraint = C>, - { - input.validate(&self.constraint).then(|| { - self.age.store(0, Ordering::SeqCst); - (&self.constraint, &self.output) - }) + Self { tree: CallTree::new() } } } diff --git a/src/track.rs b/src/track.rs index 75d5b60..bad2075 100644 --- a/src/track.rs +++ b/src/track.rs @@ -1,80 +1,91 @@ use std::fmt::{self, Debug, Formatter}; +use std::hash::Hash; use std::ops::{Deref, DerefMut}; use crate::accelerate; -use crate::constraint::Join; /// A trackable type. /// /// This is implemented by types that have an implementation block annotated /// with `#[track]` and for trait objects whose traits are annotated with /// `#[track]`. For more details, see [its documentation](macro@crate::track). -pub trait Track: Validate + Surfaces { +pub trait Track: Surfaces { + /// An enumeration of possible tracked calls that can be performed on this + /// tracked type. + type Call: Call; + + /// Performs a call on the value and returns the hash of its results. + fn call(&self, call: &Self::Call) -> u128; + + /// Performs a mutable call on the value. + fn call_mut(&mut self, call: &Self::Call); + /// Start tracking all accesses to a value. #[inline] fn track(&self) -> Tracked<'_, Self> { - Tracked { - value: self, - constraint: None, - id: accelerate::id(), - } + Tracked { value: self, sink: None, id: accelerate::id() } } /// Start tracking all accesses and mutations to a value. #[inline] fn track_mut(&mut self) -> TrackedMut<'_, Self> { - TrackedMut { value: self, constraint: None } + TrackedMut { value: self, sink: None } } - /// Start tracking all accesses into a constraint. + /// Start tracking all accesses into a sink. #[inline] - fn track_with<'a>(&'a self, constraint: &'a Self::Constraint) -> Tracked<'a, Self> { + fn track_with<'a>( + &'a self, + sink: &'a dyn Sink, + ) -> Tracked<'a, Self> { Tracked { value: self, - constraint: Some(constraint), + sink: Some(sink), id: accelerate::id(), } } - /// Start tracking all accesses and mutations into a constraint. + /// Start tracking all accesses and mutations into a sink. #[inline] fn track_mut_with<'a>( &'a mut self, - constraint: &'a Self::Constraint, + sink: &'a dyn Sink, ) -> TrackedMut<'a, Self> { - TrackedMut { value: self, constraint: Some(constraint) } + TrackedMut { value: self, sink: Some(sink) } } } -/// A type that can be validated against constraints. -/// -/// Typical crate usage does not require you to interact with its trait. -/// However, it can be useful if you want to integrate comemo's tracking and -/// constraint validation mechanism directly into your code. -/// -/// This trait is implemented by the `#[track]` macro alongside [`Track`]. -pub trait Validate { - /// The constraints for this type. - type Constraint: Default + Clone + Join + 'static; +/// A destination to which recorded tracked calls can be sent. +pub trait Sink: Send + Sync { + /// An enumeration of possible tracked calls that can be sent to this sink. + type Call; - /// Whether this value fulfills the given constraints. + /// Emit a call and its return hash to the sink. /// - /// For a type `Foo`, empty constraints can be created with `::Constraint::default()` and filled with - /// [`track_with`](Track::track_with) or - /// [`track_mut_with`](Track::track_mut_with). - fn validate(&self, constraint: &Self::Constraint) -> bool; + /// Returns `false` if the call was deduplicated, so that callers can avoid + /// sending it to other sinks higher up the hierarchy. + fn emit(&self, call: Self::Call, ret: u128) -> bool; +} - /// Accelerated version of [`validate`](Self::validate). - /// - /// A `id` uniquely identifies a value to speed up repeated validation of - /// equal constraints against the same value. If given the same `id` twice, - /// `self` must also be identical, unless [`evict`](crate::evict) has been - /// called in between. - fn validate_with_id(&self, constraint: &Self::Constraint, id: usize) -> bool; - - /// Replay recorded mutations to the value. - fn replay(&mut self, constraint: &Self::Constraint); +impl Sink for &S { + type Call = S::Call; + + fn emit(&self, call: Self::Call, ret: u128) -> bool { + (*self).emit(call, ret) + } +} + +/// A call to a tracked function. +pub trait Call: Clone + PartialEq + Hash + Send + Sync { + /// Whether the call is mutable. + fn is_mutable(&self) -> bool; +} + +/// This implementation is used for hashed types in the `Input` trait. +impl Call for () { + fn is_mutable(&self) -> bool { + false + } } /// This type's tracked surfaces. @@ -122,29 +133,33 @@ pub trait Surfaces { /// how it can be used. In particular, invariance prevents you from creating a /// usable _chain_ of tracked types. /// -/// ```ignore +/// ``` +/// # use comemo::{Track, Tracked}; /// struct Chain<'a> { /// outer: Tracked<'a, Self>, /// data: u32, // some data for the chain link /// } +/// # #[comemo::track] impl<'a> Chain<'a> {} /// ``` /// /// However, this is sometimes a useful pattern (for example, it allows you to /// detect cycles in memoized recursive algorithms). If you want to create a /// tracked chain or need covariance for another reason, you need to manually -/// specify the constraint type like so: +/// specify the call type like so: /// -/// ```ignore +/// ``` +/// # use comemo::{Track, Tracked}; /// struct Chain<'a> { -/// outer: Tracked<'a, Self, as Validate>::Constraint>, +/// outer: Tracked<'a, Self, as Track>::Call>, /// data: u32, // some data for the chain link /// } +/// # #[comemo::track] impl<'a> Chain<'a> {} /// ``` /// /// Notice the `'static` lifetime: This makes the compiler understand that no /// strange business that depends on `'a` is happening in the associated /// constraint type. (In fact, all constraints are `'static`.) -pub struct Tracked<'a, T, C = ::Constraint> +pub struct Tracked<'a, T, C = ::Call> where T: Track + ?Sized, { @@ -155,7 +170,7 @@ where /// /// Starts out as `None` and is set to a stack-stored constraint in the /// preamble of memoized functions. - pub(crate) constraint: Option<&'a C>, + pub(crate) sink: Option<&'a dyn Sink>, /// A unique ID for validation acceleration. pub(crate) id: usize, } @@ -205,7 +220,7 @@ where /// details, see [its documentation](macro@crate::track). /// /// For more details, see [`Tracked`]. -pub struct TrackedMut<'a, T, C = ::Constraint> +pub struct TrackedMut<'a, T, C = ::Call> where T: Track + ?Sized, { @@ -216,7 +231,7 @@ where /// /// Starts out as `None` and is set to a stack-stored constraint in the /// preamble of memoized functions. - pub(crate) constraint: Option<&'a C>, + pub(crate) sink: Option<&'a dyn Sink>, } impl<'a, T> TrackedMut<'a, T> @@ -231,7 +246,7 @@ where pub fn downgrade(this: Self) -> Tracked<'a, T> { Tracked { value: this.value, - constraint: this.constraint, + sink: this.sink, id: accelerate::id(), } } @@ -244,7 +259,7 @@ where pub fn reborrow(this: &Self) -> Tracked<'_, T> { Tracked { value: this.value, - constraint: this.constraint, + sink: this.sink, id: accelerate::id(), } } @@ -255,7 +270,7 @@ where /// defined on `T`. It should be called as `TrackedMut::reborrow_mut(...)`. #[inline] pub fn reborrow_mut(this: &mut Self) -> TrackedMut<'_, T> { - TrackedMut { value: this.value, constraint: this.constraint } + TrackedMut { value: this.value, sink: this.sink } } } @@ -293,31 +308,31 @@ where /// Destructure a `Tracked<_>` into its parts. #[inline] -pub fn to_parts_ref(tracked: Tracked<'_, T>) -> (&T, Option<&T::Constraint>) +pub fn to_parts_ref(tracked: Tracked<'_, T>) -> (&T, Option<&dyn Sink>) where T: Track + ?Sized, { - (tracked.value, tracked.constraint) + (tracked.value, tracked.sink) } /// Destructure a `TrackedMut<_>` into its parts. #[inline] pub fn to_parts_mut_ref<'a, T>( tracked: &'a TrackedMut, -) -> (&'a T, Option<&'a T::Constraint>) +) -> (&'a T, Option<&'a dyn Sink>) where T: Track + ?Sized, { - (tracked.value, tracked.constraint) + (tracked.value, tracked.sink) } /// Destructure a `TrackedMut<_>` into its parts. #[inline] pub fn to_parts_mut_mut<'a, T>( tracked: &'a mut TrackedMut, -) -> (&'a mut T, Option<&'a T::Constraint>) +) -> (&'a mut T, Option<&'a dyn Sink>) where T: Track + ?Sized, { - (tracked.value, tracked.constraint) + (tracked.value, tracked.sink) } diff --git a/src/tree.rs b/src/tree.rs new file mode 100644 index 0000000..b0c6427 --- /dev/null +++ b/src/tree.rs @@ -0,0 +1,417 @@ +use std::fmt::{self, Debug}; +use std::hash::Hash; + +use rustc_hash::FxHashMap; +use slab::Slab; + +use crate::constraint::CallSequence; + +/// A tree data structure that associates a value with a key hash and a sequence +/// of (call, return hash) pairs. +/// +/// Allows to efficiently query for a value for which every call in the sequence +/// yielded the same return hash as a given oracle function will yield for that +/// call. +pub struct CallTree { + /// Inner nodes, storing calls. + inner: Slab>, + /// Leaf nodes, directly storing outputs. + leaves: Slab>, + /// The initial node for the given key hash. + start: FxHashMap, + /// Maps from parent nodes to child nodes. The key is a pair of an inner + /// node ID and a return hash for that call. The value is the node to + /// transition to. + edges: FxHashMap<(InnerId, u128), NodeId>, +} + +/// An inner node in the call tree. +struct InnerNode { + /// The call at this node. + call: C, + /// How many children the node has. If this reaches zero, the node is + /// deleted. + children: usize, + /// The node's parent. + parent: Option, +} + +/// A leaf node in the call tree. +struct LeafNode { + /// The value. + value: T, + /// The node's parent. + parent: Option, +} + +impl CallTree { + /// Creates an empty call tree. + pub fn new() -> Self { + Self { + inner: Slab::new(), + leaves: Slab::new(), + edges: FxHashMap::default(), + start: FxHashMap::default(), + } + } +} + +impl CallTree { + /// Retrieves the output value for the given key and oracle. + pub fn get(&self, key: u128, mut oracle: impl FnMut(&C) -> u128) -> Option<&T> { + let mut cursor = *self.start.get(&key)?; + loop { + match cursor.kind() { + NodeIdKind::Leaf(id) => { + return Some(&self.leaves[id].value); + } + NodeIdKind::Inner(id) => { + let call = &self.inner[id].call; + let ret = oracle(call); + cursor = *self.edges.get(&(id, ret))?; + } + } + } + } + + /// Inserts a key and a call sequence and its associated value into the + /// tree. + /// + /// See the documentation of [`InsertError`] for more details on when this + /// can fail. + pub fn insert( + &mut self, + key: u128, + mut sequence: CallSequence, + value: T, + ) -> Result<(), InsertError> { + let mut cursor = self.start.get(&key).copied(); + let mut predecessor = None; + + loop { + if predecessor.is_none() + && let Some(pos) = cursor + { + let NodeIdKind::Inner(id) = pos.kind() else { + return Err(InsertError::AlreadyExists); + }; + + let call = &self.inner[id].call; + let Some(ret) = sequence.extract(call) else { + return Err(InsertError::MissingCall); + }; + + let pair = (id, ret); + if let Some(&next) = self.edges.get(&pair) { + // We are still on an existing path. + cursor = Some(next); + } else { + // We are now starting to build a new path in the tree. + predecessor = Some(pair); + } + } else { + // We are adding a new node to the tree for the next call in the + // sequence. + let Some((call, ret)) = sequence.next() else { break }; + + let new_inner_id = self.inner.insert(InnerNode { + call, + children: 0, + parent: predecessor.map(|(id, _)| id), + }); + let new_id = NodeId::inner(new_inner_id); + self.link(cursor.is_none(), key, predecessor.take(), new_id); + + predecessor = Some((new_inner_id, ret)); + cursor = Some(new_id); + } + } + + if predecessor.is_none() && cursor.is_some() { + return Err(InsertError::AlreadyExists); + } + + let target = NodeId::leaf( + self.leaves + .insert(LeafNode { value, parent: predecessor.map(|(id, _)| id) }), + ); + self.link(cursor.is_none(), key, predecessor, target); + + Ok(()) + } + + /// Creates a new link between two nodes. + fn link( + &mut self, + at_start: bool, + key: u128, + from: Option<(InnerId, u128)>, + to: NodeId, + ) { + if at_start { + self.start.insert(key, to); + } + if let Some(pair) = from { + self.inner[pair.0].children += 1; + self.edges.insert(pair, to); + } + } +} + +impl CallTree { + /// Removes all call sequences from the tree for whose values the predicate + /// returns `false`. + pub fn retain(&mut self, mut f: impl FnMut(&mut T) -> bool) { + // Prune from the leafs upwards, starting with the outputs. + self.leaves.retain(|_, node| { + let keep = f(&mut node.value); + if !keep { + // Delete parents iteratively while we are the only child. + let mut parent = node.parent; + while let Some(inner_id) = parent { + let node = &mut self.inner[inner_id]; + if node.children > 1 { + node.children -= 1; + break; + } else { + parent = self.inner[inner_id].parent; + self.inner.remove(inner_id); + } + } + } + keep + }); + + // Checks whether the given node survived the pruning. + let exists = |node: NodeId| match node.kind() { + NodeIdKind::Inner(id) => self.inner.contains(id), + NodeIdKind::Leaf(id) => self.leaves.contains(id), + }; + + // Prune edges. + self.edges.retain(|_, node| exists(*node)); + self.start.retain(|_, node| exists(*node)); + } + + /// Checks a few invariants of the data structure. + #[cfg(test)] + fn assert_consistency(&self) { + let exists = |node: NodeId| match node.kind() { + NodeIdKind::Inner(id) => self.inner.contains(id), + NodeIdKind::Leaf(id) => self.leaves.contains(id), + }; + + for &node in self.start.values() { + assert!(exists(node)); + } + + for (&(inner_id, _), &node) in &self.edges { + assert!(exists(node)); + assert!(self.inner.contains(inner_id)); + } + } +} + +impl Debug for CallTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for (&(inner_id, ret), next) in &self.edges { + let call = &self.inner[inner_id].call; + write!(f, "[{inner_id}] ({call:?}, {ret:?}) -> ")?; + match next.kind() { + NodeIdKind::Inner(id) => writeln!(f, "{id}")?, + NodeIdKind::Leaf(id) => writeln!(f, "{:?}", &self.leaves[id].value)?, + } + } + Ok(()) + } +} + +impl Default for CallTree { + fn default() -> Self { + Self::new() + } +} + +/// Identifies a node in the call tree. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +struct NodeId(isize); + +impl NodeId { + /// An inner with an index pointing into the `calls` slab allocator. + fn inner(i: usize) -> Self { + Self(i as isize) + } + + /// A leaf node with an index pointing into the `output` slab allocator. + fn leaf(i: usize) -> Self { + Self(-(i as isize) - 1) + } + + /// Makes this encoded node available as an enum for matching. + fn kind(self) -> NodeIdKind { + if self.0 >= 0 { + NodeIdKind::Inner(self.0 as usize) + } else { + NodeIdKind::Leaf((-self.0) as usize - 1) + } + } +} + +/// An unpacked representation of a `NodeId`. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +enum NodeIdKind { + Inner(InnerId), + Leaf(LeafId), +} + +type InnerId = usize; +type LeafId = usize; + +/// An error that can occur during insertion of a call sequence into the call +/// tree. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum InsertError { + /// A call sequence that is a prefix of the used one was already inserted. + AlreadyExists, + /// The calls from the sequence gave results matching the first N ones from + /// an existing sequence S in the tree, but the N+1th call from S does not + /// exist in the given `sequence`. This points towards non-determinism in + /// the memoized function the `sequence` belongs to. + MissingCall, +} + +#[cfg(test)] +mod tests { + use quickcheck::Arbitrary; + + use super::*; + + #[test] + fn test_call_tree() { + test_ops([ + Op::Insert(0, vec![('a', 10), ('b', 15)], "first"), + Op::Insert(0, vec![('a', 10), ('b', 20)], "second"), + Op::Insert(0, vec![('a', 15), ('c', 15)], "third"), + ]); + test_ops([ + Op::Insert(0, vec![('a', 10), ('b', 15)], "first"), + Op::Insert(0, vec![('a', 10), ('c', 15), ('b', 20)], "second"), + Op::Insert(0, vec![('a', 15), ('b', 30), ('c', 15)], "third"), + Op::Manual(|tree| { + assert_eq!(tree.inner.len(), 5); + assert_eq!(tree.leaves.len(), 3); + assert_eq!(tree.edges.len(), 7); + assert_eq!(tree.start.len(), 1); + }), + Op::Retain(Box::new(|v| *v == "second")), + Op::Manual(|tree| { + assert_eq!(tree.inner.len(), 3); + assert_eq!(tree.leaves.len(), 1); + assert_eq!(tree.edges.len(), 3); + assert_eq!(tree.start.len(), 1); + }), + ]); + } + + #[quickcheck_macros::quickcheck] + fn test_call_tree_quickcheck(ops: Vec) { + test_ops( + std::iter::once(Op::IgnoreInsertErrors) + .chain(ops.into_iter().map(ArbitraryOp::into_op)), + ); + } + + #[derive(Debug, Clone)] + enum ArbitraryOp { + Insert(u128, Vec, u8), + Retain(u8), + } + + impl ArbitraryOp { + fn into_op(self) -> Op { + match self { + Self::Insert(key, nums, output) => { + let mut state = 50; + Op::Insert( + key, + nums.iter() + .map(move |&v| { + let pair = (state, v as u128); + state += 1 + v as u64; + pair + }) + .collect(), + output, + ) + } + Self::Retain(mid) => Op::Retain(Box::new(move |v| *v > mid)), + } + } + } + + impl Arbitrary for ArbitraryOp { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + if bool::arbitrary(g) { + Self::Insert( + Arbitrary::arbitrary(g), + Arbitrary::arbitrary(g), + Arbitrary::arbitrary(g), + ) + } else { + Self::Retain(Arbitrary::arbitrary(g)) + } + } + } + + enum Op { + IgnoreInsertErrors, + Insert(u128, Vec<(C, u128)>, T), + Retain(Box bool>), + Manual(fn(&mut CallTree)), + } + + #[track_caller] + fn test_ops(ops: impl IntoIterator>) + where + C: Clone + Hash + Eq, + T: Debug + PartialEq + Clone, + { + let mut tree = CallTree::new(); + let mut kept = Vec::<(u128, FxHashMap, T)>::new(); + let mut ignore_insert_errors = false; + + for op in ops { + match op { + Op::IgnoreInsertErrors => ignore_insert_errors = true, + Op::Insert(key, seq, value) => { + match tree.insert(key, seq.iter().cloned().collect(), value.clone()) { + Ok(()) => kept.push(( + key, + seq.iter().map(|(k, v)| (k.clone(), *v)).collect(), + value.clone(), + )), + Err(_) if ignore_insert_errors => {} + Err(e) => panic!("{e:?}"), + } + } + Op::Retain(f) => { + tree.retain(|v| f(v)); + kept.retain_mut(|(key, map, v)| { + let keep = f(v); + if !keep { + assert_eq!(tree.get(*key, |s| map[s]), None); + } + keep + }); + } + Op::Manual(f) => f(&mut tree), + } + + tree.assert_consistency(); + + for (key, map, value) in &kept { + assert_eq!(tree.get(*key, |s| map[s]), Some(value)); + } + } + } +} diff --git a/tests/tests.rs b/tests/tests.rs index 8eb925b..baf52d1 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -3,8 +3,10 @@ use std::collections::HashMap; use std::hash::Hash; use std::path::{Path, PathBuf}; +use std::sync::atomic::AtomicUsize; -use comemo::{Track, Tracked, TrackedMut, Validate, evict, memoize, track}; +use comemo::{Track, Tracked, TrackedMut, evict, memoize, track}; +use quickcheck::Arbitrary; use serial_test::serial; macro_rules! test { @@ -80,6 +82,7 @@ fn evaluate(script: &str, files: Tracked) -> i32 { }) .sum() } + /// Test the calc language. #[test] #[serial] @@ -265,6 +268,10 @@ impl Tester { } } +/// A non-copy struct that is passed by value to a tracked method. +#[derive(Clone, PartialEq, Hash)] +struct Heavy(String); + /// Test empty type without methods. struct Empty; @@ -340,7 +347,7 @@ fn test_variance() { struct Chain<'a> { // Need to override the lifetime here so that a `Tracked` is covariant over // `Chain`. - outer: Option as Validate>::Constraint>>, + outer: Option as Track>::Call>>, value: u32, } @@ -363,33 +370,29 @@ impl<'a> Chain<'a> { } } -/// Test mutable tracking. +/// Test purely mutable tracking. #[test] #[serial] #[rustfmt::skip] -fn test_mutable() { +fn test_purely_mutable() { #[comemo::memoize] - fn dump(mut sink: TrackedMut) { - sink.emit("a"); - sink.emit("b"); - let c = sink.len_or_ten().to_string(); - sink.emit(&c); + fn dump(mut emitter: TrackedMut, value: &str) { + emitter.emit(value); + emitter.emit("1"); } let mut emitter = Emitter(vec![]); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut(), "a"), ()); + test!(miss: dump(emitter.track_mut(), "b"), ()); + test!(miss: dump(emitter.track_mut(), "c"), ()); + test!(hit: dump(emitter.track_mut(), "a"), ()); + test!(hit: dump(emitter.track_mut(), "b"), ()); assert_eq!(emitter.0, [ - "a", "b", "2", - "a", "b", "5", - "a", "b", "8", - "a", "b", "10", - "a", "b", "10", - "a", "b", "10", + "a", "1", + "b", "1", + "c", "1", + "a", "1", + "b", "1", ]) } @@ -402,22 +405,166 @@ impl Emitter { fn emit(&mut self, msg: &str) { self.0.push(msg.into()); } +} + +/// Ensures that mutabled tracked calls in a nested function are properly +/// replayed when there is a cache hit for a top-level function. +#[test] +#[serial] +fn test_mutable_nested() { + #[comemo::memoize] + fn a(counter: TrackedMut, _k: usize) { + b(counter); + } - fn len_or_ten(&self) -> usize { - self.0.len().min(10) + #[comemo::memoize] + fn b(mut counter: TrackedMut) { + counter.add(3); } + + let mut c1 = Counter(0); + a(c1.track_mut(), 0); + assert_eq!(c1.0, 3); + + let mut c2 = Counter(0); + a(c2.track_mut(), 1); + assert_eq!(c2.0, 3); + + let mut c3 = Counter(0); + a(c3.track_mut(), 1); + assert_eq!(c3.0, 3); } -/// A non-copy struct that is passed by value to a tracked method. -#[derive(Clone, PartialEq, Hash)] -struct Heavy(String); +/// Ensures that we don't run into quadratic runtime during cache validation of +/// many cache entries with the same key hash. +#[test] +#[serial] +fn test_many_with_same_key() { + #[memoize] + fn contextual(context: Tracked) -> String { + if let Some(loc) = context.location() { + if loc == 5 { + format!("Twenty has {}", context.styles()) + } else { + format!("Location: {loc}") + } + } else { + "No location".into() + } + } + + fn oracle(context: &Context) -> String { + if let Some(loc) = context.location { + if loc == 5 { + format!("Twenty has {}", context.styles) + } else { + format!("Location: {loc}") + } + } else { + "No location".into() + } + } + + for i in 0..1000 { + let context = Context { location: Some(i), styles: "styles" }; + test!(miss: contextual(context.track()), oracle(&context)); + } + + for i in 0..1000 { + let context = Context { location: Some(i), styles: "styles" }; + test!(hit: contextual(context.track()), oracle(&context)); + } +} + +/// Tests a memoized function that calls tracked functions in non-deterministic +/// fashion. (Not just out of order, but some call that appeared in one run does +/// not appear at all in the other even though the same calls and return hashes +/// led up to that point.q) +#[test] +#[serial] +#[should_panic(expected = "comemo: memoized function is non-deterministic")] +fn test_non_deterministic() { + use std::sync::atomic::Ordering::SeqCst; + + static FOO: AtomicUsize = AtomicUsize::new(0); + + #[memoize] + fn contextual(context: Tracked) -> String { + if FOO.load(SeqCst) == 0 { + let _ = context.location(); + } else { + let _ = context.styles(); + } + String::new() + } + + let context = Context { location: Some(0), styles: "styles" }; + FOO.store(0, SeqCst); + contextual(context.track()); + + let context = Context { location: Some(1), styles: "styles" }; + FOO.store(1, SeqCst); + contextual(context.track()); +} + +/// Tests a memoized function that calls tracked functions out of order, but in +/// a fashion that is still deterministic in which functions are called overall +/// (this happens in deterministic functions that use multi-threading +/// internally). +#[test] +#[serial] +fn test_deterministic_out_of_order() { + use std::sync::atomic::Ordering::SeqCst; + + static FOO: AtomicUsize = AtomicUsize::new(0); + + #[memoize] + fn contextual(context: Tracked) -> String { + let (a, b) = if FOO.load(SeqCst) == 0 { + let a = context.location(); + let b = context.styles(); + (a, b) + } else { + let b = context.styles(); + let a = context.location(); + (a, b) + }; + format!("{a:?} {b}") + } + + let context = Context { location: Some(0), styles: "styles" }; + FOO.store(0, SeqCst); + test!(miss: contextual(context.track()), "Some(0) styles"); + + FOO.store(1, SeqCst); + test!(hit: contextual(context.track()), "Some(0) styles"); + + let context = Context { location: Some(1), styles: "styles" }; + test!(miss: contextual(context.track()), "Some(1) styles"); +} + +struct Context { + location: Option, + styles: &'static str, +} + +#[track] +impl Context { + fn location(&self) -> Option { + self.location + } + + fn styles(&self) -> &'static str { + self.styles + } +} /// Test a tracked method that is impure. #[test] #[serial] #[cfg(debug_assertions)] #[should_panic( - expected = "comemo: found conflicting constraints. is this tracked function pure?" + expected = "comemo: found differing return values. is there an impure tracked function?" )] fn test_impure_tracked_method() { #[comemo::memoize] @@ -455,3 +602,132 @@ fn test_with_disabled() { test!(miss: disabled(2000), 2000); test!(hit: disabled(2000), 2000); } + +#[quickcheck_macros::quickcheck] +fn test_memoize_quickcheck(cases: Cases) { + for Case(map, tree) in cases.0 { + let mut c1 = Counter(0); + let r1 = fuzzable_unmemoized(&map, &mut c1, &tree); + + let mut c2 = Counter(0); + let r2 = fuzzable(map.track(), c2.track_mut(), &tree); + + let mut c3 = Counter(0); + let r3 = fuzzable(map.track(), c3.track_mut(), &tree); + assert!(comemo::testing::last_was_hit()); + + assert_eq!(r1, r2); + assert_eq!(r2, r3); + assert_eq!(c1, c2); + assert_eq!(c2, c3); + } + comemo::evict(2) +} + +#[memoize] +fn fuzzable( + map: Tracked, + mut counter: TrackedMut, + tree: &[Node], +) -> u32 { + tree.iter() + .filter_map(|node| match node { + Node::Leaf(leaf) => { + if *leaf == 7 { + counter.add(1); + } + map.get(*leaf) + } + Node::Inner(inner, _) => { + Some(fuzzable(map, TrackedMut::reborrow_mut(&mut counter), inner)) + } + }) + .fold(0, |a, b| a.saturating_add(b)) +} + +fn fuzzable_unmemoized(map: &IntMap, counter: &mut Counter, tree: &[Node]) -> u32 { + tree.iter() + .filter_map(|node| match node { + Node::Leaf(leaf) => { + if *leaf == 7 { + counter.add(1); + } + map.get(*leaf) + } + Node::Inner(inner, _) => Some(fuzzable_unmemoized(map, counter, inner)), + }) + .fold(0, |a, b| a.saturating_add(b)) +} + +#[derive(Debug, Clone)] +struct Cases(Vec); + +impl Arbitrary for Cases { + fn arbitrary(_: &mut quickcheck::Gen) -> Self { + Self(Arbitrary::arbitrary(&mut quickcheck::Gen::new(5))) + } +} + +#[derive(Debug, Clone)] +struct Case(IntMap, Vec); + +impl Arbitrary for Case { + fn arbitrary(_: &mut quickcheck::Gen) -> Self { + let g = &mut quickcheck::Gen::new(100); + Self(Arbitrary::arbitrary(g), Arbitrary::arbitrary(g)) + } +} + +#[derive(Debug, Clone, Hash)] +enum Node { + Leaf(u32), + Inner(Vec, usize), +} + +impl Node { + fn depth(&self) -> usize { + match self { + Self::Leaf(_) => 0, + Self::Inner(_, depth) => *depth, + } + } +} + +impl Arbitrary for Node { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + if g.size() == 0 || bool::arbitrary(g) { + Self::Leaf(Arbitrary::arbitrary(g)) + } else { + let g = &mut quickcheck::Gen::new(g.size() / 3); + let nodes: Vec = Arbitrary::arbitrary(g); + let depth = nodes.iter().map(|node| node.depth() + 1).max().unwrap_or(0); + Self::Inner(nodes, depth) + } + } +} + +#[derive(Debug, Clone)] +struct IntMap(HashMap); + +impl Arbitrary for IntMap { + fn arbitrary(g: &mut quickcheck::Gen) -> Self { + Self(Arbitrary::arbitrary(g)) + } +} + +#[track] +impl IntMap { + fn get(&self, k: u32) -> Option { + self.0.get(&k).copied() + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +struct Counter(usize); + +#[track] +impl Counter { + fn add(&mut self, v: usize) { + self.0 += v; + } +}