diff --git a/src/cachier/cores/pickle.py b/src/cachier/cores/pickle.py index e9395822..52f304c2 100644 --- a/src/cachier/cores/pickle.py +++ b/src/cachier/cores/pickle.py @@ -6,9 +6,11 @@ # Licensed under the MIT license: # http://www.opensource.org/licenses/MIT-license # Copyright (c) 2016, Shay Palachy +import hashlib import logging import os import pickle # for local caching +import tempfile import time from contextlib import suppress from datetime import datetime, timedelta @@ -28,6 +30,8 @@ class _PickleCore(_BaseCore): """The pickle core class for cachier.""" + _SHARED_LOCK_SUFFIX = ".lock" + class CacheChangeHandler(PatternMatchingEventHandler): """Handles cache-file modification events.""" @@ -71,6 +75,10 @@ def on_modified(self, event) -> None: """A Watchdog Event Handler method.""" # noqa: D401 self._check_calculation() + def on_moved(self, event) -> None: + """A Watchdog Event Handler method.""" # noqa: D401 + self._check_calculation() + def __init__( self, hash_func: Optional[HashFunc], @@ -97,6 +105,21 @@ def cache_fpath(self) -> str: os.makedirs(self.cache_dir, exist_ok=True) return os.path.abspath(os.path.join(os.path.realpath(self.cache_dir), self.cache_fname)) + @property + def _shared_lock_fpath(self) -> str: + cache_hash = hashlib.sha256(self.cache_fpath.encode("utf-8")).hexdigest() + candidate_dirs = ( + os.path.join(tempfile.gettempdir(), "cachier-locks"), + os.path.join(os.path.dirname(self.cache_fpath), ".cachier-locks"), + ) + for lock_dir in candidate_dirs: + try: + os.makedirs(lock_dir, exist_ok=True) + return os.path.join(lock_dir, f"{cache_hash}{self._SHARED_LOCK_SUFFIX}") + except OSError: + continue + return os.path.join(os.path.dirname(self.cache_fpath), f".{cache_hash}{self._SHARED_LOCK_SUFFIX}") + @staticmethod def _convert_legacy_cache_entry( entry: Union[dict, CacheEntry], @@ -113,8 +136,8 @@ def _convert_legacy_cache_entry( def _load_cache_dict(self) -> Dict[str, CacheEntry]: try: - with portalocker.Lock(self.cache_fpath, mode="rb") as cf: - cache = pickle.load(cast(IO[bytes], cf)) + with portalocker.Lock(self._shared_lock_fpath, mode="a+b"), open(self.cache_fpath, "rb") as cache_file: + cache = pickle.load(cast(IO[bytes], cache_file)) self._cache_used_fpath = str(self.cache_fpath) except (FileNotFoundError, EOFError): cache = {} @@ -181,9 +204,29 @@ def _save_cache( fpath += f"_{separate_file_key}" elif hash_str is not None: fpath += f"_{hash_str}" + parent_dir = os.path.dirname(fpath) with self.lock: - with portalocker.Lock(fpath, mode="wb") as cf: - pickle.dump(cache, cast(IO[bytes], cf), protocol=4) + if isinstance(cache, CacheEntry): + with portalocker.Lock(fpath, mode="wb") as cache_file: + pickle.dump(cache, cast(IO[bytes], cache_file), protocol=4) + else: + with portalocker.Lock(self._shared_lock_fpath, mode="a+b"): + temp_path = "" + try: + with tempfile.NamedTemporaryFile( + mode="wb", + dir=parent_dir, + delete=False, + ) as temp_file: + temp_path = temp_file.name + pickle.dump(cache, cast(IO[bytes], temp_file), protocol=4) + temp_file.flush() + os.fsync(temp_file.fileno()) + os.replace(temp_path, fpath) + finally: + if temp_path: + with suppress(FileNotFoundError): + os.remove(temp_path) # the same as check for separate_file, but changed for typing if isinstance(cache, dict): self._cache_dict = cache @@ -256,6 +299,7 @@ async def amark_entry_being_calculated(self, key: str) -> None: def mark_entry_not_calculated(self, key: str) -> None: if self.separate_files: self._mark_entry_not_calculated_separate_files(key) + return # pragma: no cover with self.lock: cache = self.get_cache_dict() # that's ok, we don't need an entry in that case diff --git a/tests/pickle_tests/test_pickle_core.py b/tests/pickle_tests/test_pickle_core.py index 66e27c6b..451363da 100644 --- a/tests/pickle_tests/test_pickle_core.py +++ b/tests/pickle_tests/test_pickle_core.py @@ -709,6 +709,130 @@ def mock_func(): core._save_cache({"key": "value"}, separate_file_key="test_key") +@pytest.mark.pickle +def test_save_cache_removes_temp_file_when_fsync_fails(tmp_path): + """Test _save_cache removes the temp file when fsync fails.""" + core = _PickleCore( + hash_func=None, + cache_dir=tmp_path, + pickle_reload=False, + wait_for_calc_timeout=10, + separate_files=False, + ) + + def mock_func(): + pass + + core.set_func(mock_func) + + with ( + patch("cachier.cores.pickle.os.fsync", side_effect=OSError("fsync failed")), + pytest.raises(OSError, match="fsync failed"), + ): + core._save_cache({"key": CacheEntry(value="value", time=datetime.now(), stale=False, _processing=False)}) + + assert list(tmp_path.iterdir()) == [] + + +@pytest.mark.pickle +def test_save_cache_propagates_tempfile_creation_failure_without_cleanup_error(tmp_path): + """Test _save_cache handles temp-file creation failures before temp_path exists.""" + core = _PickleCore( + hash_func=None, + cache_dir=tmp_path, + pickle_reload=False, + wait_for_calc_timeout=10, + separate_files=False, + ) + + def mock_func(): + pass + + core.set_func(mock_func) + + with ( + patch("cachier.cores.pickle.tempfile.NamedTemporaryFile", side_effect=OSError("tempfile failed")), + patch("cachier.cores.pickle.os.replace") as mock_replace, + patch("cachier.cores.pickle.os.remove") as mock_remove, + pytest.raises(OSError, match="tempfile failed"), + ): + core._save_cache({"key": CacheEntry(value="value", time=datetime.now(), stale=False, _processing=False)}) + + mock_replace.assert_not_called() + mock_remove.assert_not_called() + assert list(tmp_path.iterdir()) == [] + + +@pytest.mark.pickle +def test_shared_lock_fpath_falls_back_to_cache_dir_when_temp_dir_unwritable(tmp_path): + """Test _shared_lock_fpath falls back when the system temp dir is not writable.""" + core = _PickleCore( + hash_func=None, + cache_dir=tmp_path, + pickle_reload=False, + wait_for_calc_timeout=10, + separate_files=False, + ) + + def mock_func(): + pass + + core.set_func(mock_func) + + temp_lock_dir = os.path.join("/non-writable-temp", "cachier-locks") + fallback_lock_dir = os.path.join(core.cache_dir, ".cachier-locks") + + def mock_makedirs(path, exist_ok=False): + if path in (core.cache_dir, fallback_lock_dir): + return None + if path == temp_lock_dir: + raise PermissionError("temp dir not writable") + raise AssertionError(f"Unexpected os.makedirs path: {path}") + + with ( + patch("cachier.cores.pickle.tempfile.gettempdir", return_value="/non-writable-temp"), + patch("cachier.cores.pickle.os.makedirs", side_effect=mock_makedirs), + ): + assert core._shared_lock_fpath == os.path.join( + fallback_lock_dir, + f"{hashlib.sha256(core.cache_fpath.encode('utf-8')).hexdigest()}{core._SHARED_LOCK_SUFFIX}", + ) + + +@pytest.mark.pickle +def test_shared_lock_fpath_uses_cache_dir_file_when_lock_dirs_unwritable(tmp_path): + """Test _shared_lock_fpath falls back to a lockfile in the cache dir.""" + core = _PickleCore( + hash_func=None, + cache_dir=tmp_path, + pickle_reload=False, + wait_for_calc_timeout=10, + separate_files=False, + ) + + def mock_func(): + pass + + core.set_func(mock_func) + + temp_lock_dir = os.path.join("/non-writable-temp", "cachier-locks") + fallback_lock_dir = os.path.join(core.cache_dir, ".cachier-locks") + cache_hash = hashlib.sha256(core.cache_fpath.encode("utf-8")).hexdigest() + + def mock_makedirs(path, exist_ok=False): + if path == core.cache_dir: + return None + if path in (temp_lock_dir, fallback_lock_dir): + raise PermissionError("lock dir not writable") + raise AssertionError(f"Unexpected os.makedirs path: {path}") + + with ( + patch("cachier.cores.pickle.tempfile.gettempdir", return_value="/non-writable-temp"), + patch("cachier.cores.pickle.os.makedirs", side_effect=mock_makedirs), + ): + assert core._shared_lock_fpath == os.path.join(core.cache_dir, f".{cache_hash}{core._SHARED_LOCK_SUFFIX}") + + @pytest.mark.pickle def test_set_entry_should_not_store(tmp_path): """Test set_entry when value should not be stored.""" @@ -1053,6 +1177,70 @@ def mock_get_cache_dict(): assert result == "result" +@pytest.mark.pickle +def test_save_cache_keeps_existing_file_readable_during_write(tmp_path, monkeypatch): + """Test that cache rewrites do not expose a truncated file to plain readers.""" + core = _PickleCore( + hash_func=None, + cache_dir=tmp_path, + pickle_reload=False, + wait_for_calc_timeout=10, + separate_files=False, + ) + + def mock_func(): + pass + + core.set_func(mock_func) + + initial_cache = { + "key1": CacheEntry( + value="result-1", + time=datetime.now(), + stale=False, + _processing=False, + ) + } + updated_cache = { + **initial_cache, + "key2": CacheEntry( + value="result-2", + time=datetime.now(), + stale=False, + _processing=False, + ), + } + core._save_cache(initial_cache) + + dump_started = threading.Event() + allow_dump = threading.Event() + real_pickle_dump = pickle.dump + + def blocking_dump(obj, fh, protocol): + if obj is updated_cache: + dump_started.set() + assert allow_dump.wait(timeout=5) + return real_pickle_dump(obj, fh, protocol=protocol) + + monkeypatch.setattr("cachier.cores.pickle.pickle.dump", blocking_dump) + + writer = threading.Thread(target=core._save_cache, args=(updated_cache,), daemon=True) + writer.start() + + assert dump_started.wait(timeout=5) + with open(core.cache_fpath, "rb") as cache_file: + visible_cache = pickle.load(cache_file) + assert visible_cache == initial_cache + + allow_dump.set() + writer.join(timeout=5) + assert not writer.is_alive() + + with open(core.cache_fpath, "rb") as cache_file: + visible_cache = pickle.load(cache_file) + assert visible_cache == updated_cache + + @pytest.mark.pickle def test_wait_with_polling_calls_timeout_check_when_processing(tmp_path): """Test _wait_with_polling checks timeout while entry is processing."""