diff --git a/Cargo.toml b/Cargo.toml index 035d7adf..d81e5900 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ redb = { version = "2.6.3", optional = true } reflink-copy = { version = "0.1.24", optional = true } n0-error = "0.1.2" nested_enum_utils = "0.2.3" +atomic_refcell = "0.1.13" [dev-dependencies] clap = { version = "4.5.31", features = ["derive"] } diff --git a/src/store/fs/bao_file.rs b/src/store/fs/bao_file.rs index 72ab4801..7881b69b 100644 --- a/src/store/fs/bao_file.rs +++ b/src/store/fs/bao_file.rs @@ -20,12 +20,12 @@ use bytes::{Bytes, BytesMut}; use derive_more::Debug; use irpc::channel::mpsc; use n0_error::{Result, StdResultExt}; -use tokio::sync::watch; use tracing::{debug, info, trace}; use super::{ entry_state::{DataLocation, EntryState, OutboardLocation}, options::{Options, PathOptions}, + util::watch, BaoFilePart, }; use crate::{ diff --git a/src/store/fs/util.rs b/src/store/fs/util.rs index 1cbd01bc..b739394a 100644 --- a/src/store/fs/util.rs +++ b/src/store/fs/util.rs @@ -2,6 +2,7 @@ use std::future::Future; use tokio::{select, sync::mpsc}; pub(crate) mod entity_manager; +pub(crate) mod watch; /// A wrapper for a tokio mpsc receiver that allows peeking at the next message. #[derive(Debug)] diff --git a/src/store/fs/util/watch.rs b/src/store/fs/util/watch.rs new file mode 100644 index 00000000..642acfec --- /dev/null +++ b/src/store/fs/util/watch.rs @@ -0,0 +1,104 @@ +use std::{ops::Deref, sync::Arc}; + +use atomic_refcell::{AtomicRef, AtomicRefCell}; + +#[derive(Debug, Default)] +struct State { + value: T, + dropped: bool, +} + +#[derive(Debug, Default)] +struct Shared { + value: AtomicRefCell>, + notify: tokio::sync::Notify, +} + +#[derive(Debug, Default)] +pub struct Sender(Arc>); + +impl Clone for Sender { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +pub struct Receiver(Arc>); + +impl Sender { + + pub fn send_modify(&self, modify: F) + where + F: FnOnce(&mut T), + { + self.send_if_modified(|value| { + modify(value); + true + }); + } + + pub fn send_replace(&self, mut value: T) -> T { + // swap old watched value with the new one + self.send_modify(|old| std::mem::swap(old, &mut value)); + + value + } + + pub fn send_if_modified(&self, modify: F) -> bool + where + F: FnOnce(&mut T) -> bool, + { + let mut state = self.0.value.borrow_mut(); + let modified = modify(&mut state.value); + if modified { + self.0.notify.notify_waiters(); + } + modified + } + + pub fn borrow(&self) -> impl Deref + '_ { + AtomicRef::map(self.0.value.borrow(), |state| &state.value) + } + + pub fn subscribe(&self) -> Receiver { + Receiver(self.0.clone()) + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.0.value.borrow_mut().dropped = true; + self.0.notify.notify_waiters(); + } +} + +impl Receiver { + pub async fn changed(&self) -> Result<(), error::RecvError> { + self.0.notify.notified().await; + if self.0.value.borrow().dropped { + Err(error::RecvError(())) + } else { + Ok(()) + } + } + + pub fn borrow(&self) -> impl Deref + '_ { + AtomicRef::map(self.0.value.borrow(), |state| &state.value) + } +} + +pub mod error { + use std::{error::Error, fmt}; + + /// Error produced when receiving a change notification. + #[derive(Debug, Clone)] + pub struct RecvError(pub(super) ()); + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl Error for RecvError {} +}