From 9e27ada47e508c4f169a0243f8c39dfa15f5b8a3 Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 21:22:24 -0700 Subject: [PATCH 1/8] chore: replace darglint with pydoclint (TOOL-007/PC-006/TOOL-013) Migrate from archived darglint to pydoclint per updated standards manifest. Replaces [tool.darglint] config and local hook with the upstream pydoclint pre-commit hook. Removes .darglint config file. Lockfile swaps darglint 1.8.1 for pydoclint 0.8.4. Posture: keep raises checks, require arg type hints, lenient on Returns/Yields missing sections. NOTE: 807 DOC findings exist across src/ and packages/ (histogram below); these require a separate docstring-fix pass. DOC histogram: DOC109/110x157, DOC203x174, DOC105x156, DOC301x33, DOC103/101x28, DOC503x19, DOC603x14, DOC605x11, DOC502x10, DOC602x8, DOC601x5, DOC501x4, DOC107x2, DOC106x1. Co-Authored-By: Claude Sonnet 4.6 --- .darglint | 21 --------------------- .pre-commit-config.yaml | 15 +++++++-------- pyproject.toml | 32 ++++++++++++-------------------- uv.lock | 36 +++++++++++++++++++++++++----------- 4 files changed, 44 insertions(+), 60 deletions(-) delete mode 100644 .darglint diff --git a/.darglint b/.darglint deleted file mode 100644 index a1a6c4a..0000000 --- a/.darglint +++ /dev/null @@ -1,21 +0,0 @@ -[darglint] -# Google-style docstrings (matches ruff pydocstyle convention) -docstring_style = google - -# Strictness levels: short, long, full -# - short: Only documented items must exist in signature -# - long: All parameters must be documented (recommended) -# - full: Types in docstring must match annotations -strictness = long - -# Ignore these error codes -# DAR101: Missing parameter(s) in Docstring -# DAR201: Missing "Returns" in Docstring -# DAR202: Excess "Returns" in Docstring (documented but not returned) -# DAR301: Missing "Yields" in Docstring -# DAR401: Missing exception(s) in Raises section -# DAR402: Excess exception(s) in Raises section (documented but not raised directly) -ignore = DAR101,DAR201,DAR202,DAR301,DAR401,DAR402 - -# Ignore in these directories (tests, scripts, benchmarks, tools) -ignore_regex = ^(tests|scripts|benchmarks|tools|\.claude)/.*$ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a945573..d8d863f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ ci: autoupdate_schedule: monthly autofix_prs: true - skip: [validate-front-matter, qlty-check, qlty-full, trufflehog, darglint, bandit, bandit-full] # Skip local-only hooks + skip: [validate-front-matter, qlty-check, qlty-full, trufflehog, bandit, bandit-full] # Skip local-only hooks repos: # ============================================================================ @@ -199,14 +199,13 @@ repos: # ============================================================================ # Docstring Argument Validation # ============================================================================ - # Darglint validates that docstring arguments match function signatures - # Configuration in pyproject.toml [tool.darglint] - - repo: local + # Pydoclint validates that docstring arguments match function signatures + # Configuration in pyproject.toml [tool.pydoclint] + - repo: https://github.com/jsh9/pydoclint + rev: 88d83c94156c5e51a09938e77019f2c58e92ab58 # 0.8.4 hooks: - - id: darglint - name: Darglint docstring validation - entry: uv run darglint - language: system + - id: pydoclint + args: ["--config=pyproject.toml"] types: [python] stages: [pre-commit] exclude: ^(tests|scripts|benchmarks|tools)/ diff --git a/pyproject.toml b/pyproject.toml index 8039462..9c1d6c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ dev = [ "bandit>=1.7.0", "vulture>=2.11", # Dead code detection "pip-audit>=2.7.0", # Python vulnerability scanning - "darglint>=1.8.1", # Docstring argument validation + "pydoclint>=0.8.4", # Docstring argument validation # Pre-commit "pre-commit>=3.3.0", @@ -651,26 +651,18 @@ ignore-init-method = true ignore-init-module = true color = true -# Darglint Configuration (Docstring Argument Validation) +# Pydoclint Configuration (Docstring Argument Validation) # Validates that docstring arguments match function signatures -# Reference: https://github.com/terrencepreilly/darglint -[tool.darglint] -# Google-style docstrings (matches ruff pydocstyle convention) -docstring_style = "google" -# Strictness levels: short, long, full -# - short: Only documented items must exist in signature -# - long: All parameters must be documented (recommended) -# - full: Types in docstring must match annotations -strictness = "long" -# Ignore missing parameter documentation in these cases -ignore = [ - "DAR101", # Missing parameter(s) in Docstring (initially lenient) - "DAR201", # Missing "Returns" in Docstring (handled by pydocstyle) - "DAR301", # Missing "Yields" in Docstring - "DAR401", # Missing exception(s) in Raises section -] -# Ignore in these directories -ignore_regex = "^(tests|scripts|benchmarks|tools)/.*$" +# Reference: https://github.com/jsh9/pydoclint +[tool.pydoclint] +style = "google" +exclude = '\.git|tests/|scripts/|benchmarks/|tools/|noxfile\.py|\.claude/' +arg-type-hints-in-docstring = true +arg-type-hints-in-signature = true +skip-checking-raises = false +require-return-section-when-returning-nothing = false +require-return-section-when-returning-values = false +require-yield-section-when-yielding-values = false # Mutation Testing Configuration (mutmut) [tool.mutmut] diff --git a/uv.lock b/uv.lock index cdbaef7..cedcd04 100644 --- a/uv.lock +++ b/uv.lock @@ -918,15 +918,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/1b/534ad8a5e0f9470522811a8e5a9bc5d328fb7738ba29faf357467a4ef6d0/cyclonedx_python_lib-11.6.0-py3-none-any.whl", hash = "sha256:94f4aae97db42a452134dafdddcfab9745324198201c4777ed131e64c8380759", size = 511157, upload-time = "2025-12-02T12:28:44.158Z" }, ] -[[package]] -name = "darglint" -version = "1.8.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d4/2c/86e8549e349388c18ca8a4ff8661bb5347da550f598656d32a98eaaf91cc/darglint-1.8.1.tar.gz", hash = "sha256:080d5106df149b199822e7ee7deb9c012b49891538f14a11be681044f0bb20da", size = 74435, upload-time = "2021-10-18T03:40:37.283Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/69/28/85d1e0396d64422c5218d68e5cdcc53153aa8a2c83c7dbc3ee1502adf3a1/darglint-1.8.1-py3-none-any.whl", hash = "sha256:5ae11c259c17b0701618a20c3da343a3eb98b3bc4b5a83d31cdd94f5ebdced8d", size = 120767, upload-time = "2021-10-18T03:40:35.034Z" }, -] - [[package]] name = "debugpy" version = "1.8.17" @@ -1014,6 +1005,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/5a/18ad964b0086c6e62e2e7500f7edc89e3faa45033c71c1893d34eed2b2de/dnspython-2.8.0-py3-none-any.whl", hash = "sha256:01d9bbc4a2d76bf0db7c1f729812ded6d912bd318d3b1cf81d30c0f845dbf3af", size = 331094, upload-time = "2025-09-07T18:57:58.071Z" }, ] +[[package]] +name = "docstring-parser-fork" +version = "0.0.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/bf/27f9cab2f0cd1d17a4420572088bbc19f36d726fbcf165edf226a8926dbc/docstring_parser_fork-0.0.14.tar.gz", hash = "sha256:a2743a63d8d36c09650594f7b4ab5b2758fee8629dcf794d1b221b23179baa5c", size = 34551, upload-time = "2025-09-07T17:27:38.272Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/50/98b146aea0f1cd7531d25f12bea69fa9ce8d1662124f93fb30dc4511b65e/docstring_parser_fork-0.0.14-py3-none-any.whl", hash = "sha256:4c544f234ef2cc2749a3df32b70c437d77888b1099143a1ad5454452c574b9af", size = 43063, upload-time = "2025-09-07T17:27:37.012Z" }, +] + [[package]] name = "dparse" version = "0.6.4" @@ -2940,6 +2940,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, ] +[[package]] +name = "pydoclint" +version = "0.8.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "docstring-parser-fork" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/c9/357fb4af8ca78b97b6642e18e8b692a7073bf118709e559ec6f8a893774d/pydoclint-0.8.4.tar.gz", hash = "sha256:5eb5d8ae3d17bba9a1e4ac3ccdedf46152d45e60528895c26712cf7337b2a054", size = 195295, upload-time = "2026-05-16T20:31:37.046Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/f9/f15c95d6b200167cb22c5eca5eecfa9d28a8ee3f74095f1cd2345c71f2f9/pydoclint-0.8.4-py3-none-any.whl", hash = "sha256:5e0f94f785d0e902faacebb117aadf84d6e30c5f781e0fdd0ee03c3b80ea2098", size = 82023, upload-time = "2026-05-16T20:31:36.024Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -3097,7 +3111,6 @@ dev = [ { name = "bandit" }, { name = "basedpyright" }, { name = "cryptography" }, - { name = "darglint" }, { name = "email-validator" }, { name = "fastapi" }, { name = "google-api-core" }, @@ -3121,6 +3134,7 @@ dev = [ { name = "nox-uv" }, { name = "pip-audit" }, { name = "pre-commit" }, + { name = "pydoclint" }, { name = "pyjwt" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -3141,7 +3155,6 @@ requires-dist = [ { name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0" }, { name = "basedpyright", marker = "extra == 'dev'", specifier = ">=1.18.0" }, { name = "cryptography", marker = "extra == 'dev'", specifier = ">=41.0.0" }, - { name = "darglint", marker = "extra == 'dev'", specifier = ">=1.8.1" }, { name = "email-validator", marker = "extra == 'dev'", specifier = ">=2.0.0" }, { name = "fastapi", marker = "extra == 'dev'", specifier = ">=0.100.0" }, { name = "google-api-core", marker = "extra == 'dev'", specifier = ">=2.0.0" }, @@ -3167,6 +3180,7 @@ requires-dist = [ { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.3.0" }, { name = "pydantic", specifier = ">=2.0.0" }, { name = "pydantic-settings", specifier = ">=2.0.0" }, + { name = "pydoclint", marker = "extra == 'dev'", specifier = ">=0.8.4" }, { name = "pyjwt", marker = "extra == 'dev'", specifier = ">=2.8.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" }, { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, From eb2027fcdacb98bb651d0a66a55a3d3820b5d313 Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 21:55:48 -0700 Subject: [PATCH 2/8] docs(src): align docstrings with signatures for pydoclint Add argument and return type hints to docstrings in src/python_libs to match function signatures exactly. Move __init__ argument docs into the class docstring Args sections per pydoclint DOC301 guidance. Document the pydantic model_config class attribute without a type hint (the assignment carries no annotation). src/ now reports zero pydoclint violations. Co-Authored-By: Claude Sonnet 4.6 --- src/python_libs/core/config.py | 8 +- src/python_libs/core/exceptions.py | 148 +++++++++++++---------------- src/python_libs/utils/logging.py | 34 +++---- 3 files changed, 88 insertions(+), 102 deletions(-) diff --git a/src/python_libs/core/config.py b/src/python_libs/core/config.py index 4092b29..c9286ce 100644 --- a/src/python_libs/core/config.py +++ b/src/python_libs/core/config.py @@ -14,9 +14,11 @@ class Settings(BaseSettings): Configuration settings for the application, loaded from environment variables. Attributes: - log_level: The logging level for the application. - json_logs: Flag to enable or disable JSON formatted logs. - include_timestamp: Flag to include timestamps in logs. + model_config: Pydantic settings configuration. + log_level (Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]): + The logging level for the application. + json_logs (bool): Flag to enable or disable JSON formatted logs. + include_timestamp (bool): Flag to include timestamps in logs. """ model_config = SettingsConfigDict( diff --git a/src/python_libs/core/exceptions.py b/src/python_libs/core/exceptions.py index 0247ee3..47f1837 100644 --- a/src/python_libs/core/exceptions.py +++ b/src/python_libs/core/exceptions.py @@ -43,10 +43,11 @@ class ProjectBaseError(Exception): All custom exceptions in the project should inherit from this class to enable unified error handling and logging. - Attributes: - message: Human-readable error message. - details: Additional context about the error (optional). - error_code: Machine-readable error code for API responses (optional). + Args: + message (str): Human-readable error description. + details (dict[str, Any] | None): Additional context as key-value + pairs. + error_code (str | None): Machine-readable error code. Example: >>> raise ProjectBaseError("Something went wrong", error_code="ERR001") @@ -59,13 +60,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize the exception. - - Args: - message: Human-readable error description. - details: Additional context as key-value pairs. - error_code: Machine-readable error code. - """ super().__init__(message) self.message = message self.details = details or {} @@ -75,7 +69,8 @@ def to_dict(self) -> dict[str, Any]: """Convert exception to dictionary for API responses. Returns: - Dictionary with error details suitable for JSON serialization. + dict[str, Any]: Dictionary with error details suitable for JSON + serialization. """ result: dict[str, Any] = { "error": self.__class__.__name__, @@ -108,6 +103,13 @@ class ValidationError(ProjectBaseError): Raised when user input or data fails validation rules. Includes field-level error details for form validation. + Args: + message (str): Description of the validation failure. + field (str | None): Name of the field that failed validation. + value (Any): The invalid value (will be sanitized in logs). + details (dict[str, Any] | None): Additional validation context. + error_code (str | None): Machine-readable error code. + Example: >>> raise ValidationError( ... "Invalid email format", @@ -125,15 +127,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize validation error with field context. - - Args: - message: Description of the validation failure. - field: Name of the field that failed validation. - value: The invalid value (will be sanitized in logs). - details: Additional validation context. - error_code: Machine-readable error code. - """ details = details or {} if field: details["field"] = field @@ -153,6 +146,14 @@ class ResourceNotFoundError(ProjectBaseError): Raised when a requested resource (entity, file, record) cannot be found. + Args: + message (str): Description of what was not found. + resource_type (str | None): Type of resource (e.g., "User", + "Document"). + resource_id (str | None): Identifier of the missing resource. + details (dict[str, Any] | None): Additional context. + error_code (str | None): Machine-readable error code. + Example: >>> raise ResourceNotFoundError( ... "User not found", @@ -170,15 +171,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize resource not found error. - - Args: - message: Description of what was not found. - resource_type: Type of resource (e.g., "User", "Document"). - resource_id: Identifier of the missing resource. - details: Additional context. - error_code: Machine-readable error code. - """ details = details or {} if resource_type: details["resource_type"] = resource_type @@ -192,6 +184,12 @@ class AuthenticationError(ProjectBaseError): Raised when authentication fails (invalid credentials, expired tokens, etc.). + Args: + message (str): Description of authentication failure. + details (dict[str, Any] | None): Additional context (avoid including + sensitive data). + error_code (str | None): Machine-readable error code. + Example: >>> raise AuthenticationError("Invalid or expired token") """ @@ -203,13 +201,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize authentication error. - - Args: - message: Description of authentication failure. - details: Additional context (avoid including sensitive data). - error_code: Machine-readable error code. - """ super().__init__( message, details=details, error_code=error_code or "AUTH_FAILED" ) @@ -220,6 +211,13 @@ class AuthorizationError(ProjectBaseError): Raised when a user lacks permission to perform an action. + Args: + message (str): Description of permission failure. + required_permission (str | None): The permission that was required. + resource (str | None): The resource access was denied to. + details (dict[str, Any] | None): Additional context. + error_code (str | None): Machine-readable error code. + Example: >>> raise AuthorizationError( ... "Insufficient permissions", @@ -237,15 +235,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize authorization error. - - Args: - message: Description of permission failure. - required_permission: The permission that was required. - resource: The resource access was denied to. - details: Additional context. - error_code: Machine-readable error code. - """ details = details or {} if required_permission: details["required_permission"] = required_permission @@ -259,6 +248,13 @@ class ExternalServiceError(ProjectBaseError): Base class for errors from external services (APIs, databases, etc.). + Args: + message (str): Description of the service error. + service_name (str | None): Name of the external service. + status_code (int | None): HTTP status code if applicable. + details (dict[str, Any] | None): Additional context. + error_code (str | None): Machine-readable error code. + Example: >>> raise ExternalServiceError( ... "Payment gateway unavailable", @@ -276,15 +272,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize external service error. - - Args: - message: Description of the service error. - service_name: Name of the external service. - status_code: HTTP status code if applicable. - details: Additional context. - error_code: Machine-readable error code. - """ details = details or {} if service_name: details["service_name"] = service_name @@ -300,6 +287,15 @@ class APIError(ExternalServiceError): Raised when calls to external APIs fail. + Args: + message (str): Description of the API error. + service_name (str | None): Name of the external API. + status_code (int | None): HTTP status code from the API. + retry_after (int | None): Seconds to wait before retrying (for rate + limits). + details (dict[str, Any] | None): Additional context. + error_code (str | None): Machine-readable error code. + Example: >>> raise APIError( ... "GitHub API rate limit exceeded", @@ -319,16 +315,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize API error. - - Args: - message: Description of the API error. - service_name: Name of the external API. - status_code: HTTP status code from the API. - retry_after: Seconds to wait before retrying (for rate limits). - details: Additional context. - error_code: Machine-readable error code. - """ details = details or {} if retry_after: details["retry_after"] = retry_after @@ -346,6 +332,13 @@ class DatabaseError(ExternalServiceError): Raised when database operations fail (connection issues, constraint violations, etc.). + Args: + message (str): Description of the database error. + operation (str | None): The database operation that failed. + table (str | None): The table/collection involved. + details (dict[str, Any] | None): Additional context. + error_code (str | None): Machine-readable error code. + Example: >>> raise DatabaseError( ... "Unique constraint violation", @@ -363,15 +356,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize database error. - - Args: - message: Description of the database error. - operation: The database operation that failed. - table: The table/collection involved. - details: Additional context. - error_code: Machine-readable error code. - """ details = details or {} if operation: details["operation"] = operation @@ -390,6 +374,13 @@ class BusinessLogicError(ProjectBaseError): Raised when operations violate business rules or domain constraints. + Args: + message (str): Description of the rule violation. + rule (str | None): Name of the business rule violated. + context (dict[str, Any] | None): Business context for the violation. + details (dict[str, Any] | None): Additional context. + error_code (str | None): Machine-readable error code. + Example: >>> raise BusinessLogicError( ... "Insufficient funds for transfer", @@ -407,15 +398,6 @@ def __init__( details: dict[str, Any] | None = None, error_code: str | None = None, ) -> None: - """Initialize business logic error. - - Args: - message: Description of the rule violation. - rule: Name of the business rule violated. - context: Business context for the violation. - details: Additional context. - error_code: Machine-readable error code. - """ details = details or {} if rule: details["rule"] = rule diff --git a/src/python_libs/utils/logging.py b/src/python_libs/utils/logging.py index 1fc2942..671449e 100644 --- a/src/python_libs/utils/logging.py +++ b/src/python_libs/utils/logging.py @@ -42,14 +42,15 @@ def setup_logging( for the environment (JSON for production, rich console for development). Args: - level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). + level (str): Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Defaults to INFO. - json_logs: If True, output JSON logs (production mode). If False, + json_logs (bool): If True, output JSON logs (production mode). If False, use rich console formatting (development mode). Defaults to False. - include_timestamp: Whether to include timestamps in log output. + include_timestamp (bool): Whether to include timestamps in log output. Defaults to True. - include_correlation: Whether to include correlation IDs from request - context in log output. Defaults to True. Requires API framework. + include_correlation (bool): Whether to include correlation IDs from + request context in log output. Defaults to True. Requires API + framework. Example: >>> # Development setup @@ -88,12 +89,12 @@ def noop_processor( """No-op processor that passes through the event dict unchanged. Args: - _logger: The wrapped logger instance (unused). - _method_name: The logging method name (unused). - event_dict: The event dictionary to pass through. + _logger (WrappedLogger): The wrapped logger instance (unused). + _method_name (str): The logging method name (unused). + event_dict (EventDict): The event dictionary to pass through. Returns: - The event dictionary unchanged. + EventDict: The event dictionary unchanged. """ return event_dict @@ -135,10 +136,10 @@ def get_logger(name: str) -> BoundLogger: typically be called with __name__ to create module-specific loggers. Args: - name: Logger name (typically __name__ of the module). + name (str): Logger name (typically __name__ of the module). Returns: - Configured structlog logger instance with methods like: + BoundLogger: Configured structlog logger instance with methods like: - logger.info() - logger.debug() - logger.warning() @@ -170,11 +171,12 @@ def log_performance( duration, success status, and additional context. Args: - logger: Structlog logger instance from get_logger(). - operation: Name of the operation being timed. - duration_ms: Duration in milliseconds. - success: Whether the operation succeeded. Defaults to True. - **context: Additional context key-value pairs to include in the log. + logger (BoundLogger): Structlog logger instance from get_logger(). + operation (str): Name of the operation being timed. + duration_ms (float): Duration in milliseconds. + success (bool): Whether the operation succeeded. Defaults to True. + **context (object): Additional context key-value pairs to include in + the log. Example: >>> logger = get_logger(__name__) From 170d26b0c3dafa85281e8301bfdbd75604ecf322 Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 21:57:20 -0700 Subject: [PATCH 3/8] docs(gemini-image): align docstrings with signatures for pydoclint Add argument and return type hints to docstrings to match signatures exactly. Remove ImportError from generate_image Raises section: it is raised in the _get_genai helper, not directly in the function body, so pydoclint does not count it. Co-Authored-By: Claude Sonnet 4.6 --- .../src/gemini_image/generator.py | 70 ++++++++++--------- .../gemini-image/src/gemini_image/utils.py | 18 ++--- 2 files changed, 46 insertions(+), 42 deletions(-) diff --git a/packages/gemini-image/src/gemini_image/generator.py b/packages/gemini-image/src/gemini_image/generator.py index 7eb83c3..283a952 100644 --- a/packages/gemini-image/src/gemini_image/generator.py +++ b/packages/gemini-image/src/gemini_image/generator.py @@ -68,25 +68,29 @@ def generate_image( """Generate an image using Gemini. Args: - prompt: Text description of the image to generate. - model_key: Model to use ('flash' or 'pro'). - reference_images: Optional list of reference images for editing/style. - output_path: Optional output file path. If not provided, generates - a timestamped filename. - output_dir: Optional output directory. Defaults to current directory. - aspect_ratio: Aspect ratio for pro model (e.g., "16:9", "1:1"). - image_size: Image size for pro model ("1K", "2K", "4K"). - use_search: Enable Google Search grounding (pro model only). - save_thoughts: Save intermediate thought images (pro model only). - verbose: Show detailed thinking process and thought signatures. - is_draft: Generate at 1K resolution for fast iteration. + prompt (str): Text description of the image to generate. + model_key (ModelKey): Model to use ('flash' or 'pro'). + reference_images (list[Path] | None): Optional list of reference + images for editing/style. + output_path (Path | None): Optional output file path. If not provided, + generates a timestamped filename. + output_dir (Path | None): Optional output directory. Defaults to + current directory. + aspect_ratio (AspectRatio | None): Aspect ratio for pro model (e.g., + "16:9", "1:1"). + image_size (ImageSize | None): Image size for pro model ("1K", "2K", + "4K"). + use_search (bool): Enable Google Search grounding (pro model only). + save_thoughts (bool): Save intermediate thought images (pro model + only). + verbose (bool): Show detailed thinking process and thought signatures. + is_draft (bool): Generate at 1K resolution for fast iteration. Returns: - Path to the generated image, or None on failure. + Path | None: Path to the generated image, or None on failure. Raises: ValueError: If model_key is invalid or API key is missing. - ImportError: If google-genai is not installed. """ genai, types = _get_genai() @@ -321,18 +325,18 @@ def generate_story_sequence( visual continuity. Args: - base_prompt: Base story description. - num_parts: Number of story parts to generate. - model_key: Model to use. - output_prefix: Prefix for output files (e.g., "story" -> + base_prompt (str): Base story description. + num_parts (int): Number of story parts to generate. + model_key (ModelKey): Model to use. + output_prefix (Path | None): Prefix for output files (e.g., "story" -> story_part1.png, story_part2.png). - output_dir: Output directory for generated images. - aspect_ratio: Aspect ratio for all images. - image_size: Image size for all images. - verbose: Show detailed process. + output_dir (Path | None): Output directory for generated images. + aspect_ratio (AspectRatio | None): Aspect ratio for all images. + image_size (ImageSize | None): Image size for all images. + verbose (bool): Show detailed process. Returns: - List of paths to generated images. + list[Path]: List of paths to generated images. Raises: ValueError: If num_parts < 1. @@ -442,18 +446,18 @@ def finalize_draft( """Finalize a draft image by regenerating at higher resolution. Args: - draft_path: Path to the draft image. - prompt: Optional refinement prompt. If not provided, uses a - default upscaling prompt. - model_key: Model to use. - output_path: Output path for the final image. - output_dir: Output directory. - aspect_ratio: Aspect ratio (default: "16:9"). - image_size: Target resolution (default: "2K"). - verbose: Show detailed process. + draft_path (Path): Path to the draft image. + prompt (str | None): Optional refinement prompt. If not provided, uses + a default upscaling prompt. + model_key (ModelKey): Model to use. + output_path (Path | None): Output path for the final image. + output_dir (Path | None): Output directory. + aspect_ratio (AspectRatio | None): Aspect ratio (default: "16:9"). + image_size (ImageSize | None): Target resolution (default: "2K"). + verbose (bool): Show detailed process. Returns: - Path to the finalized image, or None on failure. + Path | None: Path to the finalized image, or None on failure. Raises: FileNotFoundError: If the draft image doesn't exist. diff --git a/packages/gemini-image/src/gemini_image/utils.py b/packages/gemini-image/src/gemini_image/utils.py index 62de1da..0cba276 100644 --- a/packages/gemini-image/src/gemini_image/utils.py +++ b/packages/gemini-image/src/gemini_image/utils.py @@ -14,11 +14,11 @@ def get_api_key(env_file: Path | None = None) -> str: """Get the Gemini API key from environment or .env file. Args: - env_file: Optional path to .env file. If not provided, checks - GEMINI_API_KEY environment variable only. + env_file (Path | None): Optional path to .env file. If not provided, + checks GEMINI_API_KEY environment variable only. Returns: - The API key string. + str: The API key string. Raises: ValueError: If no API key is found. @@ -48,10 +48,10 @@ def load_image_as_base64(image_path: Path) -> tuple[str, str]: """Load an image file and return base64 data and mime type. Args: - image_path: Path to the image file. + image_path (Path): Path to the image file. Returns: - Tuple of (base64_encoded_data, mime_type). + tuple[str, str]: Tuple of (base64_encoded_data, mime_type). Raises: FileNotFoundError: If the image file doesn't exist. @@ -82,10 +82,10 @@ def decode_base64_image(base64_data: str) -> bytes: """Decode base64 image data to bytes. Args: - base64_data: Base64-encoded image data. + base64_data (str): Base64-encoded image data. Returns: - Raw image bytes. + bytes: Raw image bytes. """ return base64.standard_b64decode(base64_data) @@ -95,10 +95,10 @@ def get_file_extension(mime_type: str) -> str: """Get file extension for a given MIME type. Args: - mime_type: MIME type string (e.g., "image/png"). + mime_type (str): MIME type string (e.g., "image/png"). Returns: - File extension including the dot (e.g., ".png"). + str: File extension including the dot (e.g., ".png"). """ extensions = { From 34e4fcbcbc35e971e70604bc4c1ae1838b1c1e6a Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 21:59:57 -0700 Subject: [PATCH 4/8] docs(gcs-utilities): align docstrings with signatures for pydoclint Add argument and return type hints to docstrings to match signatures exactly. Move GCSClient.__init__ argument docs into the class docstring. Remove Raises entries for exceptions raised only in helper methods (FileNotFoundError via _validate_local_path) and drop the Raises section from upload_directory which has no direct raise statements. Document the broad re-raise in _get_or_create_bucket. Co-Authored-By: Claude Sonnet 4.6 --- .../gcs-utilities/src/gcs_utilities/client.py | 172 ++++++++++-------- 1 file changed, 95 insertions(+), 77 deletions(-) diff --git a/packages/gcs-utilities/src/gcs_utilities/client.py b/packages/gcs-utilities/src/gcs_utilities/client.py index 009a019..5ca8e03 100644 --- a/packages/gcs-utilities/src/gcs_utilities/client.py +++ b/packages/gcs-utilities/src/gcs_utilities/client.py @@ -39,6 +39,18 @@ class GCSClient: GCS_BUCKET: Default bucket name (optional, can be specified per operation) GCP_PROJECT: GCP project ID (optional, extracted from service account if not provided) + Args: + service_account_key_b64 (str | None): Base64-encoded service account + JSON. If not provided, reads from GCP_SA_KEY env var. + bucket_name (str | None): Default bucket name. If not provided, reads + from GCS_BUCKET env var. + project_id (str | None): GCP project ID. If not provided, extracts from + service account. + auto_create_bucket (bool): If True, creates bucket if it doesn't exist. + + Raises: + GCSAuthError: If authentication fails. + Example: ```python from gcs_utilities import GCSClient @@ -64,19 +76,6 @@ def __init__( project_id: str | None = None, auto_create_bucket: bool = False, ) -> None: - """Initialize GCS client. - - Args: - service_account_key_b64: Base64-encoded service account JSON. - If not provided, reads from GCP_SA_KEY env var. - bucket_name: Default bucket name. If not provided, reads from GCS_BUCKET env var. - project_id: GCP project ID. If not provided, extracts from service account. - auto_create_bucket: If True, creates bucket if it doesn't exist. - - Raises: - GCSAuthError: If authentication fails. - GCSConfigError: If required configuration is missing (raised by _setup_credentials). - """ self._credentials_path: str | None = None self._cleanup_registered = False self.bucket_name = bucket_name or os.getenv("GCS_BUCKET") @@ -106,7 +105,8 @@ def _setup_credentials(self, service_account_key_b64: str | None = None) -> None """Setup GCS credentials from base64-encoded service account key. Args: - service_account_key_b64: Base64-encoded service account JSON. + service_account_key_b64 (str | None): Base64-encoded service + account JSON. Raises: GCSAuthError: If credentials setup fails. @@ -167,14 +167,16 @@ def _get_or_create_bucket(self, auto_create: bool = False) -> storage.Bucket: """Get bucket or optionally create it if it doesn't exist. Args: - auto_create: If True, creates bucket if it doesn't exist. + auto_create (bool): If True, creates bucket if it doesn't exist. Returns: - Storage bucket object. + storage.Bucket: Storage bucket object. Raises: GCSNotFoundError: If bucket doesn't exist and auto_create is False. GCSAuthError: If bucket access fails. + Exception: Re-raised if it is a GCSNotFoundError caught by the + broad handler. """ try: bucket = self.client.bucket(self.bucket_name) @@ -200,8 +202,8 @@ def set_bucket(self, bucket_name: str, auto_create: bool = False) -> None: """Set or change the default bucket. Args: - bucket_name: Name of the bucket. - auto_create: If True, creates bucket if it doesn't exist. + bucket_name (str): Name of the bucket. + auto_create (bool): If True, creates bucket if it doesn't exist. """ self.bucket_name = bucket_name self.bucket = self._get_or_create_bucket(auto_create) @@ -211,11 +213,11 @@ def _validate_local_path(path: Path, must_exist: bool = False) -> Path: """Validate local file path for security. Args: - path: Path to validate. - must_exist: If True, raises error if path doesn't exist. + path (Path): Path to validate. + must_exist (bool): If True, raises error if path doesn't exist. Returns: - Resolved absolute path. + Path: Resolved absolute path. Raises: ValueError: If path is invalid or contains traversal attempts. @@ -241,10 +243,10 @@ def _sanitize_gcs_path(gcs_path: str) -> str: """Sanitize GCS path to prevent issues. Args: - gcs_path: GCS blob path to sanitize. + gcs_path (str): GCS blob path to sanitize. Returns: - Sanitized path. + str: Sanitized path. Raises: ValueError: If path is invalid. @@ -275,18 +277,20 @@ def upload_file( """Upload a single file to GCS. Args: - local_path: Path to local file. - gcs_path: Destination path in GCS (blob name). - bucket_name: Bucket name (uses default if not specified). - content_type: Content type for the file (auto-detected if not provided). - metadata: Optional metadata dict to attach to the blob. + local_path (str): Path to local file. + gcs_path (str): Destination path in GCS (blob name). + bucket_name (str | None): Bucket name (uses default if not + specified). + content_type (str | None): Content type for the file + (auto-detected if not provided). + metadata (dict[str, str] | None): Optional metadata dict to attach + to the blob. Returns: - Full GCS URI (gs://bucket/path). + str: Full GCS URI (gs://bucket/path). Raises: GCSUploadError: If upload fails. - FileNotFoundError: If local file doesn't exist (raised by _validate_local_path). """ # Validate and sanitize paths local_file = self._validate_local_path(Path(local_path), must_exist=True) @@ -327,18 +331,18 @@ def upload_directory( """Upload a directory to GCS, preserving structure. Args: - local_dir: Path to local directory. - gcs_prefix: Prefix for GCS paths (like a directory). - bucket_name: Bucket name (uses default if not specified). - pattern: Glob pattern for files to include (default: all files). - exclude_patterns: List of glob patterns to exclude. + local_dir (str): Path to local directory. + gcs_prefix (str): Prefix for GCS paths (like a directory). + bucket_name (str | None): Bucket name (uses default if not + specified). + pattern (str): Glob pattern for files to include (default: all + files). + exclude_patterns (list[str] | None): List of glob patterns to + exclude. Returns: - Dict with stats: {"files_uploaded": int, "total_bytes": int, "failed": list}. - - Raises: - GCSUploadError: If upload fails (raised indirectly). - FileNotFoundError: If local directory doesn't exist (raised by _validate_local_path). + dict[str, Any]: Dict with stats: {"files_uploaded": int, + "total_bytes": int, "failed": list}. """ # Validate local directory path local_path = self._validate_local_path(Path(local_dir), must_exist=True) @@ -406,13 +410,15 @@ def download_file( """Download a single file from GCS. Args: - gcs_path: Path in GCS (blob name). - local_path: Destination local path. - bucket_name: Bucket name (uses default if not specified). - create_dirs: If True, creates parent directories if they don't exist. + gcs_path (str): Path in GCS (blob name). + local_path (str): Destination local path. + bucket_name (str | None): Bucket name (uses default if not + specified). + create_dirs (bool): If True, creates parent directories if they + don't exist. Returns: - Path to downloaded file. + str: Path to downloaded file. Raises: GCSDownloadError: If download fails. @@ -456,11 +462,12 @@ def download_as_bytes( """Download a file from GCS as bytes. Args: - gcs_path: Path in GCS (blob name). - bucket_name: Bucket name (uses default if not specified). + gcs_path (str): Path in GCS (blob name). + bucket_name (str | None): Bucket name (uses default if not + specified). Returns: - File contents as bytes. + bytes: File contents as bytes. Raises: GCSDownloadError: If download fails. @@ -489,12 +496,13 @@ def download_as_text( """Download a file from GCS as text. Args: - gcs_path: Path in GCS (blob name). - bucket_name: Bucket name (uses default if not specified). - encoding: Text encoding (default: utf-8). + gcs_path (str): Path in GCS (blob name). + bucket_name (str | None): Bucket name (uses default if not + specified). + encoding (str): Text encoding (default: utf-8). Returns: - File contents as string. + str: File contents as string. Raises: GCSDownloadError: If download fails. @@ -524,13 +532,16 @@ def list_files( """List files in GCS bucket. Args: - prefix: Filter to files with this prefix. - bucket_name: Bucket name (uses default if not specified). - max_results: Maximum number of results to return. - delimiter: Delimiter for directory-like listing (e.g., "/"). + prefix (str | None): Filter to files with this prefix. + bucket_name (str | None): Bucket name (uses default if not + specified). + max_results (int | None): Maximum number of results to return. + delimiter (str | None): Delimiter for directory-like listing (e.g., + "/"). Returns: - List of dicts with file info: {"name": str, "size": int, "updated": datetime}. + list[dict[str, Any]]: List of dicts with file info: {"name": str, + "size": int, "updated": datetime}. Raises: GCSDownloadError: If listing fails. @@ -571,12 +582,15 @@ def delete_file( """Delete a file from GCS. Args: - gcs_path: Path in GCS (blob name). - bucket_name: Bucket name (uses default if not specified). - ignore_missing: If True, doesn't raise error if file doesn't exist. + gcs_path (str): Path in GCS (blob name). + bucket_name (str | None): Bucket name (uses default if not + specified). + ignore_missing (bool): If True, doesn't raise error if file doesn't + exist. Returns: - True if file was deleted, False if it didn't exist (when ignore_missing=True). + bool: True if file was deleted, False if it didn't exist (when + ignore_missing=True). Raises: GCSNotFoundError: If file doesn't exist and ignore_missing=False. @@ -609,11 +623,12 @@ def delete_directory( """Delete all files with a given prefix (directory-like deletion). Args: - prefix: Prefix of files to delete. - bucket_name: Bucket name (uses default if not specified). + prefix (str): Prefix of files to delete. + bucket_name (str | None): Bucket name (uses default if not + specified). Returns: - Number of files deleted. + int: Number of files deleted. """ prefix = self._sanitize_gcs_path(prefix) if prefix else "" bucket = self._get_bucket(bucket_name) @@ -640,11 +655,12 @@ def file_exists( """Check if a file exists in GCS. Args: - gcs_path: Path in GCS (blob name). - bucket_name: Bucket name (uses default if not specified). + gcs_path (str): Path in GCS (blob name). + bucket_name (str | None): Bucket name (uses default if not + specified). Returns: - True if file exists, False otherwise. + bool: True if file exists, False otherwise. """ gcs_path = self._sanitize_gcs_path(gcs_path) bucket = self._get_bucket(bucket_name) @@ -659,11 +675,13 @@ def get_file_metadata( """Get metadata for a file in GCS. Args: - gcs_path: Path in GCS (blob name). - bucket_name: Bucket name (uses default if not specified). + gcs_path (str): Path in GCS (blob name). + bucket_name (str | None): Bucket name (uses default if not + specified). Returns: - Dict with metadata including size, content_type, updated, etc. + dict[str, Any]: Dict with metadata including size, content_type, + updated, etc. Raises: GCSNotFoundError: If file doesn't exist. @@ -694,10 +712,10 @@ def _get_bucket(self, bucket_name: str | None = None) -> storage.Bucket: """Get bucket object, using default if not specified. Args: - bucket_name: Optional bucket name. + bucket_name (str | None): Optional bucket name. Returns: - Storage bucket object. + storage.Bucket: Storage bucket object. Raises: GCSConfigError: If no bucket is specified and no default is set. @@ -734,7 +752,7 @@ def __enter__(self) -> "GCSClient": """Context manager entry. Returns: - The GCSClient instance. + GCSClient: The GCSClient instance. """ return self @@ -742,12 +760,12 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: """Context manager exit with cleanup. Args: - exc_type: Exception type, if any. - exc_val: Exception value, if any. - exc_tb: Exception traceback, if any. + exc_type (Any): Exception type, if any. + exc_val (Any): Exception value, if any. + exc_tb (Any): Exception traceback, if any. Returns: - False to propagate exceptions. + bool: False to propagate exceptions. """ self.close() return False From 0844183c0b4f984294ddf417b1f33566b0e437a4 Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 22:08:14 -0700 Subject: [PATCH 5/8] docs(cloudflare-api): align core module docstrings with signatures Add argument and return type hints to docstrings and move __init__ argument docs into class docstrings. Document pydantic model fields with their annotated types. Refactor CloudflareAPIClient._handle_api_error to NoReturn (it always raises) and drop the unreachable bare re-raises so the Raises sections reflect the exceptions actually raised in each body: delegating methods document none, methods that raise CloudflareConflict- Error directly keep that entry. Behavior is unchanged; 83 cloudflare-api tests pass. Co-Authored-By: Claude Sonnet 4.6 --- .../src/cloudflare_api/client.py | 167 ++++++------------ .../src/cloudflare_api/exceptions.py | 77 +++----- .../src/cloudflare_api/models.py | 47 ++--- .../src/cloudflare_api/settings.py | 31 ++-- 4 files changed, 124 insertions(+), 198 deletions(-) diff --git a/packages/cloudflare-api/src/cloudflare_api/client.py b/packages/cloudflare-api/src/cloudflare_api/client.py index 374b5fe..38dde41 100644 --- a/packages/cloudflare-api/src/cloudflare_api/client.py +++ b/packages/cloudflare-api/src/cloudflare_api/client.py @@ -5,7 +5,7 @@ import logging import time -from typing import Any +from typing import Any, NoReturn from cloudflare import Cloudflare from cloudflare._exceptions import ( @@ -45,6 +45,10 @@ class CloudflareAPIClient: Provides methods for managing IP lists, firewall rules, and other Cloudflare resources using the official SDK. + Args: + settings (CloudflareAPISettings | None): Optional settings. If not + provided, reads from environment. + Example: ```python client = CloudflareAPIClient() @@ -64,14 +68,6 @@ def __init__( self, settings: CloudflareAPISettings | None = None, ) -> None: - """Initialize the Cloudflare API client. - - Args: - settings: Optional settings. If not provided, reads from environment. - - Raises: - CloudflareAuthError: If authentication credentials are missing. - """ self.settings = settings or get_cloudflare_api_settings() self._client = Cloudflare( api_token=self.settings.get_token_value(), @@ -83,11 +79,14 @@ def __init__( self._account_id[:8] + "...", ) - def _handle_api_error(self, error: Exception) -> None: - """Convert SDK exceptions to our custom exceptions. + def _handle_api_error(self, error: Exception) -> NoReturn: + """Convert SDK exceptions to our custom exceptions and raise them. + + This helper always raises; it never returns. Callers invoke it as a + statement (``self._handle_api_error(e)``) inside an ``except`` block. Args: - error: Exception from the Cloudflare SDK. + error (Exception): Exception from the Cloudflare SDK. Raises: CloudflareAuthError: For authentication failures. @@ -136,10 +135,7 @@ def list_ip_lists(self) -> list[IPList]: """List all IP lists in the account. Returns: - List of IPList objects. - - Raises: - CloudflareAPIError: If the API request fails. + list[IPList]: List of IPList objects. """ try: response = self._client.rules.lists.list(account_id=self._account_id) @@ -161,20 +157,15 @@ def list_ip_lists(self) -> list[IPList]: return lists except Exception as e: self._handle_api_error(e) - raise # Unreachable but satisfies type checker def get_ip_list(self, list_id: str) -> IPList: """Get details of a specific IP list. Args: - list_id: The list identifier. + list_id (str): The list identifier. Returns: - IPList object. - - Raises: - CloudflareNotFoundError: If the list doesn't exist. - CloudflareAPIError: If the API request fails. + IPList: IPList object. """ try: item = self._client.rules.lists.get( @@ -193,19 +184,15 @@ def get_ip_list(self, list_id: str) -> IPList: ) except Exception as e: self._handle_api_error(e) - raise def get_ip_list_by_name(self, name: str) -> IPList | None: """Get an IP list by name. Args: - name: The list name to search for. + name (str): The list name to search for. Returns: - IPList if found, None otherwise. - - Raises: - CloudflareAPIError: If the API request fails. + IPList | None: IPList if found, None otherwise. """ lists = self.list_ip_lists() for ip_list in lists: @@ -222,17 +209,15 @@ def create_ip_list( """Create a new IP list. Args: - name: List name (must be unique per account). - kind: Type of list (ip, redirect, hostname, asn). - description: Optional description. + name (str): List name (must be unique per account). + kind (str): Type of list (ip, redirect, hostname, asn). + description (str | None): Optional description. Returns: - The created IPList. + IPList: The created IPList. Raises: CloudflareConflictError: If a list with this name already exists. - CloudflareValidationError: If the name or kind is invalid. - CloudflareAPIError: If the API request fails. """ try: response = self._client.rules.lists.create( @@ -257,10 +242,8 @@ def create_ip_list( msg = f"A list named '{name}' already exists" raise CloudflareConflictError(msg, code=409) from e self._handle_api_error(e) - raise except Exception as e: self._handle_api_error(e) - raise def update_ip_list( self, @@ -270,15 +253,11 @@ def update_ip_list( """Update an IP list's description. Args: - list_id: The list identifier. - description: New description. + list_id (str): The list identifier. + description (str | None): New description. Returns: - The updated IPList. - - Raises: - CloudflareNotFoundError: If the list doesn't exist. - CloudflareAPIError: If the API request fails. + IPList: The updated IPList. """ try: response = self._client.rules.lists.update( @@ -299,21 +278,18 @@ def update_ip_list( ) except Exception as e: self._handle_api_error(e) - raise def delete_ip_list(self, list_id: str) -> bool: """Delete an IP list and all its items. Args: - list_id: The list identifier. + list_id (str): The list identifier. Returns: - True if deleted successfully. + bool: True if deleted successfully. Raises: - CloudflareNotFoundError: If the list doesn't exist. CloudflareConflictError: If the list is in use by firewall rules. - CloudflareAPIError: If the API request fails. """ try: self._client.rules.lists.delete( @@ -329,10 +305,8 @@ def delete_ip_list(self, list_id: str) -> bool: ) raise CloudflareConflictError(msg, code=409) from e self._handle_api_error(e) - raise except Exception as e: self._handle_api_error(e) - raise # ========================================================================= # IP List Item Operations @@ -342,14 +316,10 @@ def get_ip_list_items(self, list_id: str) -> list[IPListItem]: """Get all items in an IP list. Args: - list_id: The list identifier. + list_id (str): The list identifier. Returns: - List of IPListItem objects. - - Raises: - CloudflareNotFoundError: If the list doesn't exist. - CloudflareAPIError: If the API request fails. + list[IPListItem]: List of IPListItem objects. """ try: response = self._client.rules.lists.items.list( @@ -371,7 +341,6 @@ def get_ip_list_items(self, list_id: str) -> list[IPListItem]: return items except Exception as e: self._handle_api_error(e) - raise def add_ip_list_items( self, @@ -384,19 +353,17 @@ def add_ip_list_items( This is an asynchronous operation. By default, waits for completion. Args: - list_id: The list identifier. - items: List of items to add. Each item should have 'ip' and - optionally 'comment'. - wait_for_completion: Whether to wait for the operation to complete. + list_id (str): The list identifier. + items (list[dict[str, Any] | IPListItemInput]): List of items to + add. Each item should have 'ip' and optionally 'comment'. + wait_for_completion (bool): Whether to wait for the operation to + complete. Returns: - Operation ID if not waiting, None if completed. + str | None: Operation ID if not waiting, None if completed. Raises: - CloudflareNotFoundError: If the list doesn't exist. - CloudflareBulkOperationError: If the operation fails. CloudflareConflictError: If another bulk operation is in progress. - CloudflareAPIError: If the API request fails. """ try: # Convert to API format @@ -431,10 +398,8 @@ def add_ip_list_items( msg = "Another bulk operation is already in progress" raise CloudflareConflictError(msg, code=409) from e self._handle_api_error(e) - raise except Exception as e: self._handle_api_error(e) - raise def replace_ip_list_items( self, @@ -448,19 +413,17 @@ def replace_ip_list_items( This is an asynchronous operation. Args: - list_id: The list identifier. - items: List of items to set. Each item should have 'ip' and - optionally 'comment'. - wait_for_completion: Whether to wait for the operation to complete. + list_id (str): The list identifier. + items (list[dict[str, Any] | IPListItemInput]): List of items to + set. Each item should have 'ip' and optionally 'comment'. + wait_for_completion (bool): Whether to wait for the operation to + complete. Returns: - Operation ID if not waiting, None if completed. + str | None: Operation ID if not waiting, None if completed. Raises: - CloudflareNotFoundError: If the list doesn't exist. - CloudflareBulkOperationError: If the operation fails. CloudflareConflictError: If another bulk operation is in progress. - CloudflareAPIError: If the API request fails. """ try: # Convert to API format @@ -495,10 +458,8 @@ def replace_ip_list_items( msg = "Another bulk operation is already in progress" raise CloudflareConflictError(msg, code=409) from e self._handle_api_error(e) - raise except Exception as e: self._handle_api_error(e) - raise def delete_ip_list_items( self, @@ -509,17 +470,13 @@ def delete_ip_list_items( """Delete specific items from an IP list. Args: - list_id: The list identifier. - item_ids: List of item IDs to delete. - wait_for_completion: Whether to wait for the operation to complete. + list_id (str): The list identifier. + item_ids (list[str]): List of item IDs to delete. + wait_for_completion (bool): Whether to wait for the operation to + complete. Returns: - Operation ID if not waiting, None if completed. - - Raises: - CloudflareNotFoundError: If the list doesn't exist. - CloudflareBulkOperationError: If the operation fails. - CloudflareAPIError: If the API request fails. + str | None: Operation ID if not waiting, None if completed. """ try: # Format items for deletion @@ -546,7 +503,6 @@ def delete_ip_list_items( return operation_id except Exception as e: self._handle_api_error(e) - raise # ========================================================================= # Bulk Operation Helpers @@ -556,14 +512,10 @@ def get_bulk_operation_status(self, operation_id: str) -> BulkOperation: """Get the status of a bulk operation. Args: - operation_id: The operation identifier. + operation_id (str): The operation identifier. Returns: - BulkOperation with current status. - - Raises: - CloudflareNotFoundError: If the operation doesn't exist. - CloudflareAPIError: If the API request fails. + BulkOperation: BulkOperation with current status. """ try: response = self._client.rules.lists.bulk_operations.get( @@ -578,16 +530,15 @@ def get_bulk_operation_status(self, operation_id: str) -> BulkOperation: ) except Exception as e: self._handle_api_error(e) - raise def _wait_for_bulk_operation(self, operation_id: str) -> BulkOperation: """Wait for a bulk operation to complete. Args: - operation_id: The operation identifier. + operation_id (str): The operation identifier. Returns: - The final BulkOperation status. + BulkOperation: The final BulkOperation status. Raises: CloudflareBulkOperationError: If the operation fails or times out. @@ -637,15 +588,12 @@ def ensure_ip_list( """Get or create an IP list by name. Args: - name: List name. - kind: Type of list if creating. - description: Description if creating. + name (str): List name. + kind (str): Type of list if creating. + description (str | None): Description if creating. Returns: - The existing or newly created IPList. - - Raises: - CloudflareAPIError: If the API request fails. + IPList: The existing or newly created IPList. """ existing = self.get_ip_list_by_name(name) if existing: @@ -663,12 +611,11 @@ def sync_ip_list( """Sync an IP list to contain exactly the specified IPs. Args: - list_id: The list identifier. - ips: List of IP addresses/CIDRs that should be in the list. - comments: Optional mapping of IP to comment. - - Raises: - CloudflareAPIError: If the API request fails. + list_id (str): The list identifier. + ips (list[str]): List of IP addresses/CIDRs that should be in the + list. + comments (dict[str, str] | None): Optional mapping of IP to + comment. """ comments = comments or {} items = [{"ip": ip, "comment": comments.get(ip)} for ip in ips] diff --git a/packages/cloudflare-api/src/cloudflare_api/exceptions.py b/packages/cloudflare-api/src/cloudflare_api/exceptions.py index bc7bdbb..fd09439 100644 --- a/packages/cloudflare-api/src/cloudflare_api/exceptions.py +++ b/packages/cloudflare-api/src/cloudflare_api/exceptions.py @@ -9,11 +9,12 @@ class CloudflareAPIError(Exception): """Base exception for Cloudflare API errors. - Attributes: - message: Error message - code: Cloudflare error code (if available) - errors: List of error details from Cloudflare response - response: Raw response data (if available) + Args: + message (str): Error message + code (int | None): Cloudflare error code (if available) + errors (list[dict[str, Any]] | None): List of error details from + Cloudflare response + response (dict[str, Any] | None): Raw response data (if available) """ def __init__( @@ -23,14 +24,6 @@ def __init__( errors: list[dict[str, Any]] | None = None, response: dict[str, Any] | None = None, ) -> None: - """Initialize CloudflareAPIError. - - Args: - message: Error message - code: Cloudflare error code - errors: List of error details - response: Raw response data - """ super().__init__(message) self.message = message self.code = code @@ -60,8 +53,10 @@ class CloudflareRateLimitError(CloudflareAPIError): Raised when too many requests are made in a short period. - Attributes: - retry_after: Seconds to wait before retrying (if provided) + Args: + message (str): Error message + retry_after (int | None): Seconds to wait before retrying + **kwargs (Any): Additional arguments for base class """ def __init__( @@ -70,13 +65,6 @@ def __init__( retry_after: int | None = None, **kwargs: Any, ) -> None: - """Initialize rate limit error. - - Args: - message: Error message - retry_after: Seconds to wait before retrying - **kwargs: Additional arguments for base class - """ super().__init__(message, **kwargs) self.retry_after = retry_after @@ -86,9 +74,11 @@ class CloudflareNotFoundError(CloudflareAPIError): Raised when a requested resource (list, item, zone) doesn't exist. - Attributes: - resource_type: Type of resource not found - resource_id: ID of the missing resource + Args: + message (str): Error message + resource_type (str | None): Type of resource (e.g., "list", "item") + resource_id (str | None): ID of the missing resource + **kwargs (Any): Additional arguments for base class """ def __init__( @@ -98,14 +88,6 @@ def __init__( resource_id: str | None = None, **kwargs: Any, ) -> None: - """Initialize not found error. - - Args: - message: Error message - resource_type: Type of resource (e.g., "list", "item") - resource_id: ID of the missing resource - **kwargs: Additional arguments for base class - """ super().__init__(message, **kwargs) self.resource_type = resource_type self.resource_id = resource_id @@ -116,8 +98,10 @@ class CloudflareValidationError(CloudflareAPIError): Raised when request parameters fail Cloudflare's validation. - Attributes: - field: Field that failed validation (if known) + Args: + message (str): Error message + field (str | None): Field that failed validation + **kwargs (Any): Additional arguments for base class """ def __init__( @@ -126,13 +110,6 @@ def __init__( field: str | None = None, **kwargs: Any, ) -> None: - """Initialize validation error. - - Args: - message: Error message - field: Field that failed validation - **kwargs: Additional arguments for base class - """ super().__init__(message, **kwargs) self.field = field @@ -142,9 +119,11 @@ class CloudflareBulkOperationError(CloudflareAPIError): Raised when a bulk operation fails or times out. - Attributes: - operation_id: ID of the failed operation - status: Final status of the operation + Args: + message (str): Error message + operation_id (str | None): ID of the failed operation + status (str | None): Final status of the operation + **kwargs (Any): Additional arguments for base class """ def __init__( @@ -154,14 +133,6 @@ def __init__( status: str | None = None, **kwargs: Any, ) -> None: - """Initialize bulk operation error. - - Args: - message: Error message - operation_id: ID of the failed operation - status: Final status of the operation - **kwargs: Additional arguments for base class - """ super().__init__(message, **kwargs) self.operation_id = operation_id self.status = status diff --git a/packages/cloudflare-api/src/cloudflare_api/models.py b/packages/cloudflare-api/src/cloudflare_api/models.py index abf928c..3a604bd 100644 --- a/packages/cloudflare-api/src/cloudflare_api/models.py +++ b/packages/cloudflare-api/src/cloudflare_api/models.py @@ -32,11 +32,11 @@ class IPListItem(BaseModel): """An item in an IP list. Attributes: - id: Unique identifier for the item - ip: IP address or CIDR range - comment: Optional description - created_on: When the item was created - modified_on: When the item was last modified + id (str | None): Unique identifier for the item + ip (str): IP address or CIDR range + comment (str | None): Optional description + created_on (datetime | None): When the item was created + modified_on (datetime | None): When the item was last modified """ id: str | None = None @@ -50,14 +50,15 @@ class IPList(BaseModel): """A Cloudflare IP list. Attributes: - id: Unique identifier for the list - name: List name (must be unique per account) - description: Optional description - kind: Type of list (ip, redirect, hostname, asn) - num_items: Number of items in the list - num_referencing_filters: Number of firewall filters using this list - created_on: When the list was created - modified_on: When the list was last modified + id (str): Unique identifier for the list + name (str): List name (must be unique per account) + description (str | None): Optional description + kind (ListKind): Type of list (ip, redirect, hostname, asn) + num_items (int): Number of items in the list + num_referencing_filters (int): Number of firewall filters using this + list + created_on (datetime | None): When the list was created + modified_on (datetime | None): When the list was last modified """ id: str @@ -74,10 +75,10 @@ class BulkOperation(BaseModel): """Status of a bulk operation. Attributes: - id: Operation identifier - status: Current status - error: Error message if failed - completed: When the operation completed + id (str): Operation identifier + status (BulkOperationStatus): Current status + error (str | None): Error message if failed + completed (datetime | None): When the operation completed """ id: str @@ -90,8 +91,8 @@ class IPListItemInput(BaseModel): """Input model for creating/updating IP list items. Attributes: - ip: IP address or CIDR range - comment: Optional description + ip (str): IP address or CIDR range + comment (str | None): Optional description """ ip: str = Field(description="IP address or CIDR range") @@ -101,7 +102,7 @@ def to_api_dict(self) -> dict[str, Any]: """Convert to API request format. Returns: - Dictionary for API request. + dict[str, Any]: Dictionary for API request. """ result: dict[str, Any] = {"ip": self.ip} if self.comment: @@ -113,9 +114,9 @@ class CreateIPListRequest(BaseModel): """Request to create a new IP list. Attributes: - name: List name (must be unique per account) - kind: Type of list - description: Optional description + name (str): List name (must be unique per account) + kind (ListKind): Type of list + description (str | None): Optional description """ name: str = Field(description="List name (must be unique)") diff --git a/packages/cloudflare-api/src/cloudflare_api/settings.py b/packages/cloudflare-api/src/cloudflare_api/settings.py index d686248..f241b88 100644 --- a/packages/cloudflare-api/src/cloudflare_api/settings.py +++ b/packages/cloudflare-api/src/cloudflare_api/settings.py @@ -13,13 +13,23 @@ class CloudflareAPISettings(BaseSettings): All settings can be configured via environment variables or .env file. Attributes: - cloudflare_api_token: API token with appropriate permissions - cloudflare_account_id: Cloudflare account identifier - cloudflare_api_email: Optional email for legacy API key auth - cloudflare_api_key: Optional global API key (legacy) - default_list_kind: Default kind for new IP lists (ip, redirect, hostname, asn) - request_timeout: HTTP request timeout in seconds - max_retries: Maximum number of retries for failed requests + model_config: Pydantic settings configuration. + cloudflare_api_token (SecretStr): API token with appropriate + permissions + cloudflare_account_id (str): Cloudflare account identifier + cloudflare_api_email (str | None): Optional email for legacy API key + auth + cloudflare_api_key (SecretStr | None): Optional global API key (legacy) + cloudflare_zone_id (str | None): Default zone ID for zone-scoped + operations + default_list_kind (str): Default kind for new IP lists (ip, redirect, + hostname, asn) + request_timeout (int): HTTP request timeout in seconds + max_retries (int): Maximum number of retries for failed requests + bulk_operation_poll_interval (float): Seconds between bulk operation + status checks + bulk_operation_timeout (int): Maximum seconds to wait for bulk + operations """ model_config = SettingsConfigDict( @@ -104,7 +114,7 @@ def get_token_value(self) -> str: """Get the API token as a plain string. Returns: - The API token value. + str: The API token value. """ return self.cloudflare_api_token.get_secret_value() @@ -116,10 +126,7 @@ def get_cloudflare_api_settings() -> CloudflareAPISettings: """Get default settings (singleton, reads from environment). Returns: - CloudflareAPISettings instance. - - Raises: - ValidationError: If required environment variables are missing. + CloudflareAPISettings: Configured settings instance. """ global _settings_instance if _settings_instance is None: From eb48b5c33d40a1f0faec9731f76994fe68ad058a Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 22:29:02 -0700 Subject: [PATCH 6/8] docs(cloudflare-api): align ip_groups docstrings with signatures for pydoclint Co-Authored-By: Claude Sonnet 4.6 --- .../src/cloudflare_api/ip_groups/cli.py | 22 +-- .../src/cloudflare_api/ip_groups/config.py | 37 ++--- .../src/cloudflare_api/ip_groups/fetchers.py | 132 +++++++++--------- .../src/cloudflare_api/ip_groups/manager.py | 95 ++++++------- 4 files changed, 138 insertions(+), 148 deletions(-) diff --git a/packages/cloudflare-api/src/cloudflare_api/ip_groups/cli.py b/packages/cloudflare-api/src/cloudflare_api/ip_groups/cli.py index dc5ab10..796029d 100644 --- a/packages/cloudflare-api/src/cloudflare_api/ip_groups/cli.py +++ b/packages/cloudflare-api/src/cloudflare_api/ip_groups/cli.py @@ -16,7 +16,7 @@ def setup_logging(verbose: bool = False) -> None: """Configure logging for CLI output. Args: - verbose: Enable debug logging. + verbose (bool): Enable debug logging. """ level = logging.DEBUG if verbose else logging.INFO logging.basicConfig( @@ -30,10 +30,10 @@ def cmd_sync(args: argparse.Namespace) -> int: """Sync IP groups to Cloudflare. Args: - args: Parsed command line arguments. + args (argparse.Namespace): Parsed command line arguments. Returns: - Exit code (0 for success). + int: Exit code (0 for success). """ manager = IPGroupManager.from_config(args.config) @@ -64,10 +64,10 @@ def cmd_preview(args: argparse.Namespace) -> int: """Preview changes for an IP group. Args: - args: Parsed command line arguments. + args (argparse.Namespace): Parsed command line arguments. Returns: - Exit code. + int: Exit code. """ manager = IPGroupManager.from_config(args.config) preview = manager.preview_group(args.group) @@ -104,10 +104,10 @@ def cmd_list(args: argparse.Namespace) -> int: """List all configured IP groups. Args: - args: Parsed command line arguments. + args (argparse.Namespace): Parsed command line arguments. Returns: - Exit code. + int: Exit code. """ manager = IPGroupManager.from_config(args.config) groups = manager.list_groups() @@ -134,10 +134,10 @@ def cmd_fetch(args: argparse.Namespace) -> int: """Fetch and display IPs for a group (without syncing). Args: - args: Parsed command line arguments. + args (argparse.Namespace): Parsed command line arguments. Returns: - Exit code. + int: Exit code. """ manager = IPGroupManager.from_config(args.config) @@ -158,10 +158,10 @@ def main(argv: list[str] | None = None) -> int: """Main CLI entry point. Args: - argv: Command line arguments. + argv (list[str] | None): Command line arguments. Returns: - Exit code. + int: Exit code. """ parser = argparse.ArgumentParser( description="Manage IP range groups for Cloudflare", diff --git a/packages/cloudflare-api/src/cloudflare_api/ip_groups/config.py b/packages/cloudflare-api/src/cloudflare_api/ip_groups/config.py index 3093594..a278806 100644 --- a/packages/cloudflare-api/src/cloudflare_api/ip_groups/config.py +++ b/packages/cloudflare-api/src/cloudflare_api/ip_groups/config.py @@ -27,12 +27,13 @@ class IPSourceConfig(BaseModel): """Configuration for an IP source. Attributes: - type: Type of IP source - ips: Static list of IPs (for static type) - url: URL to fetch IPs from (for url type) - services: Filter by service names (for provider types) - regions: Filter by regions (for provider types) - ip_version: Filter by IP version (4 or 6, or both if None) + type (SourceType): Type of IP source + ips (list[str]): Static list of IPs (for static type) + url (str | None): URL to fetch IPs from (for url type) + services (list[str]): Filter by service names (for provider types) + regions (list[str]): Filter by regions (for provider types) + ip_version (int | None): Filter by IP version (4 or 6, or both if None) + json_path (str | None): JSONPath to extract IPs from response """ type: SourceType = Field(description="Type of IP source") @@ -65,12 +66,12 @@ class IPGroupConfig(BaseModel): """Configuration for an IP range group. Attributes: - name: Human-readable name for the group - cloudflare_list_name: Name of the Cloudflare list to sync to - description: Optional description - sources: List of IP sources that make up this group - enabled: Whether this group is enabled for syncing - tags: Optional tags for categorization + name (str): Human-readable name for the group + cloudflare_list_name (str): Name of the Cloudflare list to sync to + description (str | None): Optional description + sources (list[IPSourceConfig]): List of IP sources that make up this group + enabled (bool): Whether this group is enabled for syncing + tags (list[str]): Optional tags for categorization """ name: str = Field(description="Human-readable name") @@ -87,9 +88,10 @@ class IPGroupsConfig(BaseModel): """Root configuration for all IP groups. Attributes: - version: Config schema version - groups: List of IP group configurations - defaults: Default settings for all groups + version (str): Config schema version + groups (list[IPGroupConfig]): List of IP group configurations + cache_ttl_seconds (int): How long to cache fetched IPs + cloudflare_list_prefix (str): Prefix for all Cloudflare list names """ version: str = Field(default="1.0", description="Config schema version") @@ -108,14 +110,13 @@ def load_config(path: str | Path) -> IPGroupsConfig: """Load IP groups configuration from a YAML file. Args: - path: Path to the YAML configuration file. + path (str | Path): Path to the YAML configuration file. Returns: - Parsed configuration. + IPGroupsConfig: Parsed configuration. Raises: FileNotFoundError: If the config file doesn't exist. - ValueError: If the config is invalid. """ path = Path(path) if not path.exists(): diff --git a/packages/cloudflare-api/src/cloudflare_api/ip_groups/fetchers.py b/packages/cloudflare-api/src/cloudflare_api/ip_groups/fetchers.py index 03e8f05..30a530f 100644 --- a/packages/cloudflare-api/src/cloudflare_api/ip_groups/fetchers.py +++ b/packages/cloudflare-api/src/cloudflare_api/ip_groups/fetchers.py @@ -32,10 +32,10 @@ def fetch(self, config: IPSourceConfig) -> list[str]: """Fetch IP ranges from the source. Args: - config: Source configuration. + config (IPSourceConfig): Source configuration. Returns: - List of IP addresses or CIDR ranges. + list[str]: List of IP addresses or CIDR ranges. """ @staticmethod @@ -43,10 +43,10 @@ def validate_ip(ip: str) -> bool: """Validate an IP address or CIDR range. Args: - ip: IP address or CIDR to validate. + ip (str): IP address or CIDR to validate. Returns: - True if valid, False otherwise. + bool: True if valid, False otherwise. """ try: # Try as network (CIDR) @@ -65,10 +65,10 @@ def get_ip_version(ip: str) -> int: """Get the IP version (4 or 6) of an address. Args: - ip: IP address or CIDR. + ip (str): IP address or CIDR. Returns: - 4 or 6. + int: 4 or 6. """ try: network = ipaddress.ip_network(ip, strict=False) @@ -81,11 +81,11 @@ def filter_by_version(self, ips: list[str], version: int | None) -> list[str]: """Filter IPs by version. Args: - ips: List of IP addresses. - version: IP version to filter by (4, 6, or None for all). + ips (list[str]): List of IP addresses. + version (int | None): IP version to filter by (4, 6, or None for all). Returns: - Filtered list of IPs. + list[str]: Filtered list of IPs. """ if version is None: return ips @@ -99,10 +99,10 @@ def fetch(self, config: IPSourceConfig) -> list[str]: """Return the static IP list from config. Args: - config: Source configuration with static IPs. + config (IPSourceConfig): Source configuration with static IPs. Returns: - List of validated IP addresses. + list[str]: List of validated IP addresses. """ valid_ips = [] for ip in config.ips: @@ -115,28 +115,26 @@ def fetch(self, config: IPSourceConfig) -> list[str]: class URLIPFetcher(IPFetcher): - """Fetcher for generic URL sources.""" + """Fetcher for generic URL sources. - def __init__(self, timeout: float = 30.0) -> None: - """Initialize the URL fetcher. + Args: + timeout (float): HTTP request timeout in seconds. + """ - Args: - timeout: HTTP request timeout in seconds. - """ + def __init__(self, timeout: float = 30.0) -> None: self.timeout = timeout def fetch(self, config: IPSourceConfig) -> list[str]: """Fetch IPs from a URL. Args: - config: Source configuration with URL. + config (IPSourceConfig): Source configuration with URL. Returns: - List of IP addresses extracted from the response. + list[str]: List of IP addresses extracted from the response. Raises: ValueError: If URL is not configured. - httpx.HTTPError: If the request fails. """ if not config.url: msg = "URL is required for URL source type" @@ -156,11 +154,11 @@ def _parse_json(self, text: str, config: IPSourceConfig) -> list[str]: """Parse JSON response for IPs. Args: - text: JSON response text. - config: Source configuration. + text (str): JSON response text. + config (IPSourceConfig): Source configuration. Returns: - List of IP addresses. + list[str]: List of IP addresses. """ data = json.loads(text) @@ -178,11 +176,11 @@ def _parse_text(self, text: str, config: IPSourceConfig) -> list[str]: """Parse plain text response for IPs. Args: - text: Plain text response. - config: Source configuration. + text (str): Plain text response. + config (IPSourceConfig): Source configuration. Returns: - List of IP addresses (one per line). + list[str]: List of IP addresses (one per line). """ ips = [] for line in text.strip().split("\n"): @@ -199,11 +197,11 @@ def _extract_json_path(self, data: Any, path: str) -> list[str]: Supports paths like "prefixes[*].ip_prefix" or "hooks". Args: - data: Parsed JSON data. - path: JSONPath-like expression. + data (Any): Parsed JSON data. + path (str): JSONPath-like expression. Returns: - List of extracted string values. + list[str]: List of extracted string values. """ parts = path.split(".") current = data @@ -240,11 +238,11 @@ def _auto_extract_ips( """Auto-extract IP-like values from JSON. Args: - data: Parsed JSON data. - results: Accumulator for results. + data (Any): Parsed JSON data. + results (list[str] | None): Accumulator for results. Returns: - List of IP-like strings found. + list[str]: List of IP-like strings found. """ if results is None: results = [] @@ -271,24 +269,23 @@ def _auto_extract_ips( class GitHubIPFetcher(IPFetcher): - """Fetcher for GitHub Meta API IP ranges.""" + """Fetcher for GitHub Meta API IP ranges. - def __init__(self, timeout: float = 30.0) -> None: - """Initialize the GitHub fetcher. + Args: + timeout (float): HTTP request timeout in seconds. + """ - Args: - timeout: HTTP request timeout in seconds. - """ + def __init__(self, timeout: float = 30.0) -> None: self.timeout = timeout def fetch(self, config: IPSourceConfig) -> list[str]: """Fetch GitHub IP ranges. Args: - config: Source configuration with optional service filters. + config (IPSourceConfig): Source configuration with optional service filters. Returns: - List of GitHub IP ranges. + list[str]: List of GitHub IP ranges. Available services: hooks, web, api, git, github_enterprise_importer, packages, pages, importer, actions, actions_macos, dependabot, copilot @@ -323,24 +320,23 @@ def fetch(self, config: IPSourceConfig) -> list[str]: class GoogleCloudIPFetcher(IPFetcher): - """Fetcher for Google Cloud IP ranges.""" + """Fetcher for Google Cloud IP ranges. - def __init__(self, timeout: float = 30.0) -> None: - """Initialize the Google Cloud fetcher. + Args: + timeout (float): HTTP request timeout in seconds. + """ - Args: - timeout: HTTP request timeout in seconds. - """ + def __init__(self, timeout: float = 30.0) -> None: self.timeout = timeout def fetch(self, config: IPSourceConfig) -> list[str]: """Fetch Google Cloud IP ranges. Args: - config: Source configuration with optional region/service filters. + config (IPSourceConfig): Source configuration with optional region/service filters. Returns: - List of Google Cloud IP ranges. + list[str]: List of Google Cloud IP ranges. """ with httpx.Client(timeout=self.timeout) as client: response = client.get(GOOGLE_CLOUD_URL) @@ -378,24 +374,23 @@ def fetch(self, config: IPSourceConfig) -> list[str]: class AWSIPFetcher(IPFetcher): - """Fetcher for AWS IP ranges.""" + """Fetcher for AWS IP ranges. - def __init__(self, timeout: float = 30.0) -> None: - """Initialize the AWS fetcher. + Args: + timeout (float): HTTP request timeout in seconds. + """ - Args: - timeout: HTTP request timeout in seconds. - """ + def __init__(self, timeout: float = 30.0) -> None: self.timeout = timeout def fetch(self, config: IPSourceConfig) -> list[str]: """Fetch AWS IP ranges. Args: - config: Source configuration with optional region/service filters. + config (IPSourceConfig): Source configuration with optional region/service filters. Returns: - List of AWS IP ranges. + list[str]: List of AWS IP ranges. Available services: AMAZON, EC2, S3, CLOUDFRONT, ROUTE53, ROUTE53_HEALTHCHECKS, API_GATEWAY, etc. @@ -434,11 +429,11 @@ def _matches_filters(self, prefix: dict[str, Any], config: IPSourceConfig) -> bo """Check if a prefix matches the configured filters. Args: - prefix: AWS prefix object. - config: Source configuration. + prefix (dict[str, Any]): AWS prefix object. + config (IPSourceConfig): Source configuration. Returns: - True if prefix matches filters. + bool: True if prefix matches filters. """ # Check region filter if config.regions: @@ -456,24 +451,23 @@ def _matches_filters(self, prefix: dict[str, Any], config: IPSourceConfig) -> bo class CloudflareIPFetcher(IPFetcher): - """Fetcher for Cloudflare's own IP ranges.""" + """Fetcher for Cloudflare's own IP ranges. - def __init__(self, timeout: float = 30.0) -> None: - """Initialize the Cloudflare fetcher. + Args: + timeout (float): HTTP request timeout in seconds. + """ - Args: - timeout: HTTP request timeout in seconds. - """ + def __init__(self, timeout: float = 30.0) -> None: self.timeout = timeout def fetch(self, config: IPSourceConfig) -> list[str]: """Fetch Cloudflare IP ranges. Args: - config: Source configuration. + config (IPSourceConfig): Source configuration. Returns: - List of Cloudflare IP ranges. + list[str]: List of Cloudflare IP ranges. """ ips: list[str] = [] @@ -501,10 +495,10 @@ def get_fetcher(source_type: SourceType) -> IPFetcher: """Get the appropriate fetcher for a source type. Args: - source_type: Type of IP source. + source_type (SourceType): Type of IP source. Returns: - Appropriate fetcher instance. + IPFetcher: Appropriate fetcher instance. Raises: ValueError: If source type is not supported. diff --git a/packages/cloudflare-api/src/cloudflare_api/ip_groups/manager.py b/packages/cloudflare-api/src/cloudflare_api/ip_groups/manager.py index 934f3a2..b121b7d 100644 --- a/packages/cloudflare-api/src/cloudflare_api/ip_groups/manager.py +++ b/packages/cloudflare-api/src/cloudflare_api/ip_groups/manager.py @@ -29,15 +29,15 @@ class SyncResult: """Result of syncing an IP group. Attributes: - group_name: Name of the IP group - cloudflare_list_name: Cloudflare list name - cloudflare_list_id: Cloudflare list ID - ips_count: Number of IPs synced - added: Number of IPs added - removed: Number of IPs removed - unchanged: Whether the list was unchanged - error: Error message if sync failed - duration_seconds: Time taken to sync + group_name (str): Name of the IP group + cloudflare_list_name (str): Cloudflare list name + cloudflare_list_id (str | None): Cloudflare list ID + ips_count (int): Number of IPs synced + added (int): Number of IPs added + removed (int): Number of IPs removed + unchanged (bool): Whether the list was unchanged + error (str | None): Error message if sync failed + duration_seconds (float): Time taken to sync """ group_name: str @@ -56,9 +56,9 @@ class IPCache: """Cache for fetched IP ranges. Attributes: - ips: Cached IP addresses - fetched_at: When the IPs were fetched - source_hash: Hash of the source config for invalidation + ips (list[str]): Cached IP addresses + fetched_at (datetime): When the IPs were fetched + source_hash (str): Hash of the source config for invalidation """ ips: list[str] = field(default_factory=list) @@ -71,6 +71,10 @@ class IPGroupManager: Handles fetching IPs from various sources and syncing them to Cloudflare. + Args: + config (IPGroupsConfig): IP groups configuration. + client (CloudflareAPIClient | None): Optional Cloudflare client. If not provided, creates one. + Example: ```python manager = IPGroupManager.from_config("ip_groups.yaml") @@ -91,12 +95,6 @@ def __init__( config: IPGroupsConfig, client: CloudflareAPIClient | None = None, ) -> None: - """Initialize the IP Group Manager. - - Args: - config: IP groups configuration. - client: Optional Cloudflare client. If not provided, creates one. - """ self.config = config self._client = client self._cache: dict[str, IPCache] = {} @@ -110,11 +108,11 @@ def from_config( """Create a manager from a config file. Args: - config_path: Path to the YAML config file. - client: Optional Cloudflare client. + config_path (str | Path): Path to the YAML config file. + client (CloudflareAPIClient | None): Optional Cloudflare client. Returns: - Configured IPGroupManager. + IPGroupManager: Configured IPGroupManager. """ config = load_config(config_path) return cls(config, client) @@ -130,10 +128,10 @@ def _get_source_hash(self, source: IPSourceConfig) -> str: """Get a hash of the source config for cache invalidation. Args: - source: Source configuration. + source (IPSourceConfig): Source configuration. Returns: - Hash string. + str: Hash string. """ config_str = json.dumps(source.model_dump(), sort_keys=True) # MD5 used only for cache key generation, not security purposes @@ -143,11 +141,11 @@ def _is_cache_valid(self, cache: IPCache, source: IPSourceConfig) -> bool: """Check if cached IPs are still valid. Args: - cache: Cached data. - source: Source configuration. + cache (IPCache): Cached data. + source (IPSourceConfig): Source configuration. Returns: - True if cache is valid. + bool: True if cache is valid. """ # Check if source config changed if cache.source_hash != self._get_source_hash(source): @@ -165,11 +163,11 @@ def fetch_source_ips( """Fetch IPs from a single source. Args: - source: Source configuration. - use_cache: Whether to use cached results. + source (IPSourceConfig): Source configuration. + use_cache (bool): Whether to use cached results. Returns: - List of IP addresses. + list[str]: List of IP addresses. """ cache_key = self._get_source_hash(source) @@ -201,11 +199,14 @@ def fetch_group_ips( """Fetch all IPs for a group from all sources. Args: - group: Group configuration. - use_cache: Whether to use cached results. + group (IPGroupConfig): Group configuration. + use_cache (bool): Whether to use cached results. Returns: - Deduplicated list of IP addresses. + list[str]: Deduplicated list of IP addresses. + + Raises: + Exception: If fetching from any source fails. """ all_ips: set[str] = set() @@ -233,13 +234,10 @@ def preview_group(self, group_name: str) -> dict[str, Any]: """Preview what would change for a group without applying. Args: - group_name: Name of the group to preview. + group_name (str): Name of the group to preview. Returns: - Dict with current and new IPs, and diff. - - Raises: - ValueError: If group not found. + dict[str, Any]: Dict with current and new IPs, and diff. """ group = self._get_group(group_name) @@ -275,14 +273,11 @@ def sync_group(self, group_name: str, dry_run: bool = False) -> SyncResult: """Sync a single IP group to Cloudflare. Args: - group_name: Name of the group to sync. - dry_run: If True, preview without applying changes. + group_name (str): Name of the group to sync. + dry_run (bool): If True, preview without applying changes. Returns: - SyncResult with details of the operation. - - Raises: - ValueError: If group not found or disabled. + SyncResult: SyncResult with details of the operation. """ start_time = time.time() group = self._get_group(group_name) @@ -373,10 +368,10 @@ def sync_all(self, dry_run: bool = False) -> list[SyncResult]: """Sync all enabled IP groups to Cloudflare. Args: - dry_run: If True, preview without applying changes. + dry_run (bool): If True, preview without applying changes. Returns: - List of SyncResults for each group. + list[SyncResult]: List of SyncResults for each group. """ results = [] @@ -405,7 +400,7 @@ def list_groups(self) -> list[dict[str, Any]]: """List all configured IP groups. Returns: - List of group summaries. + list[dict[str, Any]]: List of group summaries. """ return [ { @@ -424,10 +419,10 @@ def _get_group(self, group_name: str) -> IPGroupConfig: """Get a group by name. Args: - group_name: Name of the group. + group_name (str): Name of the group. Returns: - Group configuration. + IPGroupConfig: Group configuration. Raises: ValueError: If group not found. @@ -444,10 +439,10 @@ def _get_cloudflare_list_name(self, group: IPGroupConfig) -> str: """Get the Cloudflare list name for a group. Args: - group: Group configuration. + group (IPGroupConfig): Group configuration. Returns: - Cloudflare list name with optional prefix. + str: Cloudflare list name with optional prefix. """ prefix = self.config.cloudflare_list_prefix if prefix: From 182f48e71c8d276df2e1b8423ef8d1621890a5d1 Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 22:46:08 -0700 Subject: [PATCH 7/8] docs(cloudflare-auth): align docstrings with signatures for pydoclint Co-Authored-By: Claude Sonnet 4.6 --- .../src/cloudflare_auth/config.py | 47 +++--- .../src/cloudflare_auth/csrf.py | 34 ++-- .../src/cloudflare_auth/middleware.py | 61 +++----- .../cloudflare_auth/middleware_enhanced.py | 96 ++++++------ .../src/cloudflare_auth/models.py | 64 ++++---- .../src/cloudflare_auth/rate_limiter.py | 36 ++--- .../src/cloudflare_auth/redis_sessions.py | 58 ++++--- .../src/cloudflare_auth/security_helpers.py | 72 ++++----- .../src/cloudflare_auth/sessions.py | 52 +++---- .../src/cloudflare_auth/utils.py | 44 +++--- .../src/cloudflare_auth/validators.py | 34 ++-- .../src/cloudflare_auth/whitelist.py | 146 +++++++++--------- 12 files changed, 353 insertions(+), 391 deletions(-) diff --git a/packages/cloudflare-auth/src/cloudflare_auth/config.py b/packages/cloudflare-auth/src/cloudflare_auth/config.py index fff44d2..c4d1ad9 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/config.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/config.py @@ -31,22 +31,23 @@ class CloudflareSettings(BaseSettings): variables with sensible defaults for development. Attributes: - cloudflare_team_domain: Cloudflare Access team domain (e.g., "myteam") - cloudflare_audience_tag: Application audience tag from CF dashboard - cloudflare_enabled: Whether CF authentication is enabled - jwt_header_name: Header containing the JWT token - email_header_name: Header containing the authenticated email - jwt_algorithm: Algorithm for JWT validation (default: RS256) - jwt_cache_max_keys: Maximum cached signing keys - require_email_verification: Require email claim in token - log_auth_failures: Log failed authentication attempts - require_cloudflare_headers: Require CF-Ray header for validation - allowed_tunnel_ips: List of allowed tunnel IPs (optional) - allowed_email_domains: Restrict to specific email domains - cookie_path: Session cookie path - cookie_secure: Use secure cookies - cookie_samesite: Cookie SameSite attribute - cookie_domain: Cookie domain (optional) + model_config: Pydantic settings configuration + cloudflare_team_domain (str): Cloudflare Access team domain (e.g., "myteam") + cloudflare_audience_tag (str): Application audience tag from CF dashboard + cloudflare_enabled (bool): Whether CF authentication is enabled + jwt_header_name (str): Header containing the JWT token + email_header_name (str): Header containing the authenticated email + jwt_algorithm (str): Algorithm for JWT validation (default: RS256) + jwt_cache_max_keys (int): Maximum cached signing keys + require_email_verification (bool): Require email claim in token + log_auth_failures (bool): Log failed authentication attempts + require_cloudflare_headers (bool): Require CF-Ray header for validation + allowed_tunnel_ips (list[str]): List of allowed tunnel IPs (optional) + allowed_email_domains (list[str]): Restrict to specific email domains + cookie_path (str): Session cookie path + cookie_secure (bool): Use secure cookies + cookie_samesite (Literal['lax', 'strict', 'none']): Cookie SameSite attribute + cookie_domain (str | None): Cookie domain (optional) """ model_config = SettingsConfigDict( @@ -137,10 +138,10 @@ def parse_comma_separated(cls, v: str | list[str] | None) -> list[str]: """Parse comma-separated strings into lists. Args: - v: Comma-separated string or list of strings. + v (str | list[str] | None): Comma-separated string or list of strings. Returns: - List of strings parsed from input. + list[str]: List of strings parsed from input. """ if v is None: return [] @@ -153,7 +154,7 @@ def certs_url(self) -> str | None: """Get the Cloudflare certificate URL. Returns: - URL for JWKS endpoint, or None if team domain not configured. + str | None: URL for JWKS endpoint, or None if team domain not configured. """ if not self.cloudflare_team_domain: return None @@ -164,7 +165,7 @@ def issuer(self) -> str | None: """Get the expected token issuer. Returns: - Issuer URL, or None if team domain not configured. + str | None: Issuer URL, or None if team domain not configured. """ if not self.cloudflare_team_domain: return None @@ -174,10 +175,10 @@ def is_email_allowed(self, email: str) -> bool: """Check if an email address is allowed. Args: - email: Email address to check + email (str): Email address to check Returns: - True if email is allowed (no restrictions or matches allowed domains) + bool: True if email is allowed (no restrictions or matches allowed domains) """ if not self.allowed_email_domains: return True @@ -197,7 +198,7 @@ def get_cloudflare_settings() -> CloudflareSettings: Settings are loaded from environment variables and .env file. Returns: - CloudflareSettings instance + CloudflareSettings: CloudflareSettings instance Example: settings = get_cloudflare_settings() diff --git a/packages/cloudflare-auth/src/cloudflare_auth/csrf.py b/packages/cloudflare-auth/src/cloudflare_auth/csrf.py index 2ce55de..22a3d62 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/csrf.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/csrf.py @@ -32,6 +32,11 @@ class CSRFProtection: 2. Setting it in both a cookie and requiring it in request headers/body 3. Validating that both values match + Args: + cookie_name (str): Name of CSRF cookie (default: "csrf_token") + header_name (str): Name of CSRF header (default: "X-CSRF-Token") + secret_key (str | None): Optional secret key for token generation + Example: csrf = CSRFProtection() @@ -52,13 +57,6 @@ def __init__( header_name: str = "X-CSRF-Token", secret_key: str | None = None, ) -> None: - """Initialize CSRF protection. - - Args: - cookie_name: Name of CSRF cookie (default: "csrf_token") - header_name: Name of CSRF header (default: "X-CSRF-Token") - secret_key: Optional secret key for token generation - """ self.cookie_name = cookie_name self.header_name = header_name self.secret_key = secret_key or secrets.token_hex(32) @@ -73,10 +71,10 @@ def generate_token(self, session_id: str | None = None) -> str: """Generate a new CSRF token. Args: - session_id: Optional session ID to bind token to + session_id (str | None): Optional session ID to bind token to Returns: - CSRF token string + str: CSRF token string """ # Generate random token secrets.token_bytes(32) @@ -104,12 +102,12 @@ def validate_token( """Validate CSRF token from cookie and header. Args: - cookie_token: Token from cookie - header_token: Token from header - constant_time: Use constant-time comparison (default: True) + cookie_token (str | None): Token from cookie + header_token (str | None): Token from header + constant_time (bool): Use constant-time comparison (default: True) Returns: - True if tokens match and are valid + bool: True if tokens match and are valid """ # Both tokens must be present if not cookie_token or not header_token: @@ -137,10 +135,10 @@ def validate_request( Args: request: FastAPI/Starlette Request object - methods_to_protect: HTTP methods that require CSRF validation + methods_to_protect (set[str] | None): HTTP methods that require CSRF validation Returns: - True if validation passes or not required for this method + bool: True if validation passes or not required for this method """ # Skip CSRF check for safe methods if methods_to_protect is None: @@ -167,11 +165,11 @@ def get_csrf_protection( """Get or create global CSRF protection instance. Args: - cookie_name: Name of CSRF cookie - header_name: Name of CSRF header + cookie_name (str): Name of CSRF cookie + header_name (str): Name of CSRF header Returns: - CSRFProtection instance + CSRFProtection: CSRFProtection instance """ global _global_csrf_protection # noqa: PLW0603 diff --git a/packages/cloudflare-auth/src/cloudflare_auth/middleware.py b/packages/cloudflare-auth/src/cloudflare_auth/middleware.py index 5527f85..ed6b384 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/middleware.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/middleware.py @@ -74,11 +74,15 @@ class CloudflareAuthMiddleware(BaseHTTPMiddleware): Cloudflare Access headers. It validates tokens, extracts user information, and makes it available via request.state.user. - Attributes: - validator: JWT token validator instance - settings: Cloudflare configuration settings - excluded_paths: Paths that bypass authentication - require_auth: Whether to enforce authentication (True) or just parse it (False) + Args: + app (Any): The ASGI application + settings (CloudflareSettings | None): Optional CloudflareSettings instance + validator (CloudflareJWTValidator | None): Optional CloudflareJWTValidator instance + excluded_paths (list[str] | None): List of paths to exclude from authentication + require_auth (bool): Whether to require authentication (vs. optional) + enable_rate_limiting (bool): Whether to enable rate limiting (default: True) + rate_limit_attempts (int): Max authentication attempts per window (default: 5) + rate_limit_window (int): Rate limit window in seconds (default: 60) Example: # In your FastAPI app @@ -100,18 +104,6 @@ def __init__( rate_limit_attempts: int = 5, rate_limit_window: int = 60, ) -> None: - """Initialize Cloudflare authentication middleware. - - Args: - app: The ASGI application - settings: Optional CloudflareSettings instance - validator: Optional CloudflareJWTValidator instance - excluded_paths: List of paths to exclude from authentication - require_auth: Whether to require authentication (vs. optional) - enable_rate_limiting: Whether to enable rate limiting (default: True) - rate_limit_attempts: Max authentication attempts per window (default: 5) - rate_limit_window: Rate limit window in seconds (default: 60) - """ super().__init__(app) self.settings = settings or get_cloudflare_settings() self.validator = validator or CloudflareJWTValidator(self.settings) @@ -140,10 +132,10 @@ def _is_path_excluded(self, path: str) -> bool: """Check if a path should bypass authentication. Args: - path: Request path to check + path (str): Request path to check Returns: - True if path is in excluded list + bool: True if path is in excluded list """ # Exact match or prefix match return any( @@ -158,7 +150,7 @@ def _validate_cloudflare_origin(self, request: Request) -> None: tunnel and not directly to the application (bypassing Cloudflare Access). Args: - request: The incoming request + request (Request): The incoming request Raises: HTTPException: If request doesn't have required Cloudflare headers @@ -223,11 +215,11 @@ async def dispatch( 5. Process request through application Args: - request: The incoming request - call_next: The next middleware/endpoint in the chain + request (Request): The incoming request + call_next (Callable): The next middleware/endpoint in the chain Returns: - Response from the application + Response: Response from the application Raises: HTTPException: If authentication fails and is required @@ -403,13 +395,10 @@ async def _authenticate_request(self, request: Request) -> CloudflareUser | None """Authenticate request using Cloudflare headers. Args: - request: The incoming request + request (Request): The incoming request Returns: - CloudflareUser object if authentication succeeds, None if optional - - Raises: - HTTPException: If authentication fails and is required + CloudflareUser | None: CloudflareUser object if authentication succeeds, None if optional """ self._check_rate_limit(request) @@ -449,10 +438,10 @@ def setup_cloudflare_auth( to your FastAPI application with sensible defaults. Args: - app: The FastAPI application instance - excluded_paths: Optional list of paths to exclude from auth - require_auth: Whether authentication is required (vs. optional) - settings: Optional CloudflareSettings instance + app (Any): The FastAPI application instance + excluded_paths (list[str] | None): Optional list of paths to exclude from auth + require_auth (bool): Whether authentication is required (vs. optional) + settings (CloudflareSettings | None): Optional CloudflareSettings instance Example: from fastapi import FastAPI @@ -512,10 +501,10 @@ def get_current_user(request: Request) -> CloudflareUser: the authenticated user in route handlers. Args: - request: The FastAPI request object + request (Request): The FastAPI request object Returns: - CloudflareUser object + CloudflareUser: CloudflareUser object Raises: HTTPException: If user is not authenticated @@ -551,10 +540,10 @@ def get_current_user_optional(request: Request) -> CloudflareUser | None: raising an exception if the user is not authenticated. Args: - request: The FastAPI request object + request (Request): The FastAPI request object Returns: - CloudflareUser object or None if not authenticated + CloudflareUser | None: CloudflareUser object or None if not authenticated Example: from fastapi import Depends diff --git a/packages/cloudflare-auth/src/cloudflare_auth/middleware_enhanced.py b/packages/cloudflare-auth/src/cloudflare_auth/middleware_enhanced.py index 6fc8b67..28435ba 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/middleware_enhanced.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/middleware_enhanced.py @@ -70,6 +70,19 @@ class CloudflareAuthMiddlewareEnhanced(BaseHTTPMiddleware): - Session management - Development mode support + Args: + app (Any): ASGI application + settings (CloudflareSettings | None): Cloudflare configuration settings + validator (CloudflareJWTValidator | None): JWT token validator + whitelist_validator (EmailWhitelistValidator | None): Email whitelist validator (required) + session_manager (SimpleSessionManager | None): Session manager instance + excluded_paths (list[str] | None): Paths to exclude from authentication + enable_sessions (bool): Whether to use session cookies + require_auth (bool): Whether authentication is required + enable_rate_limiting (bool): Whether to enable rate limiting (default: True) + rate_limit_attempts (int): Max authentication attempts per window (default: 5) + rate_limit_window (int): Rate limit window in seconds (default: 60) + Example: middleware = CloudflareAuthMiddlewareEnhanced( app=app, @@ -94,21 +107,6 @@ def __init__( rate_limit_attempts: int = 5, rate_limit_window: int = 60, ) -> None: - """Initialize enhanced authentication middleware. - - Args: - app: ASGI application - settings: Cloudflare configuration settings - validator: JWT token validator - whitelist_validator: Email whitelist validator (required) - session_manager: Session manager instance - excluded_paths: Paths to exclude from authentication - enable_sessions: Whether to use session cookies - require_auth: Whether authentication is required - enable_rate_limiting: Whether to enable rate limiting (default: True) - rate_limit_attempts: Max authentication attempts per window (default: 5) - rate_limit_window: Rate limit window in seconds (default: 60) - """ super().__init__(app) self.settings = settings or get_cloudflare_settings() self.jwt_validator = validator or CloudflareJWTValidator(self.settings) @@ -157,10 +155,10 @@ def _is_path_excluded(self, path: str) -> bool: """Check if a path should bypass authentication. Args: - path: Request path to check + path (str): Request path to check Returns: - True if path is excluded from auth + bool: True if path is excluded from auth """ return any( path == excluded or path.startswith(excluded.rstrip("/") + "/") @@ -184,11 +182,14 @@ async def dispatch( 7. Inject user into request.state Args: - request: Incoming request - call_next: Next middleware/endpoint + request (Request): Incoming request + call_next (Callable): Next middleware/endpoint Returns: - Response from application + Response: Response from application + + Raises: + HTTPException: If authentication fails and is required """ # Skip authentication for excluded paths if self._is_path_excluded(request.url.path): @@ -375,13 +376,10 @@ async def _authenticate_request(self, request: Request) -> CloudflareUser | None """Authenticate request using JWT and whitelist. Args: - request: Incoming request + request (Request): Incoming request Returns: - CloudflareUser object if authenticated, None if optional - - Raises: - HTTPException: If authentication fails and is required + CloudflareUser | None: CloudflareUser object if authenticated, None if optional """ self._check_rate_limit(request) @@ -438,11 +436,11 @@ def _user_from_session( """Recreate CloudflareUser from session data. Args: - session: Session data dictionary - session_id: Session identifier + session (dict[str, Any]): Session data dictionary + session_id (str): Session identifier Returns: - CloudflareUser instance + CloudflareUser: CloudflareUser instance """ from cloudflare_auth.models import CloudflareJWTClaims @@ -473,8 +471,8 @@ def _set_session_cookie(self, response: Response, session_id: str) -> None: Uses security settings from configuration for proper cookie attributes. Args: - response: Response to modify - session_id: Session ID to set + response (Response): Response to modify + session_id (str): Session ID to set """ # Prepare cookie kwargs from settings # Get cookie configuration @@ -533,19 +531,19 @@ def setup_cloudflare_auth_enhanced( - Development mode Args: - app: FastAPI application - whitelist: List of allowed emails/domains (e.g., ["user@example.com", "@company.com"]) - admin_emails: List of admin emails - full_users: List of full-tier users - limited_users: List of limited-tier users - excluded_paths: Paths to exclude from auth - enable_sessions: Whether to use session cookies - require_auth: Whether authentication is required - session_timeout: Session timeout in seconds - settings: Optional CloudflareSettings instance + app (Any): FastAPI application + whitelist (list[str] | None): List of allowed emails/domains (e.g., ["user@example.com", "@company.com"]) + admin_emails (list[str] | None): List of admin emails + full_users (list[str] | None): List of full-tier users + limited_users (list[str] | None): List of limited-tier users + excluded_paths (list[str] | None): Paths to exclude from auth + enable_sessions (bool): Whether to use session cookies + require_auth (bool): Whether authentication is required + session_timeout (int): Session timeout in seconds + settings (CloudflareSettings | None): Optional CloudflareSettings instance Returns: - None - middleware is added directly to the app + None: Middleware is added directly to the app Example: app = FastAPI() @@ -631,10 +629,10 @@ def get_current_user(request: Request) -> CloudflareUser: """FastAPI dependency to get current authenticated user. Args: - request: FastAPI request + request (Request): FastAPI request Returns: - CloudflareUser object + CloudflareUser: CloudflareUser object Raises: HTTPException: If user is not authenticated @@ -660,10 +658,10 @@ def get_current_user_optional(request: Request) -> CloudflareUser | None: """FastAPI dependency for optional authentication. Args: - request: FastAPI request + request (Request): FastAPI request Returns: - CloudflareUser or None + CloudflareUser | None: CloudflareUser or None Example: @app.get("/info") @@ -679,10 +677,10 @@ def require_admin(request: Request) -> CloudflareUser: """FastAPI dependency requiring admin privileges. Args: - request: FastAPI request + request (Request): FastAPI request Returns: - CloudflareUser object + CloudflareUser: CloudflareUser object Raises: HTTPException: If not authenticated or not admin @@ -707,10 +705,10 @@ def require_tier(minimum_tier: UserTier) -> Callable: """Create a dependency that requires a minimum user tier. Args: - minimum_tier: Minimum required tier + minimum_tier (UserTier): Minimum required tier Returns: - Dependency function + Callable: Dependency function Example: require_full = require_tier(UserTier.FULL) diff --git a/packages/cloudflare-auth/src/cloudflare_auth/models.py b/packages/cloudflare-auth/src/cloudflare_auth/models.py index 1c4a9d5..8d3c663 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/models.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/models.py @@ -32,15 +32,15 @@ class CloudflareJWTClaims(BaseModel): Access JWT assertion token. Attributes: - email: Authenticated user's email address - iss: Token issuer (Cloudflare team domain) - aud: Audience tag for the application - sub: Subject (user identifier) - iat: Issued at timestamp - exp: Expiration timestamp - nonce: Nonce for replay protection - identity_nonce: Identity nonce - custom_claims: Any additional custom claims + email (EmailStr): Authenticated user's email address + iss (str): Token issuer (Cloudflare team domain) + aud (list[str] | str): Audience tag for the application + sub (str): Subject (user identifier) + iat (int): Issued at timestamp + exp (int): Expiration timestamp + nonce (str | None): Nonce for replay protection + identity_nonce (str | None): Identity nonce + custom (dict[str, Any]): Any additional custom claims Example: claims = CloudflareJWTClaims( @@ -73,7 +73,7 @@ def issued_at(self) -> datetime: """Get issued at time as datetime. Returns: - Datetime when token was issued + datetime: Datetime when token was issued """ return datetime.fromtimestamp(self.iat, tz=timezone.utc) @@ -82,7 +82,7 @@ def expires_at(self) -> datetime: """Get expiration time as datetime. Returns: - Datetime when token expires + datetime: Datetime when token expires """ return datetime.fromtimestamp(self.exp, tz=timezone.utc) @@ -91,7 +91,7 @@ def is_expired(self) -> bool: """Check if token is expired. Returns: - True if token is expired + bool: True if token is expired """ return datetime.now(tz=timezone.utc) >= self.expires_at @@ -99,7 +99,7 @@ def get_audience_list(self) -> list[str]: """Get audience as a list. Returns: - List of audience tags + list[str]: List of audience tags """ if isinstance(self.aud, str): return [self.aud] @@ -114,13 +114,13 @@ class CloudflareUser(BaseModel): tier-based access control and admin privileges. Attributes: - email: User's email address - user_id: Unique user identifier from JWT subject - claims: Full JWT claims object - authenticated_at: When the user was authenticated - user_tier: User's access tier (admin/full/limited) - is_admin: Whether user has admin privileges - session_id: Optional session identifier + email (EmailStr): User's email address + user_id (str): Unique user identifier from JWT subject + claims (CloudflareJWTClaims): Full JWT claims object + authenticated_at (datetime): When the user was authenticated + user_tier (UserTier): User's access tier (admin/full/limited) + is_admin (bool): Whether user has admin privileges + session_id (str | None): Optional session identifier Example: user = CloudflareUser( @@ -170,13 +170,13 @@ def from_jwt_claims( """Create CloudflareUser from JWT claims with tier information. Args: - claims: Validated JWT claims - user_tier: User's access tier - is_admin: Whether user has admin privileges - session_id: Optional session identifier + claims (CloudflareJWTClaims): Validated JWT claims + user_tier (UserTier): User's access tier + is_admin (bool): Whether user has admin privileges + session_id (str | None): Optional session identifier Returns: - CloudflareUser instance + CloudflareUser: CloudflareUser instance Example: claims = validator.validate_token(token) @@ -201,7 +201,7 @@ def email_domain(self) -> str: """Get the domain from user's email. Returns: - Email domain (e.g., 'example.com') + str: Email domain (e.g., 'example.com') """ return self.email.split("@")[-1] if "@" in self.email else "" @@ -210,7 +210,7 @@ def email_username(self) -> str: """Get the username portion of email. Returns: - Username before @ symbol + str: Username before @ symbol """ return self.email.split("@")[0] if "@" in self.email else self.email @@ -218,10 +218,10 @@ def has_email_domain(self, domain: str) -> bool: """Check if user's email is from a specific domain. Args: - domain: Domain to check (case-insensitive) + domain (str): Domain to check (case-insensitive) Returns: - True if email domain matches + bool: True if email domain matches Example: if user.has_email_domain("example.com"): @@ -235,7 +235,7 @@ def can_access_premium_models(self) -> bool: """Check if user can access premium models. Returns: - True for ADMIN and FULL tiers, False for LIMITED + bool: True for ADMIN and FULL tiers, False for LIMITED """ return self.user_tier.can_access_premium_models @@ -244,7 +244,7 @@ def role(self) -> str: """Get user role string. Returns: - 'admin' or 'user' + str: 'admin' or 'user' """ return "admin" if self.is_admin else "user" @@ -252,7 +252,7 @@ def model_dump_safe(self) -> dict[str, Any]: """Dump model with only safe fields for logging. Returns: - Dictionary with safe fields (excludes sensitive claims) + dict[str, Any]: Dictionary with safe fields (excludes sensitive claims) """ return { "email": self.email, diff --git a/packages/cloudflare-auth/src/cloudflare_auth/rate_limiter.py b/packages/cloudflare-auth/src/cloudflare_auth/rate_limiter.py index 697c058..cbd53a2 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/rate_limiter.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/rate_limiter.py @@ -40,6 +40,11 @@ class InMemoryRateLimiter: deployments. For production with multiple instances, use Redis or similar distributed solutions. + Args: + max_attempts (int): Maximum attempts allowed within window + window_seconds (int): Time window in seconds + cleanup_interval (int): Seconds between cleanup operations + Example: limiter = InMemoryRateLimiter( max_attempts=5, @@ -59,13 +64,6 @@ def __init__( window_seconds: int = 60, cleanup_interval: int = 300, ) -> None: - """Initialize rate limiter. - - Args: - max_attempts: Maximum attempts allowed within window - window_seconds: Time window in seconds - cleanup_interval: Seconds between cleanup operations - """ self.max_attempts = max_attempts self.window_seconds = window_seconds self.cleanup_interval = cleanup_interval @@ -85,10 +83,10 @@ def is_allowed(self, identifier: str) -> bool: """Check if request is allowed based on rate limit. Args: - identifier: IP address or other identifier + identifier (str): IP address or other identifier Returns: - True if request is allowed, False if rate limited + bool: True if request is allowed, False if rate limited """ with self.lock: self._cleanup_if_needed() @@ -121,7 +119,7 @@ def record_attempt(self, identifier: str) -> None: """Record an authentication attempt. Args: - identifier: IP address or other identifier + identifier (str): IP address or other identifier """ with self.lock: self.attempts[identifier].append(datetime.now(tz=timezone.utc)) @@ -130,7 +128,7 @@ def reset(self, identifier: str) -> None: """Reset rate limit for an identifier. Args: - identifier: IP address or other identifier + identifier (str): IP address or other identifier """ with self.lock: if identifier in self.attempts: @@ -141,10 +139,10 @@ def get_remaining_attempts(self, identifier: str) -> int: """Get remaining attempts for an identifier. Args: - identifier: IP address or other identifier + identifier (str): IP address or other identifier Returns: - Number of remaining attempts + int: Number of remaining attempts """ with self.lock: current_time = datetime.now(tz=timezone.utc) @@ -166,10 +164,10 @@ def get_retry_after(self, identifier: str) -> int: """Get seconds until identifier can retry. Args: - identifier: IP address or other identifier + identifier (str): IP address or other identifier Returns: - Seconds until next attempt is allowed (0 if allowed now) + int: Seconds until next attempt is allowed (0 if allowed now) """ with self.lock: if identifier not in self.attempts or not self.attempts[identifier]: @@ -231,7 +229,7 @@ def get_stats(self) -> dict: """Get rate limiter statistics. Returns: - Dictionary with current statistics + dict: Dictionary with current statistics """ with self.lock: total_tracked = len(self.attempts) @@ -259,11 +257,11 @@ def get_rate_limiter( """Get or create global rate limiter instance. Args: - max_attempts: Maximum attempts per window - window_seconds: Time window in seconds + max_attempts (int): Maximum attempts per window + window_seconds (int): Time window in seconds Returns: - InMemoryRateLimiter instance + InMemoryRateLimiter: InMemoryRateLimiter instance """ global _global_rate_limiter # noqa: PLW0603 diff --git a/packages/cloudflare-auth/src/cloudflare_auth/redis_sessions.py b/packages/cloudflare-auth/src/cloudflare_auth/redis_sessions.py index 8fde66a..2566936 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/redis_sessions.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/redis_sessions.py @@ -68,6 +68,15 @@ class RedisSessionManager: Requirements: pip install redis>=5.0.0 + Args: + redis_url (str): Redis connection URL + session_timeout (int): Session timeout in seconds (default: 1 hour) + key_prefix (str): Prefix for Redis keys (default: "cf_auth_session") + + Raises: + ImportError: If redis package is not installed + redis.ConnectionError: If cannot connect to Redis + Example: manager = RedisSessionManager( redis_url="redis://localhost:6379/0", @@ -88,17 +97,6 @@ def __init__( session_timeout: int = 3600, key_prefix: str = "cf_auth_session", ) -> None: - """Initialize Redis session manager. - - Args: - redis_url: Redis connection URL - session_timeout: Session timeout in seconds (default: 1 hour) - key_prefix: Prefix for Redis keys (default: "cf_auth_session") - - Raises: - ImportError: If redis package is not installed - redis.ConnectionError: If cannot connect to Redis - """ if not REDIS_AVAILABLE or redis is None: msg = ( "Redis package is required for RedisSessionManager. " @@ -133,10 +131,10 @@ def _make_key(self, session_id: str) -> str: """Generate Redis key for session. Args: - session_id: Session identifier + session_id (str): Session identifier Returns: - Redis key with prefix + str: Redis key with prefix """ return f"{self.key_prefix}:{session_id}" @@ -150,13 +148,13 @@ def create_session( """Create a new session in Redis. Args: - email: User email address - is_admin: Whether user has admin privileges - user_tier: User tier (admin, full, limited) - cf_context: Additional Cloudflare context + email (str): User email address + is_admin (bool): Whether user has admin privileges + user_tier (str): User tier (admin, full, limited) + cf_context (dict[str, Any] | None): Additional Cloudflare context Returns: - Session ID (cryptographically secure random token) + str: Session ID (cryptographically secure random token) Example: session_id = manager.create_session( @@ -201,10 +199,10 @@ def get_session(self, session_id: str) -> dict[str, Any] | None: - Refreshes TTL Args: - session_id: Session identifier + session_id (str): Session identifier Returns: - Session data if valid, None if expired or not found + dict[str, Any] | None: Session data if valid, None if expired or not found """ if not session_id: return None @@ -246,10 +244,10 @@ def invalidate_session(self, session_id: str) -> bool: """Invalidate (delete) a session from Redis. Args: - session_id: Session to invalidate + session_id (str): Session to invalidate Returns: - True if session was found and deleted + bool: True if session was found and deleted """ key = self._make_key(session_id) deleted = self.redis_client.delete(key) @@ -264,10 +262,10 @@ def refresh_session(self, session_id: str) -> bool: """Refresh a session's expiration time. Args: - session_id: Session to refresh + session_id (str): Session to refresh Returns: - True if session was found and refreshed + bool: True if session was found and refreshed """ key = self._make_key(session_id) @@ -288,7 +286,7 @@ def get_session_count(self) -> int: """Get the current number of active sessions. Returns: - Number of active sessions + int: Number of active sessions """ pattern = f"{self.key_prefix}:*" keys = self.redis_client.keys(pattern) @@ -301,10 +299,10 @@ def get_user_sessions(self, email: str) -> list[str]: Note: This operation can be expensive on large datasets. Args: - email: User email address + email (str): User email address Returns: - List of session IDs for the user + list[str]: List of session IDs for the user """ pattern = f"{self.key_prefix}:*" keys = self.redis_client.keys(pattern) @@ -340,7 +338,7 @@ def cleanup_expired_sessions(self) -> int: so this is a no-op for RedisSessionManager. Returns: - Always returns 0 (Redis handles cleanup automatically) + int: Always returns 0 (Redis handles cleanup automatically) """ # Redis automatically removes expired keys logger.debug("Redis handles expiration automatically (no cleanup needed)") @@ -350,7 +348,7 @@ def get_stats(self) -> dict[str, Any]: """Get session manager statistics. Returns: - Dictionary with session statistics + dict[str, Any]: Dictionary with session statistics """ active_sessions = self.get_session_count() @@ -367,7 +365,7 @@ def health_check(self) -> bool: """Check if Redis connection is healthy. Returns: - True if Redis is reachable and responsive + bool: True if Redis is reachable and responsive """ try: result = self.redis_client.ping() diff --git a/packages/cloudflare-auth/src/cloudflare_auth/security_helpers.py b/packages/cloudflare-auth/src/cloudflare_auth/security_helpers.py index 76b0ec7..943327d 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/security_helpers.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/security_helpers.py @@ -43,6 +43,11 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): - Permissions-Policy: Restrictive permissions - Strict-Transport-Security: HSTS (production only) + Args: + app (Any): ASGI application + csp_policy (str | None): Custom Content Security Policy + enable_hsts (bool): Enable HSTS headers + Example: app.add_middleware(SecurityHeadersMiddleware) """ @@ -53,13 +58,6 @@ def __init__( csp_policy: str | None = None, enable_hsts: bool = True, ) -> None: - """Initialize security headers middleware. - - Args: - app: ASGI application - csp_policy: Custom Content Security Policy - enable_hsts: Enable HSTS headers - """ super().__init__(app) self.csp_policy = csp_policy or self._default_csp_policy() self.enable_hsts = enable_hsts @@ -68,7 +66,7 @@ def _default_csp_policy(self) -> str: """Generate default Content Security Policy. Returns: - CSP policy string + str: CSP policy string """ return ( "default-src 'self'; " @@ -90,11 +88,11 @@ async def dispatch( """Add security headers to response. Args: - request: Incoming request - call_next: Next middleware/endpoint + request (Request): Incoming request + call_next (Callable): Next middleware/endpoint Returns: - Response with security headers + Response: Response with security headers """ response = await call_next(request) @@ -133,11 +131,11 @@ def create_session_cleanup_task( expired sessions to prevent memory leaks. Args: - session_manager: Session manager to clean - cleanup_interval: Cleanup interval in seconds (default: 5 minutes) + session_manager (SimpleSessionManager): Session manager to clean + cleanup_interval (int): Cleanup interval in seconds (default: 5 minutes) Returns: - Asyncio task handle + asyncio.Task: Asyncio task handle Example: @app.on_event("startup") @@ -179,6 +177,9 @@ class AuditLogger: This class provides structured logging for admin actions, authentication events, and other security-critical operations. + Args: + logger_name (str): Name for the audit logger + Example: audit = AuditLogger() @@ -197,11 +198,6 @@ class AuditLogger: """ def __init__(self, logger_name: str = "audit") -> None: - """Initialize audit logger. - - Args: - logger_name: Name for the audit logger - """ self.logger = logging.getLogger(logger_name) def log_admin_action( @@ -215,11 +211,11 @@ def log_admin_action( """Log administrative action. Args: - admin_email: Email of admin performing action - action: Action performed (e.g., "create_user", "delete_data") - target: Target of action (e.g., affected user email) - result: Result of action ("success", "failure", "denied") - details: Additional details dictionary + admin_email (str): Email of admin performing action + action (str): Action performed (e.g., "create_user", "delete_data") + target (str | None): Target of action (e.g., affected user email) + result (str): Result of action ("success", "failure", "denied") + details (dict[str, Any] | None): Additional details dictionary """ self.logger.info( "ADMIN_ACTION: %s performed %s on %s (result: %s)", @@ -248,11 +244,11 @@ def log_auth_event( """Log authentication event. Args: - event_type: Type of auth event ("login", "logout", "failed_auth") - user_email: User's email address - ip_address: IP address of request - result: Result of event ("success", "failure") - details: Additional details dictionary + event_type (str): Type of auth event ("login", "logout", "failed_auth") + user_email (str | None): User's email address + ip_address (str | None): IP address of request + result (str): Result of event ("success", "failure") + details (dict[str, Any] | None): Additional details dictionary """ self.logger.info( "AUTH_EVENT: %s for %s from %s (result: %s)", @@ -280,10 +276,10 @@ def log_access_denied( """Log access denial. Args: - user_email: User's email address - resource: Resource that was denied - reason: Reason for denial - ip_address: IP address of request + user_email (str): User's email address + resource (str): Resource that was denied + reason (str): Reason for denial + ip_address (str | None): IP address of request """ self.logger.warning( "ACCESS_DENIED: %s denied access to %s (reason: %s) from %s", @@ -310,10 +306,10 @@ def log_security_event( """Log general security event. Args: - event_type: Type of security event - severity: Severity level ("low", "medium", "high", "critical") - description: Event description - details: Additional details dictionary + event_type (str): Type of security event + severity (str): Severity level ("low", "medium", "high", "critical") + description (str): Event description + details (dict[str, Any] | None): Additional details dictionary """ log_method = { "low": self.logger.info, @@ -344,7 +340,7 @@ def get_audit_logger() -> AuditLogger: """Get singleton audit logger instance. Returns: - AuditLogger instance + AuditLogger: AuditLogger instance """ global _audit_logger_instance # noqa: PLW0603 if _audit_logger_instance is None: diff --git a/packages/cloudflare-auth/src/cloudflare_auth/sessions.py b/packages/cloudflare-auth/src/cloudflare_auth/sessions.py index 786303f..c7d2caa 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/sessions.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/sessions.py @@ -31,7 +31,7 @@ class SimpleSessionManager: This manager provides session tracking for authenticated users, maintaining session state and handling expiration. - ⚠️ SECURITY WARNING: + SECURITY WARNING: This in-memory implementation is NOT suitable for production use: - Sessions are lost on application restart - Not shared across multiple instances @@ -50,6 +50,9 @@ class SimpleSessionManager: For production use with multiple instances, consider using a distributed session store (Redis, Memcached, etc.). + Args: + session_timeout (int): Session timeout in seconds (default: 1 hour) + Example: manager = SimpleSessionManager(session_timeout=3600) session_id = manager.create_session( @@ -61,11 +64,6 @@ class SimpleSessionManager: """ def __init__(self, session_timeout: int = 3600) -> None: - """Initialize session manager. - - Args: - session_timeout: Session timeout in seconds (default: 1 hour) - """ self.sessions: dict[str, dict[str, Any]] = {} self.session_timeout = session_timeout logger.info("Initialized session manager with %ss timeout", session_timeout) @@ -80,13 +78,13 @@ def create_session( """Create a new session for the user. Args: - email: User email address - is_admin: Whether user has admin privileges - user_tier: User tier (admin, full, limited) - cf_context: Additional Cloudflare context (headers, metadata) + email (str): User email address + is_admin (bool): Whether user has admin privileges + user_tier (str): User tier (admin, full, limited) + cf_context (dict[str, Any] | None): Additional Cloudflare context (headers, metadata) Returns: - Session ID (cryptographically secure random token) + str: Session ID (cryptographically secure random token) Example: session_id = manager.create_session( @@ -123,10 +121,10 @@ def get_session(self, session_id: str) -> dict[str, Any] | None: for valid sessions. Args: - session_id: Session identifier + session_id (str): Session identifier Returns: - Session data if valid, None if expired or not found + dict[str, Any] | None: Session data if valid, None if expired or not found Example: session = manager.get_session(session_id) @@ -155,10 +153,10 @@ def invalidate_session(self, session_id: str) -> bool: """Invalidate a session. Args: - session_id: Session to invalidate + session_id (str): Session to invalidate Returns: - True if session was found and removed + bool: True if session was found and removed Example: # Logout @@ -176,10 +174,10 @@ def refresh_session(self, session_id: str) -> bool: """Refresh a session's last accessed time. Args: - session_id: Session to refresh + session_id (str): Session to refresh Returns: - True if session was found and refreshed + bool: True if session was found and refreshed """ session = self.sessions.get(session_id) if session: @@ -191,10 +189,10 @@ def _is_session_expired(self, session: dict[str, Any]) -> bool: """Check if session has expired. Args: - session: Session data dictionary + session (dict[str, Any]): Session data dictionary Returns: - True if session has exceeded timeout + bool: True if session has exceeded timeout """ expiry = session["last_accessed"] + timedelta(seconds=self.session_timeout) return datetime.now(tz=timezone.utc) >= expiry @@ -206,7 +204,7 @@ def cleanup_expired_sessions(self) -> int: expired sessions and free memory. Returns: - Number of sessions cleaned up + int: Number of sessions cleaned up Example: # In a background task @@ -235,7 +233,7 @@ def get_session_count(self) -> int: """Get the current number of active sessions. Returns: - Number of active sessions + int: Number of active sessions """ return len(self.sessions) @@ -243,10 +241,10 @@ def get_user_sessions(self, email: str) -> list[str]: """Get all session IDs for a specific user. Args: - email: User email address + email (str): User email address Returns: - List of session IDs for the user + list[str]: List of session IDs for the user """ return [ session_id @@ -260,10 +258,10 @@ def get_session_info(self, session_id: str) -> dict[str, Any] | None: Returns session data without sensitive information. Args: - session_id: Session identifier + session_id (str): Session identifier Returns: - Safe session information or None if not found + dict[str, Any] | None: Safe session information or None if not found """ session = self.sessions.get(session_id) if not session: @@ -284,7 +282,7 @@ def get_stats(self) -> dict[str, Any]: """Get session manager statistics. Returns: - Dictionary with session statistics + dict[str, Any]: Dictionary with session statistics """ datetime.now(tz=timezone.utc) active_sessions = [] @@ -308,7 +306,7 @@ def _count_by_tier(self) -> dict[str, int]: """Count sessions by user tier. Returns: - Dictionary with tier counts + dict[str, int]: Dictionary with tier counts """ tier_counts: dict[str, int] = {"admin": 0, "full": 0, "limited": 0} diff --git a/packages/cloudflare-auth/src/cloudflare_auth/utils.py b/packages/cloudflare-auth/src/cloudflare_auth/utils.py index bd83c53..a79a4e8 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/utils.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/utils.py @@ -40,13 +40,13 @@ def sanitize_for_logging( - Converting to string safely Args: - value: Value to sanitize (any type) - max_length: Maximum length of output (default: 200) - replace_newlines: Replace newlines with space (default: True) - replace_control_chars: Replace control chars with � (default: True) + value (Any): Value to sanitize (any type) + max_length (int): Maximum length of output (default: 200) + replace_newlines (bool): Replace newlines with space (default: True) + replace_control_chars (bool): Replace control chars with replacement chars� (default: True) Returns: - Sanitized string safe for logging + str: Sanitized string safe for logging Example: >>> sanitize_for_logging("user@example.com\\nINJECTED LINE") @@ -86,11 +86,11 @@ def sanitize_email(email: str, max_length: int = 254) -> str: Validates email format and sanitizes for safe logging. Args: - email: Email address to sanitize - max_length: Maximum email length (default: 254 per RFC 5321) + email (str): Email address to sanitize + max_length (int): Maximum email length (default: 254 per RFC 5321) Returns: - Sanitized email address + str: Sanitized email address Example: >>> sanitize_email("user@example.com") @@ -111,11 +111,11 @@ def sanitize_path(path: str, max_length: int = 200) -> str: r"""Sanitize URL path for logging. Args: - path: URL path to sanitize - max_length: Maximum path length + path (str): URL path to sanitize + max_length (int): Maximum path length Returns: - Sanitized path + str: Sanitized path Example: >>> sanitize_path("/api/users/123") @@ -130,11 +130,11 @@ def sanitize_ip(ip: str, max_length: int = 45) -> str: r"""Sanitize IP address for logging. Args: - ip: IP address to sanitize - max_length: Maximum length (45 for IPv6) + ip (str): IP address to sanitize + max_length (int): Maximum length (45 for IPv6) Returns: - Sanitized IP address + str: Sanitized IP address Example: >>> sanitize_ip("192.168.1.1") @@ -165,12 +165,12 @@ def sanitize_dict_for_logging( """Sanitize dictionary for safe logging. Args: - data: Dictionary to sanitize - max_value_length: Maximum length for each value - excluded_keys: Keys to exclude (e.g., 'password', 'token') + data (dict[str, Any]): Dictionary to sanitize + max_value_length (int): Maximum length for each value + excluded_keys (set[str] | None): Keys to exclude (e.g., 'password', 'token') Returns: - Sanitized dictionary with string values + dict[str, str]: Sanitized dictionary with string values Example: >>> sanitize_dict_for_logging( @@ -210,11 +210,11 @@ def mask_sensitive_data( """Mask sensitive data in text using regex pattern. Args: - text: Text potentially containing sensitive data - pattern: Regex pattern to match sensitive data (default: email pattern) + text (str): Text potentially containing sensitive data + pattern (str): Regex pattern to match sensitive data (default: email pattern) Returns: - Text with sensitive data masked + str: Text with sensitive data masked Example: >>> mask_sensitive_data("Contact user@example.com for help") @@ -244,7 +244,7 @@ def get_client_ip(request) -> str: request: FastAPI/Starlette Request object Returns: - Client IP address string + str: Client IP address string Example: >>> from fastapi import Request diff --git a/packages/cloudflare-auth/src/cloudflare_auth/validators.py b/packages/cloudflare-auth/src/cloudflare_auth/validators.py index 4cb1cd4..80f7776 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/validators.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/validators.py @@ -50,10 +50,8 @@ class CloudflareJWTValidator: - Expiration checking - Claim extraction and validation - Attributes: - settings: Cloudflare configuration settings - jwks_client: Client for fetching JWT signing keys - _last_key_refresh: Timestamp of last key refresh + Args: + settings (CloudflareSettings | None): Optional CloudflareSettings instance (uses default if not provided) Example: validator = CloudflareJWTValidator() @@ -65,11 +63,6 @@ class CloudflareJWTValidator: """ def __init__(self, settings: CloudflareSettings | None = None) -> None: - """Initialize JWT validator. - - Args: - settings: Optional CloudflareSettings instance (uses default if not provided) - """ self.settings = settings or get_cloudflare_settings() if not self.settings.cloudflare_team_domain: @@ -104,11 +97,11 @@ def validate_token( 5. Required claims presence Args: - token: JWT token string from Cf-Access-Jwt-Assertion header - verify_exp: Whether to verify token expiration (default: True) + token (str): JWT token string from Cf-Access-Jwt-Assertion header + verify_exp (bool): Whether to verify token expiration (default: True) Returns: - CloudflareJWTClaims object with validated claims + CloudflareJWTClaims: CloudflareJWTClaims object with validated claims Raises: ValueError: If token is invalid, expired, or claims are missing @@ -211,7 +204,7 @@ def _validate_required_claims(self, payload: dict[str, Any]) -> None: """Validate that required claims are present. Args: - payload: Decoded JWT payload + payload (dict[str, Any]): Decoded JWT payload Raises: ValueError: If required claims are missing @@ -236,14 +229,11 @@ async def validate_token_async( is CPU-bound and not truly async. Args: - token: JWT token string - verify_exp: Whether to verify token expiration + token (str): JWT token string + verify_exp (bool): Whether to verify token expiration Returns: - CloudflareJWTClaims object with validated claims - - Raises: - ValueError: If token is invalid + CloudflareJWTClaims: CloudflareJWTClaims object with validated claims """ # JWT validation is CPU-bound, not I/O bound # But we provide async interface for consistency @@ -275,7 +265,7 @@ def is_configured(self) -> bool: """Check if validator is properly configured. Returns: - True if validator has necessary configuration + bool: True if validator has necessary configuration """ return bool( self.settings.cloudflare_team_domain @@ -290,10 +280,10 @@ def get_unverified_claims(self, token: str) -> dict[str, Any]: Only use for debugging or non-security-critical inspection. Args: - token: JWT token string + token (str): JWT token string Returns: - Dictionary of unverified claims + dict[str, Any]: Dictionary of unverified claims Example: # For debugging only diff --git a/packages/cloudflare-auth/src/cloudflare_auth/whitelist.py b/packages/cloudflare-auth/src/cloudflare_auth/whitelist.py index b6208de..f187e43 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/whitelist.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/whitelist.py @@ -65,10 +65,10 @@ def from_string(cls, value: str) -> "UserTier": """Create UserTier from string value. Args: - value: String representation of tier + value (str): String representation of tier Returns: - UserTier enum value + UserTier: UserTier enum value Raises: ValueError: If value is not a valid tier @@ -85,7 +85,7 @@ def can_access_premium_models(self) -> bool: """Check if tier allows access to premium models. Returns: - True for ADMIN and FULL tiers, False for LIMITED + bool: True for ADMIN and FULL tiers, False for LIMITED """ return self in (UserTier.ADMIN, UserTier.FULL) @@ -94,7 +94,7 @@ def has_admin_privileges(self) -> bool: """Check if tier has admin privileges. Returns: - True only for ADMIN tier + bool: True only for ADMIN tier """ return self == UserTier.ADMIN @@ -104,10 +104,10 @@ class WhitelistEntry: """Represents a whitelist entry with metadata. Attributes: - value: Email address or domain pattern (@domain.com) - is_domain: Whether this is a domain pattern vs individual email - added_at: ISO timestamp when entry was added - description: Optional description of why this entry exists + value (str): Email address or domain pattern (@domain.com) + is_domain (bool): Whether this is a domain pattern vs individual email + added_at (str): ISO timestamp when entry was added + description (str | None): Optional description of why this entry exists """ value: str @@ -123,11 +123,11 @@ class EmailWhitelistConfig(BaseModel): from environment variables or config files. Attributes: - whitelist: List of allowed emails/domains - admin_emails: List of emails with admin privileges - full_users: List of emails with full tier access - limited_users: List of emails with limited tier access - case_sensitive: Whether email matching is case-sensitive + whitelist (list[str]): List of allowed emails/domains + admin_emails (list[str]): List of emails with admin privileges + full_users (list[str]): List of emails with full tier access + limited_users (list[str]): List of emails with limited tier access + case_sensitive (bool): Whether email matching is case-sensitive """ whitelist: list[str] = [] @@ -144,10 +144,10 @@ def normalize_emails(cls, v: str | list[str]) -> list[str]: """Normalize email addresses to lowercase unless case_sensitive. Args: - v: String (comma-separated) or list of emails + v (str | list[str]): String (comma-separated) or list of emails Returns: - List of normalized email addresses + list[str]: List of normalized email addresses """ if isinstance(v, str): v = [email.strip() for email in v.split(",") if email.strip()] @@ -163,6 +163,13 @@ class EmailWhitelistValidator: - Admin privilege detection for specific admin emails - User tier assignment (admin/full/limited) + Args: + whitelist (list[str]): List of allowed emails and domains + admin_emails (list[str] | None): List of emails with admin privileges + full_users (list[str] | None): List of emails with full tier access + limited_users (list[str] | None): List of emails with limited tier access + case_sensitive (bool): Whether email matching should be case-sensitive + Example: validator = EmailWhitelistValidator( whitelist=["user@example.com", "@company.com"], @@ -186,15 +193,6 @@ def __init__( limited_users: list[str] | None = None, case_sensitive: bool = False, ) -> None: - """Initialize the email whitelist validator. - - Args: - whitelist: List of allowed emails and domains - admin_emails: List of emails with admin privileges - full_users: List of emails with full tier access - limited_users: List of emails with limited tier access - case_sensitive: Whether email matching should be case-sensitive - """ self.case_sensitive = case_sensitive self.admin_emails = self._normalize_emails(admin_emails or []) self.full_users = self._normalize_emails(full_users or []) @@ -224,10 +222,10 @@ def _normalize_emails(self, emails: list[str]) -> list[str]: """Normalize email list based on case sensitivity setting. Args: - emails: List of email addresses to normalize + emails (list[str]): List of email addresses to normalize Returns: - Normalized list of emails + list[str]: Normalized list of emails """ if not emails: return [] @@ -244,10 +242,10 @@ def _normalize_email(self, email: str) -> str: """Normalize a single email address. Args: - email: Email address to normalize + email (str): Email address to normalize Returns: - Normalized email address + str: Normalized email address """ return email.strip() if self.case_sensitive else email.strip().lower() @@ -257,10 +255,10 @@ def is_authorized(self, email: str) -> bool: Uses constant-time comparison to prevent timing attacks. Args: - email: Email address to validate + email (str): Email address to validate Returns: - True if email is authorized, False otherwise + bool: True if email is authorized, False otherwise """ if not email: return False @@ -292,10 +290,10 @@ def is_admin(self, email: str) -> bool: Uses constant-time comparison to prevent timing attacks. Args: - email: Email address to check + email (str): Email address to check Returns: - True if email is an admin, False otherwise + bool: True if email is an admin, False otherwise """ if not email: return False @@ -314,10 +312,10 @@ def get_user_role(self, email: str) -> str: """Get user role based on email. Args: - email: Email address to check + email (str): Email address to check Returns: - 'admin', 'user', or 'unauthorized' + str: 'admin', 'user', or 'unauthorized' """ if not self.is_authorized(email): return "unauthorized" @@ -331,10 +329,10 @@ def get_user_tier(self, email: str) -> UserTier: First checks exact email matches, then domain patterns. Args: - email: Email address to check + email (str): Email address to check Returns: - UserTier enum value + UserTier: UserTier enum value Raises: ValueError: If email is not authorized @@ -382,10 +380,10 @@ def can_access_premium_models(self, email: str) -> bool: """Check if email can access premium models. Args: - email: Email address to check + email (str): Email address to check Returns: - True if user can access premium models, False otherwise + bool: True if user can access premium models, False otherwise """ try: tier = self.get_user_tier(email) @@ -397,10 +395,10 @@ def has_admin_privileges(self, email: str) -> bool: """Check if email has admin privileges. Args: - email: Email address to check + email (str): Email address to check Returns: - True if user has admin privileges, False otherwise + bool: True if user has admin privileges, False otherwise """ try: tier = self.get_user_tier(email) @@ -412,7 +410,7 @@ def get_whitelist_stats(self) -> dict[str, Any]: """Get statistics about the current whitelist configuration. Returns: - Dictionary with whitelist statistics + dict[str, Any]: Dictionary with whitelist statistics """ return { "individual_emails": len(self.individual_emails), @@ -434,7 +432,7 @@ def _check_empty_whitelist(self) -> list[str]: """Check if whitelist is empty. Returns: - List of warning messages if whitelist is empty, empty list otherwise. + list[str]: List of warning messages if whitelist is empty, empty list otherwise. """ if not self.individual_emails and not self.domain_patterns: return ["Whitelist is empty - no users will be authorized"] @@ -444,11 +442,11 @@ def _check_tier_authorization(self, emails: list[str], tier_name: str) -> list[s """Check if tier emails are authorized in whitelist. Args: - emails: List of email addresses to check. - tier_name: Name of the tier for warning messages. + emails (list[str]): List of email addresses to check. + tier_name (str): Name of the tier for warning messages. Returns: - List of warning messages for unauthorized emails. + list[str]: List of warning messages for unauthorized emails. """ warnings = [] for email in emails: @@ -460,7 +458,7 @@ def _check_tier_conflicts(self) -> list[str]: """Check for emails assigned to multiple tiers. Returns: - List of warning messages for emails in multiple tiers. + list[str]: List of warning messages for emails in multiple tiers. """ warnings = [] all_tier_emails = ( @@ -485,7 +483,7 @@ def _check_public_domains(self) -> list[str]: """Check for potentially insecure public email domains. Returns: - List of warning messages if public domains are in whitelist. + list[str]: List of warning messages if public domains are in whitelist. """ public_domains = {"@gmail.com", "@outlook.com"} if self.domain_patterns & public_domains: @@ -498,7 +496,7 @@ def validate_whitelist_config(self) -> list[str]: """Validate the whitelist configuration and return any warnings. Returns: - List of warning messages about the configuration + list[str]: List of warning messages about the configuration """ warnings = [] warnings.extend(self._check_empty_whitelist()) @@ -518,6 +516,9 @@ class WhitelistManager: Provides runtime management of whitelist entries. Note that changes are not persisted - for permanent changes, update configuration. + Args: + validator (EmailWhitelistValidator): EmailWhitelistValidator instance to manage + Example: manager = WhitelistManager(validator) manager.add_email("newuser@company.com") @@ -525,18 +526,13 @@ class WhitelistManager: """ def __init__(self, validator: EmailWhitelistValidator) -> None: - """Initialize whitelist manager with a validator. - - Args: - validator: EmailWhitelistValidator instance to manage - """ self.validator = validator def _validate_empty_input(self, email: str) -> None: """Validate that email input is not empty. Args: - email: Email string to validate + email (str): Email string to validate Raises: ValueError: If email is empty or whitespace only @@ -549,10 +545,10 @@ def _validate_email_with_library(self, email: str) -> str: """Validate email using email-validator library. Args: - email: Email to validate + email (str): Email to validate Returns: - Normalized email address + str: Normalized email address Raises: RuntimeError: If email-validator is not available. @@ -572,7 +568,7 @@ def _validate_email_basic(self, email: str) -> None: """Basic email validation without email-validator library. Args: - email: Email to validate + email (str): Email to validate Raises: ValueError: If email format is invalid @@ -593,10 +589,10 @@ def _validate_email_format(self, email: str) -> str: May raise ValueError from helper methods if email format is invalid. Args: - email: Email address to validate + email (str): Email address to validate Returns: - Normalized email address + str: Normalized email address """ if _email_validator_available and _validate_email_func is not None: return self._validate_email_with_library(email) @@ -607,7 +603,7 @@ def _validate_domain_pattern(self, pattern: str) -> None: """Validate domain pattern format. Args: - pattern: Domain pattern to validate (e.g., @domain.tld) + pattern (str): Domain pattern to validate (e.g., @domain.tld) Raises: ValueError: If domain pattern is invalid @@ -627,9 +623,9 @@ def _add_to_collections( """Add email to appropriate whitelist collections. Args: - normalized_email: Normalized email or domain pattern - is_admin: Whether to add to admin list - original_email: Original email for logging + normalized_email (str): Normalized email or domain pattern + is_admin (bool): Whether to add to admin list + original_email (str): Original email for logging """ if normalized_email.startswith("@"): self.validator.domain_patterns.add(normalized_email) @@ -648,11 +644,11 @@ def add_email(self, email: str, is_admin: bool = False) -> bool: update the configuration file directly. Args: - email: Email or domain pattern to add - is_admin: Whether email should have admin privileges + email (str): Email or domain pattern to add + is_admin (bool): Whether email should have admin privileges Returns: - True if email was added successfully + bool: True if email was added successfully Raises: ValueError: If email format is invalid @@ -680,10 +676,10 @@ def remove_email(self, email: str) -> bool: """Remove email from whitelist (runtime operation). Args: - email: Email or domain pattern to remove + email (str): Email or domain pattern to remove Returns: - True if email was removed successfully + bool: True if email was removed successfully """ try: normalized_email = self.validator._normalize_email(email) @@ -717,10 +713,10 @@ def check_email(self, email: str) -> dict[str, Any]: """Check email status and provide detailed information. Args: - email: Email to check + email (str): Email to check Returns: - Dictionary with email status information + dict[str, Any]: Dictionary with email status information """ return { "email": email, @@ -743,13 +739,13 @@ def create_validator_from_env( comma-separated environment variable values. Args: - whitelist_str: Comma-separated string of emails/domains - admin_emails_str: Comma-separated string of admin emails - full_users_str: Comma-separated string of full tier users - limited_users_str: Comma-separated string of limited tier users + whitelist_str (str): Comma-separated string of emails/domains + admin_emails_str (str): Comma-separated string of admin emails + full_users_str (str): Comma-separated string of full tier users + limited_users_str (str): Comma-separated string of limited tier users Returns: - Configured EmailWhitelistValidator + EmailWhitelistValidator: Configured EmailWhitelistValidator Example: validator = create_validator_from_env( From 780db9defad6129934c14874e358529833a0a109 Mon Sep 17 00:00:00 2001 From: Byron Williams Date: Sat, 30 May 2026 22:47:30 -0700 Subject: [PATCH 8/8] docs(cloudflare-auth): add TYPE_CHECKING annotations for unannotated request params Co-Authored-By: Claude Sonnet 4.6 --- packages/cloudflare-auth/src/cloudflare_auth/csrf.py | 8 ++++++-- packages/cloudflare-auth/src/cloudflare_auth/utils.py | 9 ++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/packages/cloudflare-auth/src/cloudflare_auth/csrf.py b/packages/cloudflare-auth/src/cloudflare_auth/csrf.py index 22a3d62..e8ffd47 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/csrf.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/csrf.py @@ -20,6 +20,10 @@ import hmac import logging import secrets +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastapi import Request logger = logging.getLogger(__name__) @@ -128,13 +132,13 @@ def validate_token( def validate_request( self, - request, # noqa: ANN001 - FastAPI/Starlette Request - not annotated to avoid import + request: "Request", methods_to_protect: set[str] | None = None, ) -> bool: """Validate CSRF token for a request. Args: - request: FastAPI/Starlette Request object + request (Request): FastAPI/Starlette Request object methods_to_protect (set[str] | None): HTTP methods that require CSRF validation Returns: diff --git a/packages/cloudflare-auth/src/cloudflare_auth/utils.py b/packages/cloudflare-auth/src/cloudflare_auth/utils.py index a79a4e8..9f3c246 100644 --- a/packages/cloudflare-auth/src/cloudflare_auth/utils.py +++ b/packages/cloudflare-auth/src/cloudflare_auth/utils.py @@ -18,7 +18,10 @@ """ import re -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from fastapi import Request # Patterns for dangerous characters in logs CONTROL_CHARS_PATTERN = re.compile(r"[\x00-\x1f\x7f-\x9f]") @@ -232,7 +235,7 @@ def mask_match(match): return re.sub(pattern, mask_match, text) -def get_client_ip(request) -> str: +def get_client_ip(request: "Request") -> str: """Extract client IP address from request. SECURITY NOTE: Only trusts CF-Connecting-IP header from Cloudflare. @@ -241,7 +244,7 @@ def get_client_ip(request) -> str: Cloudflare Access. Args: - request: FastAPI/Starlette Request object + request (Request): FastAPI/Starlette Request object Returns: str: Client IP address string