Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions src/cachier/cores/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
# Licensed under the MIT license:
# http://www.opensource.org/licenses/MIT-license
# Copyright (c) 2016, Shay Palachy <shaypal5@gmail.com>
import hashlib
import logging
import os
import pickle # for local caching
import tempfile
import time
from contextlib import suppress
from datetime import datetime, timedelta
Expand All @@ -28,6 +30,8 @@
class _PickleCore(_BaseCore):
"""The pickle core class for cachier."""

_SHARED_LOCK_SUFFIX = ".lock"

class CacheChangeHandler(PatternMatchingEventHandler):
"""Handles cache-file modification events."""

Expand Down Expand Up @@ -71,6 +75,10 @@ def on_modified(self, event) -> None:
"""A Watchdog Event Handler method.""" # noqa: D401
self._check_calculation()

def on_moved(self, event) -> None:
"""A Watchdog Event Handler method.""" # noqa: D401
self._check_calculation()

def __init__(
self,
hash_func: Optional[HashFunc],
Expand All @@ -97,6 +105,21 @@ def cache_fpath(self) -> str:
os.makedirs(self.cache_dir, exist_ok=True)
return os.path.abspath(os.path.join(os.path.realpath(self.cache_dir), self.cache_fname))

@property
def _shared_lock_fpath(self) -> str:
cache_hash = hashlib.sha256(self.cache_fpath.encode("utf-8")).hexdigest()
candidate_dirs = (
os.path.join(tempfile.gettempdir(), "cachier-locks"),
os.path.join(os.path.dirname(self.cache_fpath), ".cachier-locks"),
)
for lock_dir in candidate_dirs:
try:
os.makedirs(lock_dir, exist_ok=True)
return os.path.join(lock_dir, f"{cache_hash}{self._SHARED_LOCK_SUFFIX}")
except OSError:
continue
return os.path.join(os.path.dirname(self.cache_fpath), f".{cache_hash}{self._SHARED_LOCK_SUFFIX}")

@staticmethod
def _convert_legacy_cache_entry(
entry: Union[dict, CacheEntry],
Expand All @@ -113,8 +136,8 @@ def _convert_legacy_cache_entry(

def _load_cache_dict(self) -> Dict[str, CacheEntry]:
try:
with portalocker.Lock(self.cache_fpath, mode="rb") as cf:
cache = pickle.load(cast(IO[bytes], cf))
with portalocker.Lock(self._shared_lock_fpath, mode="a+b"), open(self.cache_fpath, "rb") as cache_file:
cache = pickle.load(cast(IO[bytes], cache_file))
self._cache_used_fpath = str(self.cache_fpath)
except (FileNotFoundError, EOFError):
cache = {}
Expand Down Expand Up @@ -181,9 +204,29 @@ def _save_cache(
fpath += f"_{separate_file_key}"
elif hash_str is not None:
fpath += f"_{hash_str}"
parent_dir = os.path.dirname(fpath)
with self.lock:
with portalocker.Lock(fpath, mode="wb") as cf:
pickle.dump(cache, cast(IO[bytes], cf), protocol=4)
if isinstance(cache, CacheEntry):
with portalocker.Lock(fpath, mode="wb") as cache_file:
pickle.dump(cache, cast(IO[bytes], cache_file), protocol=4)
else:
with portalocker.Lock(self._shared_lock_fpath, mode="a+b"):
temp_path = ""
try:
with tempfile.NamedTemporaryFile(
mode="wb",
dir=parent_dir,
delete=False,
) as temp_file:
temp_path = temp_file.name
pickle.dump(cache, cast(IO[bytes], temp_file), protocol=4)
temp_file.flush()
os.fsync(temp_file.fileno())
os.replace(temp_path, fpath)
finally:
if temp_path:
with suppress(FileNotFoundError):
os.remove(temp_path)
# the same as check for separate_file, but changed for typing
if isinstance(cache, dict):
self._cache_dict = cache
Expand Down Expand Up @@ -256,6 +299,7 @@ async def amark_entry_being_calculated(self, key: str) -> None:
def mark_entry_not_calculated(self, key: str) -> None:
if self.separate_files:
self._mark_entry_not_calculated_separate_files(key)
return # pragma: no cover
with self.lock:
cache = self.get_cache_dict()
# that's ok, we don't need an entry in that case
Expand Down
188 changes: 188 additions & 0 deletions tests/pickle_tests/test_pickle_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,130 @@ def mock_func():
core._save_cache({"key": "value"}, separate_file_key="test_key")


@pytest.mark.pickle
def test_save_cache_removes_temp_file_when_fsync_fails(tmp_path):
"""Test _save_cache removes the temp file when fsync fails."""
core = _PickleCore(
hash_func=None,
cache_dir=tmp_path,
pickle_reload=False,
wait_for_calc_timeout=10,
separate_files=False,
)

def mock_func():
pass

core.set_func(mock_func)

with (
patch("cachier.cores.pickle.os.fsync", side_effect=OSError("fsync failed")),
pytest.raises(OSError, match="fsync failed"),
):
core._save_cache({"key": CacheEntry(value="value", time=datetime.now(), stale=False, _processing=False)})

assert list(tmp_path.iterdir()) == []


@pytest.mark.pickle
def test_save_cache_propagates_tempfile_creation_failure_without_cleanup_error(tmp_path):
"""Test _save_cache handles temp-file creation failures before temp_path exists."""
core = _PickleCore(
hash_func=None,
cache_dir=tmp_path,
pickle_reload=False,
wait_for_calc_timeout=10,
separate_files=False,
)

def mock_func():
pass

core.set_func(mock_func)

with (
patch("cachier.cores.pickle.tempfile.NamedTemporaryFile", side_effect=OSError("tempfile failed")),
patch("cachier.cores.pickle.os.replace") as mock_replace,
patch("cachier.cores.pickle.os.remove") as mock_remove,
pytest.raises(OSError, match="tempfile failed"),
):
core._save_cache({"key": CacheEntry(value="value", time=datetime.now(), stale=False, _processing=False)})

mock_replace.assert_not_called()
mock_remove.assert_not_called()
assert list(tmp_path.iterdir()) == []


@pytest.mark.pickle
def test_shared_lock_fpath_falls_back_to_cache_dir_when_temp_dir_unwritable(tmp_path):
"""Test _shared_lock_fpath falls back when the system temp dir is not writable."""
core = _PickleCore(
hash_func=None,
cache_dir=tmp_path,
pickle_reload=False,
wait_for_calc_timeout=10,
separate_files=False,
)

def mock_func():
pass

core.set_func(mock_func)

temp_lock_dir = os.path.join("/non-writable-temp", "cachier-locks")
fallback_lock_dir = os.path.join(core.cache_dir, ".cachier-locks")

def mock_makedirs(path, exist_ok=False):
if path in (core.cache_dir, fallback_lock_dir):
return None
if path == temp_lock_dir:
raise PermissionError("temp dir not writable")
raise AssertionError(f"Unexpected os.makedirs path: {path}")

with (
patch("cachier.cores.pickle.tempfile.gettempdir", return_value="/non-writable-temp"),
patch("cachier.cores.pickle.os.makedirs", side_effect=mock_makedirs),
):
assert core._shared_lock_fpath == os.path.join(
fallback_lock_dir,
f"{hashlib.sha256(core.cache_fpath.encode('utf-8')).hexdigest()}{core._SHARED_LOCK_SUFFIX}",
)


@pytest.mark.pickle
def test_shared_lock_fpath_uses_cache_dir_file_when_lock_dirs_unwritable(tmp_path):
"""Test _shared_lock_fpath falls back to a lockfile in the cache dir."""
core = _PickleCore(
hash_func=None,
cache_dir=tmp_path,
pickle_reload=False,
wait_for_calc_timeout=10,
separate_files=False,
)

def mock_func():
pass

core.set_func(mock_func)

temp_lock_dir = os.path.join("/non-writable-temp", "cachier-locks")
fallback_lock_dir = os.path.join(core.cache_dir, ".cachier-locks")
cache_hash = hashlib.sha256(core.cache_fpath.encode("utf-8")).hexdigest()

def mock_makedirs(path, exist_ok=False):
if path == core.cache_dir:
return None
if path in (temp_lock_dir, fallback_lock_dir):
raise PermissionError("lock dir not writable")
raise AssertionError(f"Unexpected os.makedirs path: {path}")

with (
patch("cachier.cores.pickle.tempfile.gettempdir", return_value="/non-writable-temp"),
patch("cachier.cores.pickle.os.makedirs", side_effect=mock_makedirs),
):
assert core._shared_lock_fpath == os.path.join(core.cache_dir, f".{cache_hash}{core._SHARED_LOCK_SUFFIX}")


@pytest.mark.pickle
def test_set_entry_should_not_store(tmp_path):
"""Test set_entry when value should not be stored."""
Expand Down Expand Up @@ -1053,6 +1177,70 @@ def mock_get_cache_dict():
assert result == "result"


@pytest.mark.pickle
def test_save_cache_keeps_existing_file_readable_during_write(tmp_path, monkeypatch):
"""Test that cache rewrites do not expose a truncated file to plain readers."""
core = _PickleCore(
hash_func=None,
cache_dir=tmp_path,
pickle_reload=False,
wait_for_calc_timeout=10,
separate_files=False,
)

def mock_func():
pass

core.set_func(mock_func)

initial_cache = {
"key1": CacheEntry(
value="result-1",
time=datetime.now(),
stale=False,
_processing=False,
)
}
updated_cache = {
**initial_cache,
"key2": CacheEntry(
value="result-2",
time=datetime.now(),
stale=False,
_processing=False,
),
}
core._save_cache(initial_cache)

dump_started = threading.Event()
allow_dump = threading.Event()
real_pickle_dump = pickle.dump

def blocking_dump(obj, fh, protocol):
if obj is updated_cache:
dump_started.set()
assert allow_dump.wait(timeout=5)
return real_pickle_dump(obj, fh, protocol=protocol)

monkeypatch.setattr("cachier.cores.pickle.pickle.dump", blocking_dump)

writer = threading.Thread(target=core._save_cache, args=(updated_cache,), daemon=True)
writer.start()

assert dump_started.wait(timeout=5)
with open(core.cache_fpath, "rb") as cache_file:
visible_cache = pickle.load(cache_file)
assert visible_cache == initial_cache

allow_dump.set()
writer.join(timeout=5)
assert not writer.is_alive()

with open(core.cache_fpath, "rb") as cache_file:
visible_cache = pickle.load(cache_file)
assert visible_cache == updated_cache


@pytest.mark.pickle
def test_wait_with_polling_calls_timeout_check_when_processing(tmp_path):
"""Test _wait_with_polling checks timeout while entry is processing."""
Expand Down