Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions engine/packages/cache/src/driver.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
fmt::Debug,
sync::OnceLock,
time::{Duration, Instant},
};

Expand Down Expand Up @@ -132,10 +133,10 @@ impl moka::Expiry<String, ExpiringValue> for ValueExpiry {
}
}

static CACHE: OnceLock<Cache<String, ExpiringValue>> = OnceLock::new();

/// In-memory cache driver implementation using the moka crate
pub struct InMemoryDriver {
cache: Cache<String, ExpiringValue>,
}
pub struct InMemoryDriver {}

impl Debug for InMemoryDriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -146,11 +147,20 @@ impl Debug for InMemoryDriver {
impl InMemoryDriver {
pub fn new(max_capacity: u64) -> Self {
// Create a cache with ValueExpiry implementation for custom expiration times
let cache = CacheBuilder::new(max_capacity)
.expire_after(ValueExpiry)
.build();
CACHE.get_or_init(|| {
CacheBuilder::new(max_capacity)
.expire_after(ValueExpiry)
.eviction_listener(|key, _value, cause| {
tracing::debug!(?key, ?cause, "cache eviction");
})
.build()
});

Self {}
}

Self { cache }
fn cache(&self) -> &Cache<String, ExpiringValue> {
CACHE.get().expect("should be initialized")
}

pub async fn get<'a>(
Expand All @@ -163,7 +173,7 @@ impl InMemoryDriver {
// Async block for metrics
async {
for key in keys {
result.push(self.cache.get(&**key).await.map(|x| x.value.clone()));
result.push(self.cache().get(&**key).await.map(|x| x.value.clone()));
}
}
.instrument(tracing::info_span!("get"))
Expand Down Expand Up @@ -193,7 +203,7 @@ impl InMemoryDriver {
};

// Store in cache - expiry will be handled by ValueExpiry
self.cache.insert(key.into(), entry).await;
self.cache().insert(key.into(), entry).await;
}
}
.instrument(tracing::info_span!("set"))
Expand All @@ -212,7 +222,7 @@ impl InMemoryDriver {
async {
for key in keys {
// Use remove instead of invalidate to ensure it's actually removed
self.cache.remove(&*key).await;
self.cache().remove(&*key).await;
}
}
.instrument(tracing::info_span!("delete"))
Expand Down
19 changes: 12 additions & 7 deletions engine/packages/cache/src/inner.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::{fmt::Debug, sync::Arc};
use std::{
fmt::Debug,
sync::{Arc, OnceLock},
};

use tokio::sync::broadcast;

use super::*;
use crate::driver::{Driver, InMemoryDriver};

static IN_FLIGHT: OnceLock<scc::HashMap<RawCacheKey, broadcast::Sender<()>>> = OnceLock::new();

pub type Cache = Arc<CacheInner>;

/// Utility type used to hold information relating to caching.
pub struct CacheInner {
pub(crate) driver: Driver,
pub(crate) in_flight: scc::HashMap<RawCacheKey, broadcast::Sender<()>>,
pub(crate) ups: Option<universalpubsub::PubSub>,
}

Expand All @@ -36,11 +40,12 @@ impl CacheInner {
#[tracing::instrument(skip(ups))]
pub fn new_in_memory(max_capacity: u64, ups: Option<universalpubsub::PubSub>) -> Cache {
let driver = Driver::InMemory(InMemoryDriver::new(max_capacity));
Arc::new(CacheInner {
driver,
in_flight: scc::HashMap::new(),
ups,
})

Arc::new(CacheInner { driver, ups })
}

pub(crate) fn in_flight(&self) -> &scc::HashMap<RawCacheKey, broadcast::Sender<()>> {
IN_FLIGHT.get_or_init(scc::HashMap::new)
}
}

Expand Down
21 changes: 12 additions & 9 deletions engine/packages/cache/src/req_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl RequestConfig {

// MARK: Fetch
impl RequestConfig {
#[tracing::instrument(err, skip(keys, getter, encoder, decoder))]
#[tracing::instrument(err, skip_all, fields(?base_key))]
async fn fetch_all_convert<Key, Value, Getter, Fut, Encoder, Decoder>(
self,
base_key: impl ToString + Debug,
Expand Down Expand Up @@ -129,7 +129,7 @@ impl RequestConfig {
// Determine which keys are currently being fetched and not
for key in remaining_keys {
let cache_key = self.cache.driver.process_key(&base_key, &key);
match self.cache.in_flight.entry_async(cache_key).await {
match self.cache.in_flight().entry_async(cache_key).await {
scc::hash_map::Entry::Occupied(broadcast) => {
waiting_keys.push((key, broadcast.subscribe()));
}
Expand Down Expand Up @@ -189,7 +189,13 @@ impl RequestConfig {
succeeded_keys.into_iter().unzip();

let (cached_values_res, ctx3_res) = tokio::join!(
cache.driver.get(&base_key2, &succeeded_cache_keys),
async {
if succeeded_cache_keys.is_empty() {
Ok(Vec::new())
} else {
cache.driver.get(&base_key2, &succeeded_cache_keys).await
}
},
async {
if failed_keys.is_empty() {
Ok(ctx3)
Expand Down Expand Up @@ -276,7 +282,7 @@ impl RequestConfig {
// Release leases
for key in leased_keys {
let cache_key = self.cache.driver.process_key(&base_key, &key);
self.cache.in_flight.remove_async(&cache_key).await;
self.cache.in_flight().remove_async(&cache_key).await;
}
}

Expand Down Expand Up @@ -310,7 +316,7 @@ impl RequestConfig {
}
}

#[tracing::instrument(err, skip(keys))]
#[tracing::instrument(err, skip_all, fields(?base_key))]
pub async fn purge<Key>(
self,
base_key: impl AsRef<str> + Debug,
Expand Down Expand Up @@ -363,7 +369,7 @@ impl RequestConfig {

/// Purges keys from the local cache without publishing to NATS.
/// This is used by the cache-purge service to avoid recursive publishing.
#[tracing::instrument(err, skip(keys))]
#[tracing::instrument(err, skip_all, fields(?base_key))]
pub async fn purge_local(
self,
base_key: impl AsRef<str> + Debug,
Expand Down Expand Up @@ -398,7 +404,6 @@ impl RequestConfig {

// MARK: JSON fetch
impl RequestConfig {
#[tracing::instrument(err, skip(key, getter))]
pub async fn fetch_one_json<Key, Value, Getter, Fut>(
self,
base_key: impl ToString + Debug,
Expand Down Expand Up @@ -428,7 +433,6 @@ impl RequestConfig {
Ok(values.into_iter().next().map(|(_, v)| v))
}

#[tracing::instrument(err, skip(keys, getter))]
pub async fn fetch_all_json<Key, Value, Getter, Fut>(
self,
base_key: impl ToString + Debug,
Expand All @@ -447,7 +451,6 @@ impl RequestConfig {
.map(|x| x.into_iter().map(|(_, v)| v).collect::<Vec<_>>())
}

#[tracing::instrument(err, skip(keys, getter))]
pub async fn fetch_all_json_with_keys<Key, Value, Getter, Fut>(
self,
base_key: impl ToString + Debug,
Expand Down
Loading