diff --git a/src/mountain_madness/_2026/__init__.py b/src/mountain_madness/_2026/__init__.py index 4be31df..a56d083 100644 --- a/src/mountain_madness/_2026/__init__.py +++ b/src/mountain_madness/_2026/__init__.py @@ -1,6 +1,8 @@ +from pathlib import Path + from .counter import CounterFile from .models import CounterResponse -mm_counter = CounterFile("/var/www/mountain_madness/2026/counter.json", save_interval=300, save_threshold=10) -mm_counter.increment("good", 0) # initialize the counter with default values if it doesn't exist -mm_counter.increment("evil", 0) # initialize the counter with default values if it doesn't exist +mm_counter = CounterFile(Path("/var/www/mountain_madness/2026/counter.json"), save_interval=300, save_threshold=10) +mm_counter.increment("good", 0) +mm_counter.increment("evil", 0) diff --git a/src/mountain_madness/_2026/counter.py b/src/mountain_madness/_2026/counter.py index b8c06c9..e22aaa1 100644 --- a/src/mountain_madness/_2026/counter.py +++ b/src/mountain_madness/_2026/counter.py @@ -1,114 +1,109 @@ -import atexit import datetime import fcntl import json import threading -import time from pathlib import Path from constants import TZ_INFO class CounterFile: - def __init__(self, filepath: str, save_interval: int, save_threshold: int): - """ - Counter that saves to a file. - - Args: - filepath: The absolute file path to save the counter to. - save_interval: How often to save to the file, in seconds - save_threshold: How often to save to the file, in number of changes - """ - self.__counters: dict[str, int] = {} - - self.filepath = Path(filepath) + """Thread-safe counter with file persistence.""" + + def __init__( + self, + filepath: Path, + save_interval: int = 60, + save_threshold: int = 10, + ): + self.filepath = filepath self.save_interval = save_interval self.save_threshold = save_threshold - self.lock = threading.Lock() - self.create_time = datetime.datetime.now(TZ_INFO) - self.last_save_time = self.create_time - self.changes_since_last_save = 0 - - self._is_auto_save_enabled = save_interval > 0 - - # If the program shuts down, turn save one last time - atexit.register(self.__save_to_file) - - self.start() - - def start(self): - self.filepath.parent.mkdir(parents=True, exist_ok=True) + self._lock = threading.Lock() + self._counters: dict[str, int] = {} + self._changes_since_save = 0 + self.last_save_time: datetime.datetime | None = None + with self._lock: + self._load_from_file_unlocked() + + def start(self) -> None: + """Start the auto-save background thread.""" + + def auto_save(): + while True: + threading.Event().wait(self.save_interval) + with self._lock: + if self._changes_since_save > 0: + self._save_to_file_unlocked() + + thread = threading.Thread(target=auto_save, daemon=True) + thread.start() + + def increment(self, key: str, amount: int = 1) -> dict[str, int]: + """Increment a counter by the given amount.""" + with self._lock: + self._counters[key] = self._counters.get(key, 0) + amount + self._maybe_save_unlocked() + return self._counters.copy() + + def get(self, key: str) -> int: + """Get the current value of a counter.""" + with self._lock: + return self._counters.get(key, 0) + + def get_all(self) -> dict[str, int]: + """Get a copy of all counters.""" + with self._lock: + return self._counters.copy() + + def reset(self, key: str) -> None: + """Reset a specific counter to zero.""" + with self._lock: + self._counters[key] = 0 + self._maybe_save_unlocked() + + def _maybe_save_unlocked(self) -> None: + """Check threshold and save if needed. Must hold lock.""" + self._changes_since_save += 1 + if self._changes_since_save >= self.save_threshold: + self._save_to_file_unlocked() + + def _save_to_file_unlocked(self) -> None: + """Save counters to file. Must hold lock.""" + last_save_time = datetime.datetime.now(TZ_INFO) save_data = { - "counters": self.__counters, - "last_save_time": self.last_save_time.isoformat(), + "counters": self._counters.copy(), + "last_save_time": last_save_time.isoformat(), } - # Create file if it doesn't exist - if not self.filepath.exists(): - self.filepath.write_text(json.dumps(save_data, indent=2)) - self.__save_to_file() - else: - # Otherwise check the file and load from it - try: - with Path.open(self.filepath) as f: - fcntl.flock(f.fileno(), fcntl.LOCK_SH) - try: - data = json.load(f) - self.__counters = data.get("counters", {}) - finally: - fcntl.flock(f.fileno(), fcntl.LOCK_UN) - except (OSError, json.JSONDecodeError): - self.__counters = {} - if self._is_auto_save_enabled: - self.daemon_thread = threading.Thread(target=self.__save_daemon, daemon=True) - self.daemon_thread.start() - - def increment(self, key: str, amount: int = 1): - with self.lock: - self.__counters[key] = self.__counters.get(key, 0) + amount - self.changes_since_last_save += 1 - - if self.changes_since_last_save >= self.save_threshold: - self.__save_to_file() - - def get_all_counters(self) -> dict[str, int]: - with self.lock: - return self.__counters.copy() - def __save_daemon(self): - while self._is_auto_save_enabled: - time.sleep(self.save_interval) - if self.changes_since_last_save: - self.__save_to_file() - - def __save_to_file(self): - with self.lock: - if not self.changes_since_last_save: - return - last_save_time = datetime.datetime.now(TZ_INFO) - save_data = {"counters": self.__counters.copy(), "last_save_time": last_save_time.isoformat()} - - # Save to temp file and then move it temp_file = self.filepath.with_suffix(".tmp") try: - with Path.open(temp_file, "w") as f: + with temp_file.open("w") as f: fcntl.flock(f.fileno(), fcntl.LOCK_EX) try: json.dump(save_data, f, indent=2) f.flush() finally: fcntl.flock(f.fileno(), fcntl.LOCK_UN) - - temp_file.replace(self.filepath) - except OSError: - print("Error saving counter file") - return - - with self.lock: - self.changes_since_last_save = 0 + temp_file.replace(self.filepath) self.last_save_time = last_save_time + self._changes_since_save = 0 + except OSError as e: + print(f"Error saving counter file: {e}") - def shutdown(self): - self._is_auto_save_enabled = False - if hasattr(self, "daemon_thread") and self.daemon_thread.is_alive(): - self.daemon_thread.join(timeout=5) - self.__save_to_file() + def _load_from_file_unlocked(self) -> None: + """Load counters from file. Must hold lock.""" + try: + with self.filepath.open("r") as f: + fcntl.flock(f.fileno(), fcntl.LOCK_SH) + try: + data = json.load(f) + finally: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + self._counters = data.get("counters", {}) + if last_save := data.get("last_save_time"): + self.last_save_time = datetime.datetime.fromisoformat(last_save) + except FileNotFoundError: + pass + except (json.JSONDecodeError, OSError) as e: + print(f"Error loading counter file: {e}") diff --git a/src/mountain_madness/_2026/models.py b/src/mountain_madness/_2026/models.py index e19af79..0205755 100644 --- a/src/mountain_madness/_2026/models.py +++ b/src/mountain_madness/_2026/models.py @@ -1,7 +1,6 @@ -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel class CounterResponse(BaseModel): - model_config = ConfigDict(from_attributes=True) good: int evil: int diff --git a/src/mountain_madness/_2026/urls.py b/src/mountain_madness/_2026/urls.py index e5c466e..031c8a6 100644 --- a/src/mountain_madness/_2026/urls.py +++ b/src/mountain_madness/_2026/urls.py @@ -16,7 +16,7 @@ operation_id="mm_get_counters", ) async def get_all_counters(): - return CounterResponse(**mm_counter.get_all_counters()) + return CounterResponse(**mm_counter.get_all()) @router.post( @@ -27,8 +27,7 @@ async def get_all_counters(): operation_id="mm_good_increment", ) async def increment_good(): - mm_counter.increment("good") - return CounterResponse(**mm_counter.get_all_counters()) + return CounterResponse(**mm_counter.increment("good")) @router.post( @@ -39,5 +38,4 @@ async def increment_good(): operation_id="mm_evil_increment", ) async def increment_evil(): - mm_counter.increment("evil") - return CounterResponse(**mm_counter.get_all_counters()) + return CounterResponse(**mm_counter.increment("evil"))