diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 7fe6ba2..de90d54 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -40,13 +40,17 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[dev] - pip install flake8 flake8-pyproject pytest + pip install flake8 flake8-pyproject pyright pytest - name: Lint with flake8 run: | python -m flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics python -m flake8 . --exit-zero --max-complexity=12 --statistics + - name: Type check with pyright + run: | + pyright + - name: Test with pytest run: | python -m pytest --verbose diff --git a/.gitignore b/.gitignore index 378fbcf..89c1f6e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ __pycache__/ __pypackages__/ .mypy_cache/ .pytest_cache/ -*.py[cod] +*.py[cdio] *$py.class # BUILD ARTIFACTS diff --git a/CHANGELOG.md b/CHANGELOG.md index a4997b2..650d4f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,12 +15,28 @@ #
Changelog
+ + +## 13.04.2026 `v1.9.6` + +* The compiled version of the library now includes the type stub files (`.pyi`), so type checkers can properly check types. +* Made all type hints in the whole library way more strict and accurate. +* Removed leftover unnecessary runtime type-checks in several methods throughout the whole library. + +**BREAKING CHANGES:** +* All methods that should use positional-only and/or keyword-only params, now actually enforce that by using the `/` and `*` syntax in the method definitions. +* Renamed the `Spinner` class from the `console` module to `Throbber`, since that name is closer to what it's actually used for. +* Changed the name of the TypeAlias `DataStructure` to `DataObj` because that name is shorter and more general. +* Changed both names `DataStructureTypes` and `IndexIterableTypes` to `DataObjTT` and `IndexIterableTT` respectively (`TT` *stands for type-tuple*). +* Made the return value of `String.single_char_repeats()` always be *`int`* and not *int* | *bool*. + + ## 25.01.2026 `v1.9.5` -* Add new class property `Console.encoding`, which returns the encoding used by the console (*e.g.* `utf-8`*,* `cp1252`*, …*). -* Add multiple new class properties to the `System` class: +* Added a new class property `Console.encoding`, which returns the encoding used by the console (*e.g.* `utf-8`*,* `cp1252`*, …*). +* Added multiple new class properties to the `System` class: - `is_linux` Whether the current OS is Linux or not. - `is_mac` Whether the current OS is macOS or not. - `is_unix` Whether the current OS is a Unix-like OS (Linux, macOS, BSD, …) or not. @@ -482,10 +498,10 @@ ## 21.12.2024 `v1.5.9` * Fixed bugs in method `to_ansi()` in module `xx_format_codes`:
- 1. The method always returned an empty string, because the color validation was broken, and it would identify all colors as invalid.
- Now the validation `Color.is_valid_rgba()` and `Color.is_valid_hexa()` are fixed and now, if a color is identified as invalid, the method returns the original string instead of an empty string. - 2. Previously the method `to_ansi()` couldn't handle formats inside `[]` because everything inside the brackets was recognized as an invalid format.
- Now you are able to use formats inside `[]` (*e.g.* `"[[red](Red text [b](inside) square brackets!)]"`). + 1. The method always returned an empty string, because the color validation was broken, and it would identify all colors as invalid.
+ Now the validation `Color.is_valid_rgba()` and `Color.is_valid_hexa()` are fixed and now, if a color is identified as invalid, the method returns the original string instead of an empty string. + 2. Previously the method `to_ansi()` couldn't handle formats inside `[]` because everything inside the brackets was recognized as an invalid format.
+ Now you are able to use formats inside `[]` (*e.g.* `"[[red](Red text [b](inside) square brackets!)]"`). * Introduced a new test for the `xx_format_codes` module. * Fixed a small bug in the help client-command:
Added back the default text color. @@ -849,8 +865,8 @@ ## 15.10.2024 `v1.0.1` `v1.0.2` `v1.0.3` `v1.0.4` `v1.0.5` * Fixed `f-string` issues for Python 3.10: - 1. Not making use of same quotes inside f-strings any more. - 2. No backslash escaping in f-strings. + 1. Not making use of same quotes inside f-strings any more. + 2. No backslash escaping in f-strings. diff --git a/pyproject.toml b/pyproject.toml index 0aa89b2..037b4e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ build-backend = "setuptools.build_meta" [project] name = "xulbux" -version = "1.9.5" +version = "1.9.6" description = "A Python library to simplify common programming tasks." readme = "README.md" authors = [{ name = "XulbuX", email = "xulbux.real@gmail.com" }] @@ -30,6 +30,7 @@ dependencies = [ optional-dependencies = { dev = [ "flake8-pyproject>=1.2.3", "flake8>=6.1.0", + "pyright>=1.1.408", "pytest>=7.4.2", "toml>=0.10.2", ] } @@ -121,8 +122,8 @@ xulbux-help = "xulbux.cli.help:show_help" max-complexity = 12 max-line-length = 127 select = ["E", "F", "W", "C90"] -extend-ignore = ["E203", "E266", "E502", "W503"] -per-file-ignores = ["__init__.py:F403,F405"] +extend-ignore = ["E124", "E203", "E266", "E502", "W503"] +per-file-ignores = ["__init__.py:F403,F405", "types.py:E302,E305"] [tool.setuptools] package-dir = { "" = "src" } @@ -130,6 +131,9 @@ package-dir = { "" = "src" } [tool.setuptools.packages.find] where = ["src"] +[tool.setuptools.package-data] +xulbux = ["py.typed", "*.pyi", "**/*.pyi"] + [tool.pytest.ini_options] minversion = "7.0" addopts = "-ra -q" @@ -150,3 +154,7 @@ testpaths = [ "tests/test_string.py", "tests/test_system.py", ] + +[tool.pyright] +include = ["src", "tests"] +typeCheckingMode = "strict" diff --git a/setup.py b/setup.py index 5761a66..deb0e04 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,8 @@ from setuptools import setup from pathlib import Path +import subprocess +import shutil +import sys import os @@ -10,14 +13,71 @@ def find_python_files(directory: str) -> list[str]: return python_files -# OPTIONALLY USE MYPYC COMPILATION +def generate_stubs_for_package(): + print("\nGenerating stub files with stubgen...\n") + + try: + skip_stubgen = { + Path("src/xulbux/base/types.py"), # COMPLEX TYPE DEFINITIONS + Path("src/xulbux/__init__.py"), # PRESERVE PACKAGE METADATA CONSTANTS + } + + src_dir = Path("src/xulbux") + generated_count = 0 + skipped_count = 0 + + for py_file in src_dir.rglob("*.py"): + pyi_file = py_file.with_suffix(".pyi") + rel_path = py_file.relative_to(src_dir.parent) + + if py_file in skip_stubgen: + pyi_file.write_text(py_file.read_text(encoding="utf-8"), encoding="utf-8") + print(f" copied {rel_path.with_suffix('.pyi')} (preserving type definitions)") + skipped_count += 1 + continue + + stubgen_exe = ( + shutil.which("stubgen") + or str(Path(sys.executable).parent / ("stubgen.exe" if sys.platform == "win32" else "stubgen")) + ) + result = subprocess.run( + [stubgen_exe, + str(py_file), + "-o", "src", + "--include-private", + "--export-less"], + capture_output=True, + text=True + ) + + if result.returncode == 0: + print(f" generated {rel_path.with_suffix('.pyi')}") + generated_count += 1 + else: + print(f" failed {rel_path}") + if result.stderr: + print(f" {result.stderr.strip()}") + + print(f"\nStub generation complete. ({generated_count} generated, {skipped_count} copied)\n") + + except Exception as e: + fmt_error = "\n ".join(str(e).splitlines()) + print(f"[WARNING] Could not generate stubs:\n {fmt_error}\n") + + ext_modules = [] + +# OPTIONALLY USE MYPYC COMPILATION if os.environ.get("XULBUX_USE_MYPYC", "1") == "1": try: from mypyc.build import mypycify + print("\nCompiling with mypyc...\n") source_files = find_python_files("src/xulbux") - ext_modules = mypycify(source_files) + ext_modules = mypycify(source_files, opt_level="3") + print("\nMypyc compilation complete.\n") + + generate_stubs_for_package() except (ImportError, Exception) as e: fmt_error = "\n ".join(str(e).splitlines()) diff --git a/src/xulbux/__init__.py b/src/xulbux/__init__.py index 2f95e29..16e5bfd 100644 --- a/src/xulbux/__init__.py +++ b/src/xulbux/__init__.py @@ -1,5 +1,5 @@ __package_name__ = "xulbux" -__version__ = "1.9.5" +__version__ = "1.9.6" __description__ = "A Python library to simplify common programming tasks." __status__ = "Production/Stable" diff --git a/src/xulbux/base/consts.py b/src/xulbux/base/consts.py index 8d5425a..0fd895c 100644 --- a/src/xulbux/base/consts.py +++ b/src/xulbux/base/consts.py @@ -88,7 +88,7 @@ class ANSI: """End of an ANSI escape sequence.""" @classmethod - def seq(cls, placeholders: int = 1) -> FormattableString: + def seq(cls, placeholders: int = 1, /) -> FormattableString: """Generates an ANSI escape sequence with the specified number of placeholders.""" return cls.CHAR + cls.START + cls.SEP.join(["{}" for _ in range(placeholders)]) + cls.END diff --git a/src/xulbux/base/exceptions.py b/src/xulbux/base/exceptions.py index 352c929..4b0d5d6 100644 --- a/src/xulbux/base/exceptions.py +++ b/src/xulbux/base/exceptions.py @@ -4,9 +4,8 @@ from .decorators import mypyc_attr -# -################################################## FILE ################################################## +################################################## FILE ################################################## @mypyc_attr(native_class=False) class SameContentFileExistsError(FileExistsError): @@ -16,7 +15,6 @@ class SameContentFileExistsError(FileExistsError): ################################################## PATH ################################################## - @mypyc_attr(native_class=False) class PathNotFoundError(FileNotFoundError): """Raised when a file system path does not exist or cannot be accessed.""" diff --git a/src/xulbux/base/types.py b/src/xulbux/base/types.py index abd1562..b61c42e 100644 --- a/src/xulbux/base/types.py +++ b/src/xulbux/base/types.py @@ -2,14 +2,10 @@ This module contains all custom type definitions used throughout the library. """ -from typing import TYPE_CHECKING, Annotated, TypeAlias, TypedDict, Optional, Protocol, Literal, Union, Any +from typing import Annotated, TypeAlias, TypedDict, Optional, Protocol, Literal, Union, Any from pathlib import Path -# PREVENT CIRCULAR IMPORTS -if TYPE_CHECKING: - from ..color import rgba, hsla, hexa -# ################################################## Annotated ################################################## Int_0_100 = Annotated[int, "Integer constrained to the range [0, 100] inclusive."] @@ -24,60 +20,80 @@ FormattableString = Annotated[str, "String made to be formatted with the `.format()` method."] """String made to be formatted with the `.format()` method.""" -# + ################################################## TypeAlias ################################################## -PathsList: TypeAlias = Union[list[Path], list[str], list[Path | str]] +PathsList: TypeAlias = Union[list[Path], list[str], list[Union[Path, str]]] """Union of all supported list types for a list of paths.""" -DataStructure: TypeAlias = Union[list, tuple, set, frozenset, dict] +DataObj: TypeAlias = Union[list[Any], tuple[Any, ...], set[Any], frozenset[Any], dict[Any, Any]] """Union of supported data structures used in the `data` module.""" -DataStructureTypes = (list, tuple, set, frozenset, dict) +DataObjTT = (list, tuple, set, frozenset, dict) """Tuple of supported data structures used in the `data` module.""" -IndexIterable: TypeAlias = Union[list, tuple, set, frozenset] +IndexIterable: TypeAlias = Union[list[Any], tuple[Any, ...], set[Any], frozenset[Any]] """Union of all iterable types that support indexing operations.""" -IndexIterableTypes = (list, tuple, set, frozenset) +IndexIterableTT = (list, tuple, set, frozenset) """Tuple of all iterable types that support indexing operations.""" +class _RgbaObj(Protocol): + """Protocol for rgba-like color objects (structurally matches `rgba`).""" + r: int + g: int + b: int + a: Optional[float] + +class _HslaObj(Protocol): + """Protocol for hsla-like color objects (structurally matches `hsla`).""" + h: int + s: int + l: int + a: Optional[float] + +class _HexaObj(Protocol): + """Protocol for hexa-like color objects (structurally matches `hexa`).""" + r: int + g: int + b: int + a: Optional[float] + Rgba: TypeAlias = Union[ tuple[Int_0_255, Int_0_255, Int_0_255], - tuple[Int_0_255, Int_0_255, Int_0_255, Float_0_1], + tuple[Int_0_255, Int_0_255, Int_0_255, Optional[Float_0_1]], list[Int_0_255], - list[Union[Int_0_255, Float_0_1]], - dict[str, Union[int, float]], - "rgba", + list[Union[Int_0_255, Optional[Float_0_1]]], + "RgbaDict", + _RgbaObj, str, ] """Matches all supported RGBA color value formats.""" Hsla: TypeAlias = Union[ tuple[Int_0_360, Int_0_100, Int_0_100], - tuple[Int_0_360, Int_0_100, Int_0_100, Float_0_1], + tuple[Int_0_360, Int_0_100, Int_0_100, Optional[Float_0_1]], list[Union[Int_0_360, Int_0_100]], - list[Union[Int_0_360, Int_0_100, Float_0_1]], - dict[str, Union[int, float]], - "hsla", + list[Union[Int_0_360, Int_0_100, Optional[Float_0_1]]], + "HslaDict", + _HslaObj, str, ] """Matches all supported HSLA color value formats.""" -Hexa: TypeAlias = Union[str, int, "hexa"] -"""Matches all supported hexadecimal color value formats.""" +Hexa: TypeAlias = Union[str, int, _HexaObj] +"""Matches all supported HEXA color value formats.""" AnyRgba: TypeAlias = Any -"""Generic type alias for RGBA color values in any supported format (type checking disabled).""" +"""Generic type alias for RGBA color values in any format (type checking disabled).""" AnyHsla: TypeAlias = Any -"""Generic type alias for HSLA color values in any supported format (type checking disabled).""" +"""Generic type alias for HSLA color values in any format (type checking disabled).""" AnyHexa: TypeAlias = Any -"""Generic type alias for hexadecimal color values in any supported format (type checking disabled).""" +"""Generic type alias for HEXA color values in any format (type checking disabled).""" ArgParseConfig: TypeAlias = Union[set[str], "ArgConfigWithDefault", Literal["before", "after"]] """Matches the command-line-parsing configuration of a single argument.""" ArgParseConfigs: TypeAlias = dict[str, ArgParseConfig] """Matches the command-line-parsing configurations of multiple arguments, packed in a dictionary.""" -# -################################################## Sentinel ################################################## +################################################## Sentinel ################################################## class AllTextChars: """Sentinel class indicating all characters are allowed.""" @@ -86,13 +102,11 @@ class AllTextChars: ################################################## TypedDict ################################################## - class ArgConfigWithDefault(TypedDict): """Configuration schema for a flagged command-line argument that has a specified default value.""" flags: set[str] default: str - class ArgData(TypedDict): """Schema for the resulting data of parsing a single command-line argument.""" exists: bool @@ -100,6 +114,26 @@ class ArgData(TypedDict): values: list[str] flag: Optional[str] +class RgbaDict(TypedDict): + """Dictionary schema for RGBA color components.""" + r: Int_0_255 + g: Int_0_255 + b: Int_0_255 + a: Optional[Float_0_1] + +class HslaDict(TypedDict): + """Dictionary schema for HSLA color components.""" + h: Int_0_360 + s: Int_0_100 + l: Int_0_100 + a: Optional[Float_0_1] + +class HexaDict(TypedDict): + """Dictionary schema for HEXA color components.""" + r: str + g: str + b: str + a: Optional[str] class MissingLibsMsgs(TypedDict): """Configuration schema for custom messages in `System.check_libs()` when checking library dependencies.""" @@ -109,7 +143,6 @@ class MissingLibsMsgs(TypedDict): ################################################## Protocol ################################################## - class ProgressUpdater(Protocol): """Protocol for a progress updater function used in console progress bars.""" diff --git a/src/xulbux/cli/help.py b/src/xulbux/cli/help.py index 0952187..80e6f41 100644 --- a/src/xulbux/cli/help.py +++ b/src/xulbux/cli/help.py @@ -72,6 +72,6 @@ def is_latest_version() -> Optional[bool]: def show_help() -> None: - FormatCodes._config_console() + FormatCodes._config_console() # type: ignore[protected-access] print(CLI_HELP) - Console.pause_exit(pause=True, prompt=" [dim](Press any key to exit...)\n\n") + Console.pause_exit(" [dim](Press any key to exit...)\n\n", pause=True) diff --git a/src/xulbux/code.py b/src/xulbux/code.py index 1971176..dc789ed 100644 --- a/src/xulbux/code.py +++ b/src/xulbux/code.py @@ -6,6 +6,7 @@ from .regex import Regex from .data import Data +from typing import Any import regex as _rx @@ -13,7 +14,7 @@ class Code: """This class includes methods to work with code strings.""" @classmethod - def add_indent(cls, code: str, indent: int) -> str: + def add_indent(cls, code: str, indent: int, /) -> str: """Adds `indent` spaces at the beginning of each line.\n -------------------------------------------------------------------------- - `code` -⠀the code to indent @@ -24,7 +25,7 @@ def add_indent(cls, code: str, indent: int) -> str: return "\n".join(" " * indent + line for line in code.splitlines()) @classmethod - def get_tab_spaces(cls, code: str) -> int: + def get_tab_spaces(cls, code: str, /) -> int: """Will try to get the amount of spaces used for indentation.\n ---------------------------------------------------------------- - `code` -⠀the code to analyze""" @@ -32,7 +33,7 @@ def get_tab_spaces(cls, code: str) -> int: return min(non_zero_indents) if (non_zero_indents := [i for i in indents if i > 0]) else 0 @classmethod - def change_tab_size(cls, code: str, new_tab_size: int, remove_empty_lines: bool = False) -> str: + def change_tab_size(cls, code: str, new_tab_size: int, /, *, remove_empty_lines: bool = False) -> str: """Replaces all tabs with `new_tab_size` spaces.\n -------------------------------------------------------------------------------- - `code` -⠀the code to modify the tab size of @@ -48,7 +49,7 @@ def change_tab_size(cls, code: str, new_tab_size: int, remove_empty_lines: bool return "\n".join(code_lines) return code - result = [] + result: list[str] = [] for line in code_lines: indent_level = (len(line) - len(stripped := line.lstrip())) // tab_spaces result.append((" " * (indent_level * new_tab_size)) + stripped) @@ -56,11 +57,11 @@ def change_tab_size(cls, code: str, new_tab_size: int, remove_empty_lines: bool return "\n".join(result) @classmethod - def get_func_calls(cls, code: str) -> list: + def get_func_calls(cls, code: str, /) -> list[list[Any]]: """Will try to get all function calls and return them as a list.\n ------------------------------------------------------------------- - `code` -⠀the code to analyze""" - nested_func_calls = [] + nested_func_calls: list[list[Any]] = [] for _, func_attrs in (funcs := _rx.findall(r"(?i)" + Regex.func_call(), code)): if (nested_calls := _rx.findall(r"(?i)" + Regex.func_call(), func_attrs)): @@ -69,7 +70,7 @@ def get_func_calls(cls, code: str) -> list: return list(Data.remove_duplicates(funcs + nested_func_calls)) @classmethod - def is_js(cls, code: str, funcs: set[str] = {"__", "$t", "$lang"}) -> bool: + def is_js(cls, code: str, /, *, funcs: set[str] = {"__", "$t", "$lang"}) -> bool: """Will check if the code is very likely to be JavaScript.\n ------------------------------------------------------------- - `code` -⠀the code to analyze diff --git a/src/xulbux/color.py b/src/xulbux/color.py index 2a7eb4d..dd912ed 100644 --- a/src/xulbux/color.py +++ b/src/xulbux/color.py @@ -6,10 +6,12 @@ includes methods to work with colors in various formats. """ -from .base.types import AnyRgba, AnyHsla, AnyHexa, Rgba, Hsla, Hexa +from __future__ import annotations + +from .base.types import RgbaDict, HslaDict, HexaDict, AnyRgba, AnyHsla, AnyHexa, Rgba, Hsla, Hexa from .regex import Regex -from typing import Iterator, Optional, Literal, cast +from typing import Iterator, Optional, Literal, Any, overload, cast import re as _re @@ -41,7 +43,7 @@ class rgba: - `with_alpha(alpha)` to create a new color with different alpha - `complementary()` to get the complementary color""" - def __init__(self, r: int, g: int, b: int, a: Optional[float] = None, _validate: bool = True): + def __init__(self, r: int, g: int, b: int, a: Optional[float] = None, /, *, _validate: bool = True): self.r: int """The red channel in range [0, 255] inclusive.""" self.g: int @@ -69,19 +71,31 @@ def __len__(self) -> int: """The number of components in the color (3 or 4).""" return 3 if self.a is None else 4 - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[int | Optional[float]]: return iter((self.r, self.g, self.b) + (() if self.a is None else (self.a, ))) - def __getitem__(self, index: int) -> int | float: + @overload + def __getitem__(self, index: Literal[0, 1, 2], /) -> int: + ... + + @overload + def __getitem__(self, index: Literal[3], /) -> Optional[float]: + ... + + @overload + def __getitem__(self, index: int, /) -> int | Optional[float]: + ... + + def __getitem__(self, index: int, /) -> int | Optional[float]: return ((self.r, self.g, self.b) + (() if self.a is None else (self.a, )))[index] - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: """Check if two `rgba` objects are the same color.""" if not isinstance(other, rgba): return False return (self.r, self.g, self.b, self.a) == (other.r, other.g, other.b, other.a) - def __ne__(self, other: object) -> bool: + def __ne__(self, other: object, /) -> bool: """Check if two `rgba` objects are different colors.""" return not self.__eq__(other) @@ -91,87 +105,72 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def dict(self) -> dict: + def dict(self) -> RgbaDict: """Returns the color components as a dictionary with keys `"r"`, `"g"`, `"b"` and optionally `"a"`.""" - return dict(r=self.r, g=self.g, b=self.b) if self.a is None else dict(r=self.r, g=self.g, b=self.b, a=self.a) + return {"r": self.r, "g": self.g, "b": self.b, "a": self.a} - def values(self) -> tuple: + def values(self) -> tuple[int, int, int, Optional[float]]: """Returns the color components as separate values `r, g, b, a`.""" return self.r, self.g, self.b, self.a - def to_hsla(self) -> "hsla": + def to_hsla(self) -> hsla: """Returns the color as `hsla()` color object.""" h, s, l = self._rgb_to_hsl(self.r, self.g, self.b) return hsla(h, s, l, self.a, _validate=False) - def to_hexa(self) -> "hexa": + def to_hexa(self) -> hexa: """Returns the color as `hexa()` color object.""" - return hexa("", self.r, self.g, self.b, self.a) + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) def has_alpha(self) -> bool: """Returns `True` if the color has an alpha channel and `False` otherwise.""" return self.a is not None - def lighten(self, amount: float) -> "rgba": + def lighten(self, amount: float, /) -> rgba: """Increases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().lighten(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def darken(self, amount: float) -> "rgba": + def darken(self, amount: float, /) -> rgba: """Decreases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().darken(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def saturate(self, amount: float) -> "rgba": + def saturate(self, amount: float, /) -> rgba: """Increases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().saturate(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def desaturate(self, amount: float) -> "rgba": + def desaturate(self, amount: float, /) -> rgba: """Decreases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.r, self.g, self.b, self.a = self.to_hsla().desaturate(amount).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def rotate(self, degrees: int) -> "rgba": + def rotate(self, degrees: int, /) -> rgba: """Rotates the colors hue by the specified number of degrees.""" - if not isinstance(degrees, int): - raise TypeError(f"The 'degrees' parameter must be an integer, got {type(degrees)}") - self.r, self.g, self.b, self.a = self.to_hsla().rotate(degrees).to_rgba().values() return rgba(self.r, self.g, self.b, self.a, _validate=False) - def invert(self, invert_alpha: bool = False) -> "rgba": + def invert(self, *, invert_alpha: bool = False) -> rgba: """Inverts the color by rotating hue by 180 degrees and inverting lightness.""" - if not isinstance(invert_alpha, bool): - raise TypeError(f"The 'invert_alpha' parameter must be a boolean, got {type(invert_alpha)}") - self.r, self.g, self.b = 255 - self.r, 255 - self.g, 255 - self.b if invert_alpha and self.a is not None: self.a = 1 - self.a - return rgba(self.r, self.g, self.b, self.a, _validate=False) - def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> "rgba": + def grayscale(self, *, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> rgba: """Converts the color to grayscale using the luminance formula.\n --------------------------------------------------------------------------- - `method` -⠀the luminance calculation method to use: @@ -183,7 +182,7 @@ def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag self.r = self.g = self.b = int(Color.luminance(self.r, self.g, self.b, method=method)) return rgba(self.r, self.g, self.b, self.a, _validate=False) - def blend(self, other: Rgba, ratio: float = 0.5, additive_alpha: bool = False) -> "rgba": + def blend(self, other: Rgba, /, ratio: float = 0.5, *, additive_alpha: bool = False) -> rgba: """Blends the current color with another color using the specified ratio in range [0.0, 1.0] inclusive.\n ---------------------------------------------------------------------------------------------------------- - `other` -⠀the other RGBA color to blend with @@ -192,27 +191,20 @@ def blend(self, other: Rgba, ratio: float = 0.5, additive_alpha: bool = False) - * if `ratio` is `0.5` it means 50% of both colors (1:1 mixture) * if `ratio` is `1.0` it means 0% of the current color and 100% of the `other` color (0:2 mixture) - `additive_alpha` -⠀whether to blend the alpha channels additively or not""" - if not isinstance(other, rgba): - if Color.is_valid_rgba(other): - other = Color.to_rgba(other) - else: - raise TypeError(f"The 'other' parameter must be a valid RGBA color, got {type(other)}") - if not isinstance(ratio, float): - raise TypeError(f"The 'ratio' parameter must be a float, got {type(ratio)}") - elif not (0.0 <= ratio <= 1.0): + if not (0.0 <= ratio <= 1.0): raise ValueError(f"The 'ratio' parameter must be in range [0.0, 1.0] inclusive, got {ratio!r}") - if not isinstance(additive_alpha, bool): - raise TypeError(f"The 'additive_alpha' parameter must be a boolean, got {type(additive_alpha)}") + + other_rgba = Color.to_rgba(other) ratio *= 2 - self.r = int(max(0, min(255, int(round((self.r * (2 - ratio)) + (other.r * ratio)))))) - self.g = int(max(0, min(255, int(round((self.g * (2 - ratio)) + (other.g * ratio)))))) - self.b = int(max(0, min(255, int(round((self.b * (2 - ratio)) + (other.b * ratio)))))) - none_alpha = self.a is None and (len(other) <= 3 or other[3] is None) + self.r = int(max(0, min(255, int((self.r * (2 - ratio)) + (other_rgba.r * ratio) + 0.5)))) + self.g = int(max(0, min(255, int((self.g * (2 - ratio)) + (other_rgba.g * ratio) + 0.5)))) + self.b = int(max(0, min(255, int((self.b * (2 - ratio)) + (other_rgba.b * ratio) + 0.5)))) + none_alpha = self.a is None and (len(other_rgba) <= 3 or other_rgba[3] is None) if not none_alpha: - self_a = 1 if self.a is None else self.a - other_a = (other[3] if other[3] is not None else 1) if len(other) > 3 else 1 + self_a: float = 1.0 if self.a is None else self.a + other_a: float = cast(float, 1.0 if other_rgba[3] is None else other_rgba[3]) if len(other_rgba) > 3 else 1.0 if additive_alpha: self.a = max(0, min(1, (self_a * (2 - ratio)) + (other_a * ratio))) @@ -240,21 +232,19 @@ def is_opaque(self) -> bool: """Returns `True` if the color has no transparency.""" return self.a == 1 or self.a is None - def with_alpha(self, alpha: float) -> "rgba": + def with_alpha(self, alpha: float, /) -> rgba: """Returns a new color with the specified alpha value.""" - if not isinstance(alpha, float): - raise TypeError(f"The 'alpha' parameter must be a float, got {type(alpha)}") - elif not (0.0 <= alpha <= 1.0): + if not (0.0 <= alpha <= 1.0): raise ValueError(f"The 'alpha' parameter must be in range [0.0, 1.0] inclusive, got {alpha!r}") return rgba(self.r, self.g, self.b, alpha, _validate=False) - def complementary(self) -> "rgba": + def complementary(self) -> rgba: """Returns the complementary color (180 degrees on the color wheel).""" return self.to_hsla().complementary().to_rgba() @staticmethod - def _rgb_to_hsl(r: int, g: int, b: int) -> tuple: + def _rgb_to_hsl(r: int, g: int, b: int) -> tuple[int, int, int]: """Internal method to convert RGB to HSL color space.""" _r, _g, _b = r / 255.0, g / 255.0, b / 255.0 max_c, min_c = max(_r, _g, _b), min(_r, _g, _b) @@ -305,7 +295,7 @@ class hsla: - `with_alpha(alpha)` to create a new color with different alpha - `complementary()` to get the complementary color""" - def __init__(self, h: int, s: int, l: int, a: Optional[float] = None, _validate: bool = True): + def __init__(self, h: int, s: int, l: int, a: Optional[float] = None, /, *, _validate: bool = True): self.h: int """The hue channel in range [0, 360] inclusive.""" self.s: int @@ -333,19 +323,31 @@ def __len__(self) -> int: """The number of components in the color (3 or 4).""" return 3 if self.a is None else 4 - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[int | Optional[float]]: return iter((self.h, self.s, self.l) + (() if self.a is None else (self.a, ))) - def __getitem__(self, index: int) -> int | float: + @overload + def __getitem__(self, index: Literal[0, 1, 2], /) -> int: + ... + + @overload + def __getitem__(self, index: Literal[3], /) -> Optional[float]: + ... + + @overload + def __getitem__(self, index: int, /) -> int | Optional[float]: + ... + + def __getitem__(self, index: int, /) -> int | Optional[float]: return ((self.h, self.s, self.l) + (() if self.a is None else (self.a, )))[index] - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: """Check if two `hsla` objects are the same color.""" if not isinstance(other, hsla): return False return (self.h, self.s, self.l, self.a) == (other.h, other.s, other.l, other.a) - def __ne__(self, other: object) -> bool: + def __ne__(self, other: object, /) -> bool: """Check if two `hsla` objects are different colors.""" return not self.__eq__(other) @@ -355,81 +357,67 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def dict(self) -> dict: + def dict(self) -> HslaDict: """Returns the color components as a dictionary with keys `"h"`, `"s"`, `"l"` and optionally `"a"`.""" - return dict(h=self.h, s=self.s, l=self.l) if self.a is None else dict(h=self.h, s=self.s, l=self.l, a=self.a) + return {"h": self.h, "s": self.s, "l": self.l, "a": self.a} - def values(self) -> tuple: + def values(self) -> tuple[int, int, int, Optional[float]]: """Returns the color components as separate values `h, s, l, a`.""" return self.h, self.s, self.l, self.a - def to_rgba(self) -> "rgba": + def to_rgba(self) -> rgba: """Returns the color as `rgba()` color object.""" r, g, b = self._hsl_to_rgb(self.h, self.s, self.l) return rgba(r, g, b, self.a, _validate=False) - def to_hexa(self) -> "hexa": + def to_hexa(self) -> hexa: """Returns the color as `hexa()` color object.""" r, g, b = self._hsl_to_rgb(self.h, self.s, self.l) - return hexa("", r, g, b, self.a) + return hexa(_r=r, _g=g, _b=b, _a=self.a) def has_alpha(self) -> bool: """Returns `True` if the color has an alpha channel and `False` otherwise.""" return self.a is not None - def lighten(self, amount: float) -> "hsla": + def lighten(self, amount: float, /) -> hsla: """Increases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.l = int(min(100, self.l + (100 - self.l) * amount)) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def darken(self, amount: float) -> "hsla": + def darken(self, amount: float, /) -> hsla: """Decreases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.l = int(max(0, self.l * (1 - amount))) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def saturate(self, amount: float) -> "hsla": + def saturate(self, amount: float, /) -> hsla: """Increases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.s = int(min(100, self.s + (100 - self.s) * amount)) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def desaturate(self, amount: float) -> "hsla": + def desaturate(self, amount: float, /) -> hsla: """Decreases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") self.s = int(max(0, self.s * (1 - amount))) return hsla(self.h, self.s, self.l, self.a, _validate=False) - def rotate(self, degrees: int) -> "hsla": + def rotate(self, degrees: int, /) -> hsla: """Rotates the colors hue by the specified number of degrees.""" - if not isinstance(degrees, int): - raise TypeError(f"The 'degrees' parameter must be an integer, got {type(degrees)}") - self.h = (self.h + degrees) % 360 return hsla(self.h, self.s, self.l, self.a, _validate=False) - def invert(self, invert_alpha: bool = False) -> "hsla": + def invert(self, *, invert_alpha: bool = False) -> hsla: """Inverts the color by rotating hue by 180 degrees and inverting lightness.""" - if not isinstance(invert_alpha, bool): - raise TypeError(f"The 'invert_alpha' parameter must be a boolean, got {type(invert_alpha)}") - self.h = (self.h + 180) % 360 self.l = 100 - self.l if invert_alpha and self.a is not None: @@ -437,7 +425,7 @@ def invert(self, invert_alpha: bool = False) -> "hsla": return hsla(self.h, self.s, self.l, self.a, _validate=False) - def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> "hsla": + def grayscale(self, *, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> hsla: """Converts the color to grayscale using the luminance formula.\n --------------------------------------------------------------------------- - `method` -⠀the luminance calculation method to use: @@ -451,7 +439,7 @@ def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag self.h, self.s, self.l, _ = rgba(l, l, l, _validate=False).to_hsla().values() return hsla(self.h, self.s, self.l, self.a, _validate=False) - def blend(self, other: Hsla, ratio: float = 0.5, additive_alpha: bool = False) -> "hsla": + def blend(self, other: Hsla, /, ratio: float = 0.5, *, additive_alpha: bool = False) -> hsla: """Blends the current color with another color using the specified ratio in range [0.0, 1.0] inclusive.\n ---------------------------------------------------------------------------------------------------------- - `other` -⠀the other HSLA color to blend with @@ -462,14 +450,14 @@ def blend(self, other: Hsla, ratio: float = 0.5, additive_alpha: bool = False) - - `additive_alpha` -⠀whether to blend the alpha channels additively or not""" if not Color.is_valid_hsla(other): raise TypeError(f"The 'other' parameter must be a valid HSLA color, got {type(other)}") - if not isinstance(ratio, float): - raise TypeError(f"The 'ratio' parameter must be a float, got {type(ratio)}") - elif not (0.0 <= ratio <= 1.0): + if not (0.0 <= ratio <= 1.0): raise ValueError(f"The 'ratio' parameter must be in range [0.0, 1.0] inclusive, got {ratio!r}") - if not isinstance(additive_alpha, bool): - raise TypeError(f"The 'additive_alpha' parameter must be a boolean, got {type(additive_alpha)}") - self.h, self.s, self.l, self.a = self.to_rgba().blend(Color.to_rgba(other), ratio, additive_alpha).to_hsla().values() + self.h, self.s, self.l, self.a = self.to_rgba().blend( + Color.to_rgba(other), + ratio, + additive_alpha=additive_alpha, + ).to_hsla().values() return hsla(self.h, self.s, self.l, self.a, _validate=False) def is_dark(self) -> bool: @@ -488,7 +476,7 @@ def is_opaque(self) -> bool: """Returns `True` if the color has no transparency.""" return self.a == 1 or self.a is None - def with_alpha(self, alpha: float) -> "hsla": + def with_alpha(self, alpha: float, /) -> hsla: """Returns a new color with the specified alpha value.""" if not isinstance(alpha, float): raise TypeError(f"The 'alpha' parameter must be a float, got {type(alpha)}") @@ -497,12 +485,12 @@ def with_alpha(self, alpha: float) -> "hsla": return hsla(self.h, self.s, self.l, alpha, _validate=False) - def complementary(self) -> "hsla": + def complementary(self) -> hsla: """Returns the complementary color (180 degrees on the color wheel).""" return hsla((self.h + 180) % 360, self.s, self.l, self.a, _validate=False) @classmethod - def _hsl_to_rgb(cls, h: int, s: int, l: int) -> tuple: + def _hsl_to_rgb(cls, h: int, s: int, l: int) -> tuple[int, int, int]: """Internal method to convert HSL to RGB color space.""" _h, _s, _l = h / 360, s / 100, l / 100 @@ -562,12 +550,14 @@ class hexa: def __init__( self, - color: str | int, + color: Optional[str | int] = None, + /, + *, _r: Optional[int] = None, _g: Optional[int] = None, _b: Optional[int] = None, _a: Optional[float] = None, - ): + ) -> None: self.r: int """The red channel in range [0, 255] inclusive.""" self.g: int @@ -623,28 +613,26 @@ def __init__( elif isinstance(color, int): self.r, self.g, self.b, self.a = Color.hex_int_to_rgba(color).values() - else: - raise TypeError(f"The 'color' parameter must be a string or integer, got {type(color)}") def __len__(self) -> int: """The number of components in the color (3 or 4).""" return 3 if self.a is None else 4 - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[str]: return iter((f"{self.r:02X}", f"{self.g:02X}", f"{self.b:02X}") + (() if self.a is None else (f"{int(self.a * 255):02X}", ))) - def __getitem__(self, index: int) -> str | int: + def __getitem__(self, index: int, /) -> str: return ((f"{self.r:02X}", f"{self.g:02X}", f"{self.b:02X}") \ + (() if self.a is None else (f"{int(self.a * 255):02X}", )))[index] - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: """Check if two `hexa` objects are the same color.""" if not isinstance(other, hexa): return False return (self.r, self.g, self.b, self.a) == (other.r, other.g, other.b, other.a) - def __ne__(self, other: object) -> bool: + def __ne__(self, other: object, /) -> bool: """Check if two `hexa` objects are different colors.""" return not self.__eq__(other) @@ -654,29 +642,19 @@ def __repr__(self) -> str: def __str__(self) -> str: return f"#{self.r:02X}{self.g:02X}{self.b:02X}{'' if self.a is None else f'{int(self.a * 255):02X}'}" - def dict(self) -> dict: + def dict(self) -> HexaDict: """Returns the color components as a dictionary with hex string values for keys `"r"`, `"g"`, `"b"` and optionally `"a"`.""" - return ( - dict(r=f"{self.r:02X}", g=f"{self.g:02X}", b=f"{self.b:02X}") if self.a is None else dict( - r=f"{self.r:02X}", - g=f"{self.g:02X}", - b=f"{self.b:02X}", - a=f"{int(self.a * 255):02X}", - ) - ) + return { + "r": f"{self.r:02X}", "g": f"{self.g:02X}", "b": f"{self.b:02X}", "a": + None if self.a is None else f"{int(self.a * 255):02X}" + } - def values(self, round_alpha: bool = True) -> tuple: + def values(self, *, round_alpha: bool = True) -> tuple[int, int, int, Optional[float]]: """Returns the color components as separate values `r, g, b, a`.""" - if not isinstance(round_alpha, bool): - raise TypeError(f"The 'round_alpha' parameter must be a boolean, got {type(round_alpha)}") - return self.r, self.g, self.b, None if self.a is None else (round(self.a, 2) if round_alpha else self.a) - def to_rgba(self, round_alpha: bool = True) -> "rgba": + def to_rgba(self, *, round_alpha: bool = True) -> rgba: """Returns the color as `rgba()` color object.""" - if not isinstance(round_alpha, bool): - raise TypeError(f"The 'round_alpha' parameter must be a boolean, got {type(round_alpha)}") - return rgba( self.r, self.g, @@ -685,77 +663,60 @@ def to_rgba(self, round_alpha: bool = True) -> "rgba": _validate=False, ) - def to_hsla(self, round_alpha: bool = True) -> "hsla": + def to_hsla(self, *, round_alpha: bool = True) -> hsla: """Returns the color as `hsla()` color object.""" - if not isinstance(round_alpha, bool): - raise TypeError(f"The 'round_alpha' parameter must be a boolean, got {type(round_alpha)}") - - return self.to_rgba(round_alpha).to_hsla() + return self.to_rgba(round_alpha=round_alpha).to_hsla() def has_alpha(self) -> bool: """Returns `True` if the color has an alpha channel and `False` otherwise.""" return self.a is not None - def lighten(self, amount: float) -> "hexa": + def lighten(self, amount: float, /) -> hexa: """Increases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") - self.r, self.g, self.b, self.a = self.to_rgba(False).lighten(amount).values() - return hexa("", self.r, self.g, self.b, self.a) + self.r, self.g, self.b, self.a = self.to_rgba(round_alpha=False).lighten(amount).values() + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) - def darken(self, amount: float) -> "hexa": + def darken(self, amount: float, /) -> hexa: """Decreases the colors lightness by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") - self.r, self.g, self.b, self.a = self.to_rgba(False).darken(amount).values() - return hexa("", self.r, self.g, self.b, self.a) + self.r, self.g, self.b, self.a = self.to_rgba(round_alpha=False).darken(amount).values() + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) - def saturate(self, amount: float) -> "hexa": + def saturate(self, amount: float, /) -> hexa: """Increases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") - self.r, self.g, self.b, self.a = self.to_rgba(False).saturate(amount).values() - return hexa("", self.r, self.g, self.b, self.a) + self.r, self.g, self.b, self.a = self.to_rgba(round_alpha=False).saturate(amount).values() + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) - def desaturate(self, amount: float) -> "hexa": + def desaturate(self, amount: float, /) -> hexa: """Decreases the colors saturation by the specified amount in range [0.0, 1.0] inclusive.""" - if not isinstance(amount, float): - raise TypeError(f"The 'amount' parameter must be a float, got {type(amount)}") - elif not (0.0 <= amount <= 1.0): + if not (0.0 <= amount <= 1.0): raise ValueError(f"The 'amount' parameter must be in range [0.0, 1.0] inclusive, got {amount!r}") - self.r, self.g, self.b, self.a = self.to_rgba(False).desaturate(amount).values() - return hexa("", self.r, self.g, self.b, self.a) + self.r, self.g, self.b, self.a = self.to_rgba(round_alpha=False).desaturate(amount).values() + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) - def rotate(self, degrees: int) -> "hexa": + def rotate(self, degrees: int, /) -> hexa: """Rotates the colors hue by the specified number of degrees.""" - if not isinstance(degrees, int): - raise TypeError(f"The 'degrees' parameter must be an integer, got {type(degrees)}") - - self.r, self.g, self.b, self.a = self.to_rgba(False).rotate(degrees).values() - return hexa("", self.r, self.g, self.b, self.a) + self.r, self.g, self.b, self.a = self.to_rgba(round_alpha=False).rotate(degrees).values() + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) - def invert(self, invert_alpha: bool = False) -> "hexa": + def invert(self, *, invert_alpha: bool = False) -> hexa: """Inverts the color by rotating hue by 180 degrees and inverting lightness.""" - if not isinstance(invert_alpha, bool): - raise TypeError(f"The 'invert_alpha' parameter must be a boolean, got {type(invert_alpha)}") - - self.r, self.g, self.b, self.a = self.to_rgba(False).invert().values() + self.r, self.g, self.b, self.a = self.to_rgba(round_alpha=False).invert().values() if invert_alpha and self.a is not None: self.a = 1 - self.a - return hexa("", self.r, self.g, self.b, self.a) + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) - def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> "hexa": + def grayscale(self, *, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2") -> hexa: """Converts the color to grayscale using the luminance formula.\n --------------------------------------------------------------------------- - `method` -⠀the luminance calculation method to use: @@ -765,9 +726,9 @@ def grayscale(self, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag * `"bt601"` ITU-R BT.601 standard (older TV standard)""" # THE 'method' PARAM IS CHECKED IN 'Color.luminance()' self.r = self.g = self.b = int(Color.luminance(self.r, self.g, self.b, method=method)) - return hexa("", self.r, self.g, self.b, self.a) + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) - def blend(self, other: Hexa, ratio: float = 0.5, additive_alpha: bool = False) -> "hexa": + def blend(self, other: Hexa, /, ratio: float = 0.5, *, additive_alpha: bool = False) -> hexa: """Blends the current color with another color using the specified ratio in range [0.0, 1.0] inclusive.\n ---------------------------------------------------------------------------------------------------------- - `other` -⠀the other HEXA color to blend with @@ -778,19 +739,19 @@ def blend(self, other: Hexa, ratio: float = 0.5, additive_alpha: bool = False) - - `additive_alpha` -⠀whether to blend the alpha channels additively or not""" if not Color.is_valid_hexa(other): raise TypeError(f"The 'other' parameter must be a valid HEXA color, got {type(other)}") - if not isinstance(ratio, float): - raise TypeError(f"The 'ratio' parameter must be a float, got {type(ratio)}") - elif not (0.0 <= ratio <= 1.0): + if not (0.0 <= ratio <= 1.0): raise ValueError(f"The 'ratio' parameter must be in range [0.0, 1.0] inclusive, got {ratio!r}") - if not isinstance(additive_alpha, bool): - raise TypeError(f"The 'additive_alpha' parameter must be a boolean, got {type(additive_alpha)}") - self.r, self.g, self.b, self.a = self.to_rgba(False).blend(Color.to_rgba(other), ratio, additive_alpha).values() - return hexa("", self.r, self.g, self.b, self.a) + self.r, self.g, self.b, self.a = self.to_rgba(round_alpha=False).blend( + Color.to_rgba(other), + ratio, + additive_alpha=additive_alpha, + ).values() + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=self.a) def is_dark(self) -> bool: """Returns `True` if the color is considered dark (`lightness < 50%`).""" - return self.to_hsla(False).is_dark() + return self.to_hsla(round_alpha=False).is_dark() def is_light(self) -> bool: """Returns `True` if the color is considered light (`lightness >= 50%`).""" @@ -798,61 +759,70 @@ def is_light(self) -> bool: def is_grayscale(self) -> bool: """Returns `True` if the color is grayscale (`saturation == 0`).""" - return self.to_hsla(False).is_grayscale() + return self.to_hsla(round_alpha=False).is_grayscale() def is_opaque(self) -> bool: """Returns `True` if the color has no transparency (`alpha == 1.0`).""" return self.a == 1 or self.a is None - def with_alpha(self, alpha: float) -> "hexa": + def with_alpha(self, alpha: float, /) -> hexa: """Returns a new color with the specified alpha value.""" if not isinstance(alpha, float): raise TypeError(f"The 'alpha' parameter must be a float, got {type(alpha)}") elif not (0.0 <= alpha <= 1.0): raise ValueError(f"The 'alpha' parameter must be in range [0.0, 1.0] inclusive, got {alpha!r}") - return hexa("", self.r, self.g, self.b, alpha) + return hexa(_r=self.r, _g=self.g, _b=self.b, _a=alpha) - def complementary(self) -> "hexa": + def complementary(self) -> hexa: """Returns the complementary color (180 degrees on the color wheel).""" - return self.to_hsla(False).complementary().to_hexa() + return self.to_hsla(round_alpha=False).complementary().to_hexa() class Color: """This class includes methods to work with colors in different formats.""" @classmethod - def is_valid_rgba(cls, color: AnyRgba, allow_alpha: bool = True) -> bool: + def is_valid_rgba(cls, color: AnyRgba, /, *, allow_alpha: bool = True) -> bool: """Check if the given color is a valid RGBA color.\n ----------------------------------------------------------------- - `color` -⠀the color to check (can be in any supported format) - `allow_alpha` -⠀whether to allow alpha channel in the color""" - if not isinstance(allow_alpha, bool): - raise TypeError(f"The 'new_tab_size' parameter must be an boolean, got {type(allow_alpha)}") - try: if isinstance(color, rgba): return True elif isinstance(color, (list, tuple)): - if allow_alpha and cls.has_alpha(color): + array_color = cast(list[Any] | tuple[Any, ...], color) + + if (allow_alpha \ + and len(array_color) == 4 + and all(isinstance(val, int) for val in array_color[:3]) + and isinstance(array_color[3], (float, type(None))) + ): return ( - 0 <= color[0] <= 255 and 0 <= color[1] <= 255 and 0 <= color[2] <= 255 - and (0 <= color[3] <= 1 or color[3] is None) + 0 <= array_color[0] <= 255 and 0 <= array_color[1] <= 255 and 0 <= array_color[2] <= 255 + and (array_color[3] is None or 0 <= array_color[3] <= 1) ) - elif len(color) == 3: - return 0 <= color[0] <= 255 and 0 <= color[1] <= 255 and 0 <= color[2] <= 255 + elif len(array_color) == 3 and all(isinstance(val, int) for val in array_color): + return 0 <= array_color[0] <= 255 and 0 <= array_color[1] <= 255 and 0 <= array_color[2] <= 255 else: return False elif isinstance(color, dict): - if allow_alpha and cls.has_alpha(color): + dict_color = cast(dict[str, Any], color) + + if (allow_alpha \ + and len(dict_color) == 4 + and all(isinstance(dict_color.get(ch), int) for ch in ("r", "g", "b")) + and isinstance(dict_color.get("a", "no alpha"), (float, type(None))) + ): return ( - 0 <= color["r"] <= 255 and 0 <= color["g"] <= 255 and 0 <= color["b"] <= 255 - and (0 <= color["a"] <= 1 or color["a"] is None) + 0 <= dict_color["r"] <= 255 and 0 <= dict_color["g"] <= 255 and 0 <= dict_color["b"] <= 255 + and (dict_color["a"] is None or 0 <= dict_color["a"] <= 1) ) - elif len(color) == 3: - return 0 <= color["r"] <= 255 and 0 <= color["g"] <= 255 and 0 <= color["b"] <= 255 + elif len(dict_color) == 3 and all(isinstance(dict_color.get(ch), int) for ch in ("r", "g", "b")): + return 0 <= dict_color["r"] <= 255 and 0 <= dict_color["g"] <= 255 and 0 <= dict_color["b"] <= 255 else: return False @@ -864,7 +834,7 @@ def is_valid_rgba(cls, color: AnyRgba, allow_alpha: bool = True) -> bool: return False @classmethod - def is_valid_hsla(cls, color: AnyHsla, allow_alpha: bool = True) -> bool: + def is_valid_hsla(cls, color: AnyHsla, /, *, allow_alpha: bool = True) -> bool: """Check if the given color is a valid HSLA color.\n ----------------------------------------------------------------- - `color` -⠀the color to check (can be in any supported format) @@ -874,24 +844,36 @@ def is_valid_hsla(cls, color: AnyHsla, allow_alpha: bool = True) -> bool: return True elif isinstance(color, (list, tuple)): - if allow_alpha and cls.has_alpha(color): + array_color = cast(list[Any] | tuple[Any, ...], color) + + if (allow_alpha \ + and len(array_color) == 4 + and all(isinstance(val, int) for val in array_color[:3]) + and isinstance(array_color[3], (float, type(None))) + ): return ( - 0 <= color[0] <= 360 and 0 <= color[1] <= 100 and 0 <= color[2] <= 100 - and (0 <= color[3] <= 1 or color[3] is None) + 0 <= array_color[0] <= 360 and 0 <= array_color[1] <= 100 and 0 <= array_color[2] <= 100 + and (array_color[3] is None or 0 <= array_color[3] <= 1) ) - elif len(color) == 3: - return 0 <= color[0] <= 360 and 0 <= color[1] <= 100 and 0 <= color[2] <= 100 + elif len(array_color) == 3 and all(isinstance(val, int) for val in array_color): + return 0 <= array_color[0] <= 360 and 0 <= array_color[1] <= 100 and 0 <= array_color[2] <= 100 else: return False elif isinstance(color, dict): - if allow_alpha and cls.has_alpha(color): + dict_color = cast(dict[str, Any], color) + + if (allow_alpha \ + and len(dict_color) == 4 + and all(isinstance(dict_color.get(ch), int) for ch in ("h", "s", "l")) + and isinstance(dict_color.get("a", "no alpha"), (float, type(None))) + ): return ( - 0 <= color["h"] <= 360 and 0 <= color["s"] <= 100 and 0 <= color["l"] <= 100 - and (0 <= color["a"] <= 1 or color["a"] is None) + 0 <= dict_color["h"] <= 360 and 0 <= dict_color["s"] <= 100 and 0 <= dict_color["l"] <= 100 + and (dict_color["a"] is None or 0 <= dict_color["a"] <= 1) ) - elif len(color) == 3: - return 0 <= color["h"] <= 360 and 0 <= color["s"] <= 100 and 0 <= color["l"] <= 100 + elif len(dict_color) == 3 and all(isinstance(dict_color.get(ch), int) for ch in ("h", "s", "l")): + return 0 <= dict_color["h"] <= 360 and 0 <= dict_color["s"] <= 100 and 0 <= dict_color["l"] <= 100 else: return False @@ -902,10 +884,48 @@ def is_valid_hsla(cls, color: AnyHsla, allow_alpha: bool = True) -> bool: pass return False + @overload + @classmethod + def is_valid_hexa( + cls, + color: AnyHexa, + /, + *, + allow_alpha: bool = True, + get_prefix: Literal[True], + ) -> tuple[bool, Optional[Literal["#", "0x"]]]: + ... + + @overload + @classmethod + def is_valid_hexa( + cls, + color: AnyHexa, + /, + *, + allow_alpha: bool = True, + get_prefix: Literal[False] = False, + ) -> bool: + ... + + @overload @classmethod def is_valid_hexa( cls, color: AnyHexa, + /, + *, + allow_alpha: bool = True, + get_prefix: bool = False, + ) -> bool | tuple[bool, Optional[Literal["#", "0x"]]]: + ... + + @classmethod + def is_valid_hexa( + cls, + color: AnyHexa, + /, + *, allow_alpha: bool = True, get_prefix: bool = False, ) -> bool | tuple[bool, Optional[Literal["#", "0x"]]]: @@ -936,19 +956,19 @@ def is_valid_hexa( return (False, None) if get_prefix else False @classmethod - def is_valid(cls, color: AnyRgba | AnyHsla | AnyHexa, allow_alpha: bool = True) -> bool: + def is_valid(cls, color: AnyRgba | AnyHsla | AnyHexa, /, *, allow_alpha: bool = True) -> bool: """Check if the given color is a valid RGBA, HSLA or HEXA color.\n ------------------------------------------------------------------- - `color` -⠀the color to check (can be in any supported format) - `allow_alpha` -⠀whether to allow alpha channel in the color""" return bool( - cls.is_valid_rgba(color, allow_alpha) \ - or cls.is_valid_hsla(color, allow_alpha) \ - or cls.is_valid_hexa(color, allow_alpha) + cls.is_valid_rgba(color, allow_alpha=allow_alpha) \ + or cls.is_valid_hsla(color, allow_alpha=allow_alpha) \ + or cls.is_valid_hexa(color, allow_alpha=allow_alpha) ) @classmethod - def has_alpha(cls, color: Rgba | Hsla | Hexa) -> bool: + def has_alpha(cls, color: Rgba | Hsla | Hexa, /) -> bool: """Check if the given color has an alpha channel.\n --------------------------------------------------------------------------- - `color` -⠀the color to check (can be in any supported format)""" @@ -968,64 +988,79 @@ def has_alpha(cls, color: Rgba | Hsla | Hexa) -> bool: elif isinstance(color, str): if parsed_rgba := cls.str_to_rgba(color, only_first=True): - return cast(rgba, parsed_rgba).has_alpha() + return parsed_rgba.has_alpha() if parsed_hsla := cls.str_to_hsla(color, only_first=True): - return cast(hsla, parsed_hsla).has_alpha() + return parsed_hsla.has_alpha() - elif isinstance(color, (list, tuple)) and len(color) == 4 and color[3] is not None: + elif isinstance(color, (list, tuple)) and len(color) == 4: return True - elif isinstance(color, dict) and len(color) == 4 and color["a"] is not None: + elif isinstance(color, dict) and len(color) == 4: return True return False @classmethod - def to_rgba(cls, color: Rgba | Hsla | Hexa) -> rgba: + def to_rgba(cls, color: Rgba | Hsla | Hexa, /) -> rgba: """Will try to convert any color type to a color of type RGBA.\n --------------------------------------------------------------------- - `color` -⠀the color to convert (can be in any supported format)""" if isinstance(color, (hsla, hexa)): return color.to_rgba() elif cls.is_valid_hsla(color): - return cls._parse_hsla(color).to_rgba() + return cls._parse_hsla(cast(Hsla, color)).to_rgba() elif cls.is_valid_hexa(color): return hexa(cast(str | int, color)).to_rgba() elif cls.is_valid_rgba(color): - return cls._parse_rgba(color) + return cls._parse_rgba(cast(Rgba, color)) raise ValueError(f"Could not convert color {color!r} to RGBA.") @classmethod - def to_hsla(cls, color: Rgba | Hsla | Hexa) -> hsla: + def to_hsla(cls, color: Rgba | Hsla | Hexa, /) -> hsla: """Will try to convert any color type to a color of type HSLA.\n --------------------------------------------------------------------- - `color` -⠀the color to convert (can be in any supported format)""" if isinstance(color, (rgba, hexa)): return color.to_hsla() elif cls.is_valid_rgba(color): - return cls._parse_rgba(color).to_hsla() + return cls._parse_rgba(cast(Rgba, color)).to_hsla() elif cls.is_valid_hexa(color): return hexa(cast(str | int, color)).to_hsla() elif cls.is_valid_hsla(color): - return cls._parse_hsla(color) + return cls._parse_hsla(cast(Hsla, color)) raise ValueError(f"Could not convert color {color!r} to HSLA.") @classmethod - def to_hexa(cls, color: Rgba | Hsla | Hexa) -> hexa: + def to_hexa(cls, color: Rgba | Hsla | Hexa, /) -> hexa: """Will try to convert any color type to a color of type HEXA.\n --------------------------------------------------------------------- - `color` -⠀the color to convert (can be in any supported format)""" if isinstance(color, (rgba, hsla)): return color.to_hexa() elif cls.is_valid_rgba(color): - return cls._parse_rgba(color).to_hexa() + return cls._parse_rgba(cast(Rgba, color)).to_hexa() elif cls.is_valid_hsla(color): - return cls._parse_hsla(color).to_hexa() + return cls._parse_hsla(cast(Hsla, color)).to_hexa() elif cls.is_valid_hexa(color): return color if isinstance(color, hexa) else hexa(cast(str | int, color)) raise ValueError(f"Could not convert color {color!r} to HEXA") + @overload + @classmethod + def str_to_rgba(cls, string: str, /, *, only_first: Literal[True]) -> Optional[rgba]: + ... + + @overload + @classmethod + def str_to_rgba(cls, string: str, /, *, only_first: Literal[False] = False) -> Optional[list[rgba]]: + ... + + @overload @classmethod - def str_to_rgba(cls, string: str, only_first: bool = False) -> Optional[rgba | list[rgba]]: + def str_to_rgba(cls, string: str, /, *, only_first: bool = False) -> Optional[rgba | list[rgba]]: + ... + + @classmethod + def str_to_rgba(cls, string: str, /, *, only_first: bool = False) -> Optional[rgba | list[rgba]]: """Will try to recognize RGBA colors inside a string and output the found ones as RGBA objects.\n --------------------------------------------------------------------------------------------------------------- - `string` -⠀the string to search for RGBA colors @@ -1033,12 +1068,12 @@ def str_to_rgba(cls, string: str, only_first: bool = False) -> Optional[rgba | l if only_first: if not (match := _re.search(Regex.rgba_str(allow_alpha=True), string)): return None - m = match.groups() + groups = match.groups() return rgba( - int(m[0]), - int(m[1]), - int(m[2]), - ((int(m[3]) if "." not in m[3] else float(m[3])) if m[3] else None), + int(groups[0]), + int(groups[1]), + int(groups[2]), + ((int(groups[3]) if "." not in groups[3] else float(groups[3])) if groups[3] else None), _validate=False, ) @@ -1047,16 +1082,31 @@ def str_to_rgba(cls, string: str, only_first: bool = False) -> Optional[rgba | l return None return [ rgba( - int(m[0]), - int(m[1]), - int(m[2]), - ((int(m[3]) if "." not in m[3] else float(m[3])) if m[3] else None), + int(match[0]), + int(match[1]), + int(match[2]), + ((int(match[3]) if "." not in match[3] else float(match[3])) if match[3] else None), _validate=False, - ) for m in matches + ) for match in matches ] + @overload + @classmethod + def str_to_hsla(cls, string: str, /, *, only_first: Literal[True]) -> Optional[hsla]: + ... + + @overload + @classmethod + def str_to_hsla(cls, string: str, /, *, only_first: Literal[False] = False) -> Optional[list[hsla]]: + ... + + @overload + @classmethod + def str_to_hsla(cls, string: str, /, *, only_first: bool = False) -> Optional[hsla | list[hsla]]: + ... + @classmethod - def str_to_hsla(cls, string: str, only_first: bool = False) -> Optional[hsla | list[hsla]]: + def str_to_hsla(cls, string: str, /, *, only_first: bool = False) -> Optional[hsla | list[hsla]]: """Will try to recognize HSLA colors inside a string and output the found ones as HSLA objects.\n --------------------------------------------------------------------------------------------------------------- - `string` -⠀the string to search for HSLA colors @@ -1093,6 +1143,8 @@ def rgba_to_hex_int( g: int, b: int, a: Optional[float] = None, + /, + *, preserve_original: bool = False, ) -> int: """Convert RGBA channels to a HEXA integer (alpha is optional).\n @@ -1128,7 +1180,7 @@ def rgba_to_hex_int( return hex_int @classmethod - def hex_int_to_rgba(cls, hex_int: int, preserve_original: bool = False) -> rgba: + def hex_int_to_rgba(cls, hex_int: int, /, *, preserve_original: bool = False) -> rgba: """Convert a HEX integer to RGBA channels.\n ------------------------------------------------------------------------------------------- - `hex_int` -⠀the HEX integer to convert @@ -1163,12 +1215,70 @@ def hex_int_to_rgba(cls, hex_int: int, preserve_original: bool = False) -> rgba: else: raise ValueError(f"Could not convert HEX integer 0x{hex_int:X} to RGBA color.") + @overload @classmethod def luminance( cls, r: int, g: int, b: int, + /, + *, + output_type: type[int], + method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2", + ) -> int: + ... + + @overload + @classmethod + def luminance( + cls, + r: int, + g: int, + b: int, + /, + *, + output_type: type[float], + method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2", + ) -> float: + ... + + @overload + @classmethod + def luminance( + cls, + r: int, + g: int, + b: int, + /, + *, + output_type: None = None, + method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2", + ) -> int: + ... + + @overload + @classmethod + def luminance( + cls, + r: int, + g: int, + b: int, + /, + *, + output_type: Optional[type[int | float]] = None, + method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2", + ) -> int | float: + ... + + @classmethod + def luminance( + cls, + r: int, + g: int, + b: int, + /, + *, output_type: Optional[type[int | float]] = None, method: Literal["wcag2", "wcag3", "simple", "bt601"] = "wcag2", ) -> int | float: @@ -1186,8 +1296,6 @@ def luminance( * `"bt601"` ITU-R BT.601 standard (older TV standard)""" if not all(0 <= c <= 255 for c in (r, g, b)): raise ValueError(f"The 'r', 'g' and 'b' parameters must be integers in [0, 255], got {r=} {g=} {b=}") - if output_type not in {int, float, None}: - raise TypeError(f"The 'output_type' parameter must be either 'int', 'float' or 'None', got {output_type!r}") _r, _g, _b = r / 255.0, g / 255.0, b / 255.0 @@ -1213,43 +1321,77 @@ def luminance( else: return round(luminance * 255) + @overload + @classmethod + def text_color_for_on_bg(cls, text_bg_color: rgba, /) -> rgba: + ... + + @overload + @classmethod + def text_color_for_on_bg(cls, text_bg_color: hexa, /) -> hexa: + ... + + @overload + @classmethod + def text_color_for_on_bg(cls, text_bg_color: int, /) -> int: + ... + + @overload @classmethod - def text_color_for_on_bg(cls, text_bg_color: Rgba | Hexa) -> rgba | hexa | int: + def text_color_for_on_bg(cls, text_bg_color: Rgba | Hexa, /) -> rgba | hexa | int: + ... + + @classmethod + def text_color_for_on_bg(cls, text_bg_color: Rgba | Hexa, /) -> rgba | hexa | int: """Returns either black or white text color for optimal contrast on the given background color.\n -------------------------------------------------------------------------------------------------- - `text_bg_color` -⠀the background color (can be in RGBA or HEXA format)""" was_hexa, was_int = cls.is_valid_hexa(text_bg_color), isinstance(text_bg_color, int) - text_bg_color = cls.to_rgba(text_bg_color) - brightness = 0.2126 * text_bg_color[0] + 0.7152 * text_bg_color[1] + 0.0722 * text_bg_color[2] + text_bg_rgba = cls.to_rgba(text_bg_color) + brightness = 0.2126 * text_bg_rgba[0] + 0.7152 * text_bg_rgba[1] + 0.0722 * text_bg_rgba[2] return ( - (0xFFFFFF if was_int else hexa("", 255, 255, 255)) if was_hexa \ + (0xFFFFFF if was_int else hexa(_r=255, _g=255, _b=255)) if was_hexa \ else rgba(255, 255, 255, _validate=False) ) if brightness < 128 else ( - (0x000 if was_int else hexa("", 0, 0, 0)) if was_hexa \ + (0x000 if was_int else hexa(_r=0, _g=0, _b=0)) if was_hexa \ else rgba(0, 0, 0, _validate=False) ) + @overload @classmethod - def adjust_lightness(cls, color: Rgba | Hexa, lightness_change: float) -> rgba | hexa: + def adjust_lightness(cls, color: rgba, lightness_change: float, /) -> rgba: + ... + + @overload + @classmethod + def adjust_lightness(cls, color: hexa, lightness_change: float, /) -> hexa: + ... + + @overload + @classmethod + def adjust_lightness(cls, color: Rgba | Hexa, lightness_change: float, /) -> rgba | hexa: + ... + + @classmethod + def adjust_lightness(cls, color: Rgba | Hexa, lightness_change: float, /) -> rgba | hexa: """In- or decrease the lightness of the input color.\n ------------------------------------------------------------------ - `color` -⠀the color to adjust (can be in RGBA or HEXA format) - `lightness_change` -⠀the amount to change the lightness by, in range `-1.0` (darken by 100%) and `1.0` (lighten by 100%)""" - was_hexa = cls.is_valid_hexa(color) - if not (-1.0 <= lightness_change <= 1.0): raise ValueError( f"The 'lightness_change' parameter must be in range [-1.0, 1.0] inclusive, got {lightness_change!r}" ) - hsla_color: hsla = cls.to_hsla(color) + was_hexa = cls.is_valid_hexa(color) + hsla_color = cls.to_hsla(color) h, s, l, a = ( int(hsla_color[0]), int(hsla_color[1]), int(hsla_color[2]), \ - hsla_color[3] if cls.has_alpha(hsla_color) else None + hsla_color[3] if hsla_color.has_alpha() else None ) l = int(max(0, min(100, l + lightness_change * 100))) @@ -1258,25 +1400,39 @@ def adjust_lightness(cls, color: Rgba | Hexa, lightness_change: float) -> rgba | else hsla(h, s, l, a, _validate=False).to_rgba() ) + @overload + @classmethod + def adjust_saturation(cls, color: rgba, saturation_change: float, /) -> rgba: + ... + + @overload @classmethod - def adjust_saturation(cls, color: Rgba | Hexa, saturation_change: float) -> rgba | hexa: + def adjust_saturation(cls, color: hexa, saturation_change: float, /) -> hexa: + ... + + @overload + @classmethod + def adjust_saturation(cls, color: Rgba | Hexa, saturation_change: float, /) -> rgba | hexa: + ... + + @classmethod + def adjust_saturation(cls, color: Rgba | Hexa, saturation_change: float, /) -> rgba | hexa: """In- or decrease the saturation of the input color.\n ----------------------------------------------------------------------- - `color` -⠀the color to adjust (can be in RGBA or HEXA format) - `saturation_change` -⠀the amount to change the saturation by, in range `-1.0` (saturate by 100%) and `1.0` (desaturate by 100%)""" - was_hexa = cls.is_valid_hexa(color) - if not (-1.0 <= saturation_change <= 1.0): raise ValueError( f"The 'saturation_change' parameter must be in range [-1.0, 1.0] inclusive, got {saturation_change!r}" ) - hsla_color: hsla = cls.to_hsla(color) + was_hexa = cls.is_valid_hexa(color) + hsla_color = cls.to_hsla(color) h, s, l, a = ( int(hsla_color[0]), int(hsla_color[1]), int(hsla_color[2]), \ - hsla_color[3] if cls.has_alpha(hsla_color) else None + hsla_color[3] if hsla_color.has_alpha() else None ) s = int(max(0, min(100, s + saturation_change * 100))) @@ -1286,41 +1442,59 @@ def adjust_saturation(cls, color: Rgba | Hexa, saturation_change: float) -> rgba ) @classmethod - def _parse_rgba(cls, color: AnyRgba) -> rgba: + def _parse_rgba(cls, color: Rgba, /) -> rgba: """Internal method to parse a color to an RGBA object.""" if isinstance(color, rgba): return color + elif isinstance(color, (list, tuple)): - if len(color) == 4: - return rgba(color[0], color[1], color[2], color[3], _validate=False) - elif len(color) == 3: - return rgba(color[0], color[1], color[2], None, _validate=False) + array_color = cast(list[Any] | tuple[Any, ...], color) + if len(array_color) == 4: + return rgba( + int(array_color[0]), int(array_color[1]), int(array_color[2]), float(array_color[3]), _validate=False + ) + elif len(array_color) == 3: + return rgba(int(array_color[0]), int(array_color[1]), int(array_color[2]), None, _validate=False) + raise ValueError(f"Could not parse RGBA color: {color!r}") + elif isinstance(color, dict): - return rgba(color["r"], color["g"], color["b"], color.get("a"), _validate=False) + dict_color = cast(dict[str, Any], color) + return rgba(int(dict_color["r"]), int(dict_color["g"]), int(dict_color["b"]), dict_color.get("a"), _validate=False) + elif isinstance(color, str): if parsed := cls.str_to_rgba(color, only_first=True): - return cast(rgba, parsed) + return parsed + raise ValueError(f"Could not parse RGBA color: {color!r}") @classmethod - def _parse_hsla(cls, color: AnyHsla) -> hsla: + def _parse_hsla(cls, color: Hsla, /) -> hsla: """Internal method to parse a color to an HSLA object.""" if isinstance(color, hsla): return color + elif isinstance(color, (list, tuple)): + array_color = cast(list[Any] | tuple[Any, ...], color) if len(color) == 4: - return hsla(color[0], color[1], color[2], color[3], _validate=False) + return hsla( + int(array_color[0]), int(array_color[1]), int(array_color[2]), float(array_color[3]), _validate=False + ) elif len(color) == 3: - return hsla(color[0], color[1], color[2], None, _validate=False) + return hsla(int(array_color[0]), int(array_color[1]), int(array_color[2]), None, _validate=False) + raise ValueError(f"Could not parse HSLA color: {color!r}") + elif isinstance(color, dict): - return hsla(color["h"], color["s"], color["l"], color.get("a"), _validate=False) + dict_color = cast(dict[str, Any], color) + return hsla(int(dict_color["h"]), int(dict_color["s"]), int(dict_color["l"]), dict_color.get("a"), _validate=False) + elif isinstance(color, str): if parsed := cls.str_to_hsla(color, only_first=True): - return cast(hsla, parsed) + return parsed + raise ValueError(f"Could not parse HSLA color: {color!r}") @staticmethod - def _linearize_srgb(c: float) -> float: + def _linearize_srgb(c: float, /) -> float: """Helper method to linearize sRGB component following the WCAG standard.""" if not (0.0 <= c <= 1.0): raise ValueError(f"The 'c' parameter must be in range [0.0, 1.0] inclusive, got {c!r}") diff --git a/src/xulbux/console.py b/src/xulbux/console.py index 44f898c..d63f1ce 100644 --- a/src/xulbux/console.py +++ b/src/xulbux/console.py @@ -1,5 +1,5 @@ """ -This module provides the `Console`, `ProgressBar`, and `Spinner` classes +This module provides the `Console`, `ProgressBar`, and `Throbber` classes which offer methods for logging and other actions within the console. """ @@ -7,19 +7,21 @@ from .base.decorators import mypyc_attr from .base.consts import COLOR, CHARS, ANSI -from .format_codes import _PATTERNS as _FC_PATTERNS, FormatCodes +from .format_codes import _PATTERNS as _FC_PATTERNS, FormatCodes # type: ignore[private-access] from .string import String from .color import Color, hexa from .regex import LazyRegex -from typing import Generator, Callable, Optional, Literal, TypeVar, TextIO, Any, overload, cast +from typing import ValuesView, Generator, Callable, KeysView, Optional, Literal, TypeVar, TextIO, Any, overload, cast from prompt_toolkit.key_binding import KeyPressEvent, KeyBindings from prompt_toolkit.validation import ValidationError, Validator +from prompt_toolkit.document import Document from prompt_toolkit.styles import Style from prompt_toolkit.keys import Keys from contextlib import contextmanager from io import StringIO import prompt_toolkit as _pt +import subprocess as _subprocess import threading as _threading import keyboard as _keyboard import getpass as _getpass @@ -57,7 +59,7 @@ class ParsedArgData: ------------------------------------------------------------------------------------------------------------ When the `ParsedArgData` instance is accessed as a boolean it will correspond to the `exists` attribute.""" - def __init__(self, exists: bool, values: list[str], is_pos: bool, flag: Optional[str] = None): + def __init__(self, *, exists: bool, values: list[str], is_pos: bool, flag: Optional[str] = None): self.exists: bool = exists """Whether the argument was found or not.""" self.is_pos: bool = is_pos @@ -71,7 +73,7 @@ def __bool__(self) -> bool: """Whether the argument was found or not (i.e. the `exists` attribute).""" return self.exists - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: """Check if two `ParsedArgData` objects are equal by comparing their attributes.""" if not isinstance(other, ParsedArgData): return False @@ -82,7 +84,7 @@ def __eq__(self, other: object) -> bool: and self.flag == other.flag ) - def __ne__(self, other: object) -> bool: + def __ne__(self, other: object, /) -> bool: """Check if two `ParsedArgData` objects are not equal by comparing their attributes.""" return not self.__eq__(other) @@ -115,7 +117,7 @@ def __len__(self): """The number of arguments stored in the `ParsedArgs` object.""" return len(vars(self)) - def __contains__(self, key): + def __contains__(self, key: str, /) -> bool: """Checks if an argument with the given alias exists in the `ParsedArgs` object.""" return key in vars(self) @@ -123,25 +125,25 @@ def __bool__(self) -> bool: """Whether the `ParsedArgs` object contains any arguments.""" return len(self) > 0 - def __getattr__(self, name: str) -> ParsedArgData: + def __getattr__(self, name: str, /) -> ParsedArgData: raise AttributeError(f"'{type(self).__name__}' object has no attribute {name}") - def __getitem__(self, key): + def __getitem__(self, key: str | int, /) -> ParsedArgData: if isinstance(key, int): - return list(self.__iter__())[key] + return list(self.values())[key] return getattr(self, key) def __iter__(self) -> Generator[tuple[str, ParsedArgData], None, None]: for key, val in cast(dict[str, ParsedArgData], vars(self)).items(): yield (key, val) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: """Check if two `ParsedArgs` objects are equal by comparing their stored arguments.""" if not isinstance(other, ParsedArgs): return False return vars(self) == vars(other) - def __ne__(self, other: object) -> bool: + def __ne__(self, other: object, /) -> bool: """Check if two `ParsedArgs` objects are not equal by comparing their stored arguments.""" return not self.__eq__(other) @@ -160,15 +162,15 @@ def dict(self) -> dict[str, ArgData]: """Returns the arguments as a dictionary.""" return {key: val.dict() for key, val in self.__iter__()} - def get(self, key: str, default: Any = None) -> ParsedArgData | Any: + def get(self, key: str, /, default: Any = None) -> ParsedArgData | Any: """Returns the argument result for the given alias, or `default` if not found.""" return getattr(self, key, default) - def keys(self): + def keys(self) -> KeysView[str]: """Returns the argument aliases as `dict_keys([…])`.""" return vars(self).keys() - def values(self): + def values(self) -> ValuesView[ParsedArgData]: """Returns the argument results as `dict_values([…])`.""" return vars(self).values() @@ -260,7 +262,7 @@ class Console(metaclass=_ConsoleMeta): """This class provides methods for logging and other actions within the console.""" @classmethod - def get_args(cls, arg_parse_configs: ArgParseConfigs, flag_value_sep: str = "=") -> ParsedArgs: + def get_args(cls, arg_parse_configs: ArgParseConfigs, /, *, flag_value_sep: str = "=") -> ParsedArgs: """Will search for the specified args in the command-line arguments and return the results as a special `ParsedArgs` object.\n ------------------------------------------------------------------------------------------------- @@ -269,24 +271,24 @@ def get_args(cls, arg_parse_configs: ArgParseConfigs, flag_value_sep: str = "=") - `flag_value_sep` - the character/s used to separate flags from their values\n ------------------------------------------------------------------------------------------------- The `arg_parse_configs` dictionary can have the following structures for each item: - 1. Simple set of flags (when no default value is needed): - ```python + 1. Simple set of flags (when no default value is needed): + ```python "alias_name": {"-f", "--flag"} - ``` - 2. Dictionary with the`"flags"` set, plus a specified `"default"` value: - ```python + ``` + 2. Dictionary with the`"flags"` set, plus a specified `"default"` value: + ```python "alias_name": { "flags": {"-f", "--flag"}, "default": "some_value", } - ``` - 3. Positional value collection using the literals `"before"` or `"after"`: - ```python + ``` + 3. Positional value collection using the literals `"before"` or `"after"`: + ```python # COLLECT ALL NON-FLAGGED VALUES THAT APPEAR BEFORE THE FIRST FLAG "alias_name": "before" # COLLECT ALL NON-FLAGGED VALUES THAT APPEAR AFTER THE LAST FLAG'S VALUE "alias_name": "after" - ``` + ``` #### Example usage: If you call the `get_args()` method in your script like this: ```python @@ -327,6 +329,8 @@ def get_args(cls, arg_parse_configs: ArgParseConfigs, flag_value_sep: str = "=") def pause_exit( cls, prompt: object = "", + /, + *, pause: bool = True, exit: bool = False, exit_code: int = 0, @@ -351,9 +355,9 @@ def pause_exit( def cls(cls) -> None: """Will clear the console in addition to completely resetting the ANSI formats.""" if _shutil.which("cls"): - _os.system("cls") + _subprocess.run(["cls"]) elif _shutil.which("clear"): - _os.system("clear") + _subprocess.run(["clear"]) print("\033[0m", end="", flush=True) @classmethod @@ -361,6 +365,8 @@ def log( cls, title: Optional[str] = None, prompt: object = "", + /, + *, format_linebreaks: bool = True, start: str = "", end: str = "\n", @@ -387,7 +393,7 @@ def log( information about formatting codes, see `format_codes` module documentation.""" has_title_bg: bool = False if title_bg_color is not None and (Color.is_valid_rgba(title_bg_color) or Color.is_valid_hexa(title_bg_color)): - title_bg_color, has_title_bg = Color.to_hexa(cast(Rgba | Hexa, title_bg_color)), True + title_bg_color, has_title_bg = Color.to_hexa(title_bg_color), True if tab_size < 0: raise ValueError("The 'tab_size' parameter must be a non-negative integer.") if title_px < 0: @@ -402,21 +408,14 @@ def log( tab = " " * (tab_size - 1 - ((len(mx) + (title_len := len(title) + 2 * len(px))) % tab_size)) if format_linebreaks: - clean_prompt, removals = cast( - tuple[str, tuple[tuple[int, str], ...]], - FormatCodes.remove(str(prompt), get_removals=True, _ignore_linebreaks=True), - ) + clean_prompt, removals = *FormatCodes.remove(str(prompt), get_removals=True, _ignore_linebreaks=True), prompt_lst: list[str] = [ - item for lst in - ( + item for lst in [ String.split_count(line, cls.w - (title_len + len(tab) + 2 * len(mx))) \ for line in str(clean_prompt).splitlines() - ) - for item in ([""] if lst == [] else (lst if isinstance(lst, list) else [lst])) + ] for item in ([""] if lst == [] else lst) ] - prompt = f"\n{mx}{' ' * title_len}{mx}{tab}".join( - cls._add_back_removed_parts(prompt_lst, cast(tuple[tuple[int, str], ...], removals)) - ) + prompt = f"\n{mx}{' ' * title_len}{mx}{tab}".join(cls._add_back_removed_parts(prompt_lst, removals)) if title == "": FormatCodes.print( @@ -436,6 +435,8 @@ def log( def debug( cls, prompt: object = "Point in program reached.", + /, + *, active: bool = True, format_linebreaks: bool = True, start: str = "", @@ -451,8 +452,8 @@ def debug( If `active` is false, no debug message will be printed.""" if active: cls.log( - title="DEBUG", - prompt=prompt, + "DEBUG", + prompt, format_linebreaks=format_linebreaks, start=start, end=end, @@ -465,6 +466,8 @@ def debug( def info( cls, prompt: object = "Program running.", + /, + *, format_linebreaks: bool = True, start: str = "", end: str = "\n", @@ -477,8 +480,8 @@ def info( """A preset for `log()`: `INFO` log message with the options to pause at the message and exit the program after the message was printed.""" cls.log( - title="INFO", - prompt=prompt, + "INFO", + prompt, format_linebreaks=format_linebreaks, start=start, end=end, @@ -491,6 +494,8 @@ def info( def done( cls, prompt: object = "Program finished.", + /, + *, format_linebreaks: bool = True, start: str = "", end: str = "\n", @@ -503,8 +508,8 @@ def done( """A preset for `log()`: `DONE` log message with the options to pause at the message and exit the program after the message was printed.""" cls.log( - title="DONE", - prompt=prompt, + "DONE", + prompt, format_linebreaks=format_linebreaks, start=start, end=end, @@ -517,6 +522,8 @@ def done( def warn( cls, prompt: object = "Important message.", + /, + *, format_linebreaks: bool = True, start: str = "", end: str = "\n", @@ -529,8 +536,8 @@ def warn( """A preset for `log()`: `WARN` log message with the options to pause at the message and exit the program after the message was printed.""" cls.log( - title="WARN", - prompt=prompt, + "WARN", + prompt, format_linebreaks=format_linebreaks, start=start, end=end, @@ -543,6 +550,8 @@ def warn( def fail( cls, prompt: object = "Program error.", + /, + *, format_linebreaks: bool = True, start: str = "", end: str = "\n", @@ -555,8 +564,8 @@ def fail( """A preset for `log()`: `FAIL` log message with the options to pause at the message and exit the program after the message was printed.""" cls.log( - title="FAIL", - prompt=prompt, + "FAIL", + prompt, format_linebreaks=format_linebreaks, start=start, end=end, @@ -569,6 +578,8 @@ def fail( def exit( cls, prompt: object = "Program ended.", + /, + *, format_linebreaks: bool = True, start: str = "", end: str = "\n", @@ -581,8 +592,8 @@ def exit( """A preset for `log()`: `EXIT` log message with the options to pause at the message and exit the program after the message was printed.""" cls.log( - title="EXIT", - prompt=prompt, + "EXIT", + prompt, format_linebreaks=format_linebreaks, start=start, end=end, @@ -688,15 +699,15 @@ def log_box_bordered( - `"strong" = ('┏', '━', '┓', '┃', '┛', '━', '┗', '┃', '┣', '━', '┫')` - `"double" = ('╔', '═', '╗', '║', '╝', '═', '╚', '║', '╠', '═', '╣')`\n The order of the characters is always: - 1. top-left corner - 2. top border - 3. top-right corner - 4. right border - 5. bottom-right corner - 6. bottom border - 7. bottom-left corner - 8. left border - 9. left horizontal rule connector + 1. top-left corner + 2. top border + 3. top-right corner + 4. right border + 5. bottom-right corner + 6. bottom border + 7. bottom-left corner + 8. left border + 9. left horizontal rule connector 10. horizontal rule 11. right horizontal rule connector""" if w_padding < 0: @@ -709,7 +720,7 @@ def log_box_bordered( if not all(len(char) == 1 for char in _border_chars): raise ValueError("The '_border_chars' parameter must only contain single-character strings.") - if border_style is not None and Color.is_valid(border_style): + if Color.is_valid(border_style): border_style = Color.to_hexa(border_style) borders = { @@ -754,6 +765,8 @@ def log_box_bordered( def confirm( cls, prompt: object = "Do you want to continue?", + /, + *, start: str = "", end: str = "", default_color: Optional[Rgba | Hexa] = None, @@ -784,6 +797,8 @@ def confirm( def multiline_input( cls, prompt: object = "", + /, + *, start: str = "", end: str = "\n", default_color: Optional[Rgba | Hexa] = None, @@ -819,6 +834,8 @@ def multiline_input( def input( cls, prompt: object = "", + /, + *, start: str = "", end: str = "", default_color: Optional[Rgba | Hexa] = None, @@ -839,6 +856,8 @@ def input( def input( cls, prompt: object = "", + /, + *, start: str = "", end: str = "", default_color: Optional[Rgba | Hexa] = None, @@ -858,6 +877,8 @@ def input( def input( cls, prompt: object = "", + /, + *, start: str = "", end: str = "", default_color: Optional[Rgba | Hexa] = None, @@ -915,10 +936,10 @@ def input( kb.add(Keys.Any)(helper.handle_any) custom_style = Style.from_dict({"bottom-toolbar": "noreverse"}) - session: _pt.PromptSession = _pt.PromptSession( + session: _pt.PromptSession[str] = _pt.PromptSession( message=_pt.formatted_text.ANSI(FormatCodes.to_ansi(str(prompt), default_color=default_color)), validator=_ConsoleInputValidator( - get_text=helper.get_text, + helper.get_text, mask_char=mask_char, min_len=min_len, validator=validator, @@ -951,10 +972,10 @@ def input( raise @classmethod - def _add_back_removed_parts(cls, split_string: list[str], removals: tuple[tuple[int, str], ...]) -> list[str]: + def _add_back_removed_parts(cls, split_string: list[str], removals: tuple[tuple[int, str], ...], /) -> list[str]: """Adds back the removed parts into the split string parts at their original positions.""" cumulative_pos = [0] - for length in (len(s) for s in split_string): + for length in [len(part) for part in split_string]: cumulative_pos.append(cumulative_pos[-1] + length) result, offset_adjusts = split_string.copy(), [0] * len(split_string) @@ -974,7 +995,7 @@ def _add_back_removed_parts(cls, split_string: list[str], removals: tuple[tuple[ return result @staticmethod - def _find_string_part(pos: int, cumulative_pos: list[int]) -> int: + def _find_string_part(pos: int, cumulative_pos: list[int], /) -> int: """Finds the index of the string part that contains the given position.""" left, right = 0, len(cumulative_pos) - 1 while left < right: @@ -990,14 +1011,19 @@ def _find_string_part(pos: int, cumulative_pos: list[int]) -> int: @staticmethod def _prepare_log_box( values: list[object] | tuple[object, ...], + /, default_color: Optional[Rgba | Hexa] = None, + *, has_rules: bool = False, ) -> tuple[list[str], list[str], int]: """Prepares the log box content and returns it along with the max line length.""" if has_rules: - lines = [] + lines: list[str] = [] + for val in values: - val_str, result_parts, current_pos = str(val), [], 0 + result_parts: list[str] = [] + val_str, current_pos = str(val), 0 + for match in _PATTERNS.hr.finditer(val_str): start, end = match.span() should_split_before = start > 0 and val_str[start - 1] != "\n" @@ -1027,19 +1053,19 @@ def _prepare_log_box( else: lines = [line for val in values for line in str(val).splitlines()] - unfmt_lines = [cast(str, FormatCodes.remove(line, default_color)) for line in lines] + unfmt_lines = [FormatCodes.remove(line, default_color) for line in lines] max_line_len = max(len(line) for line in unfmt_lines) if unfmt_lines else 0 return lines, unfmt_lines, max_line_len @staticmethod - def _multiline_input_submit(event: KeyPressEvent) -> None: + def _multiline_input_submit(event: KeyPressEvent, /) -> None: event.app.exit(result=event.app.current_buffer.document.text) class _ConsoleArgsParseHelper: """Internal, callable helper class to parse command-line arguments.""" - def __init__(self, arg_parse_configs: ArgParseConfigs, flag_value_sep: str): + def __init__(self, arg_parse_configs: ArgParseConfigs, /, flag_value_sep: str): self.arg_parse_configs = arg_parse_configs self.flag_value_sep = flag_value_sep @@ -1078,7 +1104,7 @@ def parse_arg_configs(self) -> None: ) self.arg_lookup[flag] = alias - def _parse_arg_config(self, alias: str, config: ArgParseConfig) -> Optional[set[str]]: + def _parse_arg_config(self, alias: str, config: ArgParseConfig, /) -> Optional[set[str]]: """Parse an individual argument configuration.""" # POSITIONAL ARGUMENT CONFIGURATION if isinstance(config, str): @@ -1110,25 +1136,19 @@ def _parse_arg_config(self, alias: str, config: ArgParseConfig) -> Optional[set[ return config # SET OF FLAGS WITH SPECIFIED DEFAULT VALUE - elif isinstance(config, dict): - if not config.get("flags"): + else: + if not config["flags"]: raise ValueError( f"No flags provided under alias '{alias}'.\n" "The 'flags'-key set must contain at least one flag to search for." ) self.parsed_args[alias] = ParsedArgData( exists=False, - values=[default] if (default := config.get("default")) is not None else [], + values=[config["default"]], is_pos=False, ) return config["flags"] - else: - raise TypeError( - f"Invalid configuration type under alias '{alias}'.\n" - "Must be a set, dict, literal 'before' or literal 'after'." - ) - def find_flag_positions(self) -> None: """Find positions of first and last flags for positional argument collection.""" i = 0 @@ -1174,7 +1194,7 @@ def process_positional_args(self) -> None: "Must be either 'before' or 'after'." ) - def _collect_before_arg(self, alias: str) -> None: + def _collect_before_arg(self, alias: str, /) -> None: """Collect positional `"before"` arguments.""" before_args: list[str] = [] end_pos: int = self.first_flag_pos if self.first_flag_pos is not None else self.args_len @@ -1187,7 +1207,7 @@ def _collect_before_arg(self, alias: str) -> None: self.parsed_args[alias].values = before_args self.parsed_args[alias].exists = len(before_args) > 0 - def _collect_after_arg(self, alias: str) -> None: + def _collect_after_arg(self, alias: str, /) -> None: """Collect positional `"after"` arguments.""" after_args: list[str] = [] start_pos: int = (self.last_flag_pos + 1) if self.last_flag_pos is not None else 0 @@ -1216,7 +1236,7 @@ def _collect_after_arg(self, alias: str) -> None: self.parsed_args[alias].values = after_args self.parsed_args[alias].exists = len(after_args) > 0 - def _is_positional_arg(self, arg: str, allow_separator: bool = True) -> bool: + def _is_positional_arg(self, arg: str, /, *, allow_separator: bool = True) -> bool: """Check if an argument is positional (not a flag or separator).""" if self.flag_value_sep in arg and arg.split(self.flag_value_sep, 1)[0].strip() not in self.arg_lookup: return True @@ -1269,11 +1289,11 @@ def process_flagged_args(self) -> None: class _ConsoleLogBoxBgReplacer: """Internal, callable class to replace matched text with background-colored text for log boxes.""" - def __init__(self, box_bg_color: str | Rgba | Hexa) -> None: + def __init__(self, box_bg_color: str | Rgba | Hexa, /) -> None: self.box_bg_color = box_bg_color - def __call__(self, m: _rx.Match[str]) -> str: - return f"{cast(str, m.group(0))}[bg:{self.box_bg_color}]" + def __call__(self, m: _rx.Match[str], /) -> str: + return f"{m.group(0)}[bg:{self.box_bg_color}]" class _ConsoleInputHelper: @@ -1334,7 +1354,7 @@ def bottom_toolbar(self) -> _pt.formatted_text.ANSI: except Exception: return _pt.formatted_text.ANSI("") - def process_insert_text(self, text: str) -> tuple[str, set[str]]: + def process_insert_text(self, text: str, /) -> tuple[str, set[str]]: """Processes the inserted text according to the allowed characters and max length.""" removed_chars: set[str] = set() @@ -1360,7 +1380,7 @@ def process_insert_text(self, text: str) -> tuple[str, set[str]]: return processed_text, removed_chars - def insert_text_event(self, event: KeyPressEvent) -> None: + def insert_text_event(self, event: KeyPressEvent, /) -> None: """Handles text insertion events (typing/pasting).""" try: if not (insert_text := event.data): @@ -1381,7 +1401,7 @@ def insert_text_event(self, event: KeyPressEvent) -> None: except Exception: pass - def remove_text_event(self, event: KeyPressEvent, is_backspace: bool = False) -> None: + def remove_text_event(self, event: KeyPressEvent, /, *, is_backspace: bool = False) -> None: """Handles text removal events (backspace/delete).""" try: buffer = event.app.current_buffer @@ -1406,26 +1426,26 @@ def remove_text_event(self, event: KeyPressEvent, is_backspace: bool = False) -> except Exception: pass - def handle_delete(self, event: KeyPressEvent) -> None: + def handle_delete(self, event: KeyPressEvent, /) -> None: self.remove_text_event(event) - def handle_backspace(self, event: KeyPressEvent) -> None: + def handle_backspace(self, event: KeyPressEvent, /) -> None: self.remove_text_event(event, is_backspace=True) @staticmethod - def handle_control_a(event: KeyPressEvent) -> None: + def handle_control_a(event: KeyPressEvent, /) -> None: buffer = event.app.current_buffer buffer.cursor_position = 0 buffer.start_selection() buffer.cursor_position = len(buffer.text) - def handle_paste(self, event: KeyPressEvent) -> None: + def handle_paste(self, event: KeyPressEvent, /) -> None: if self.allow_paste: self.insert_text_event(event) else: self.tried_pasting = True - def handle_any(self, event: KeyPressEvent) -> None: + def handle_any(self, event: KeyPressEvent, /) -> None: self.insert_text_event(event) @@ -1434,6 +1454,8 @@ class _ConsoleInputValidator(Validator): def __init__( self, get_text: Callable[[], str], + /, + *, mask_char: Optional[str], min_len: Optional[int], validator: Optional[Callable[[str], Optional[str]]], @@ -1443,7 +1465,7 @@ def __init__( self.min_len = min_len self.validator = validator - def validate(self, document) -> None: + def validate(self, document: Document) -> None: text_to_validate = self.get_text() if self.mask_char else document.text if self.min_len and len(text_to_validate) < self.min_len: raise ValidationError(message="", cursor_position=len(document.text)) @@ -1475,6 +1497,7 @@ class ProgressBar: def __init__( self, + *, min_width: int = 10, max_width: int = 50, bar_format: list[str] | tuple[str, ...] = ["{l}", "▕{b}▏", "[b]({c:,})/{t:,}", "[dim](([i]({p}%)))"], @@ -1498,7 +1521,7 @@ def __init__( """A tuple of characters ordered from full to empty progress.""" self.set_width(min_width, max_width) - self.set_bar_format(bar_format, limited_bar_format, sep) + self.set_bar_format(bar_format, limited_bar_format, sep=sep) self.set_chars(chars) self._buffer: list[str] = [] @@ -1529,6 +1552,7 @@ def set_bar_format( self, bar_format: Optional[list[str] | tuple[str, ...]] = None, limited_bar_format: Optional[list[str] | tuple[str, ...]] = None, + *, sep: Optional[str] = None, ) -> None: """Set the format string used to render the progress bar.\n @@ -1546,13 +1570,13 @@ def set_bar_format( The bar format (also limited) can additionally be formatted with special formatting codes. For more detailed information about formatting codes, see the `format_codes` module documentation.""" if bar_format is not None: - if not any(_PATTERNS.bar.search(s) for s in bar_format): + if not any(_PATTERNS.bar.search(part) for part in bar_format): raise ValueError("The 'bar_format' parameter value must contain the '{bar}' or '{b}' placeholder.") self.bar_format = bar_format if limited_bar_format is not None: - if not any(_PATTERNS.bar.search(s) for s in limited_bar_format): + if not any(_PATTERNS.bar.search(part) for part in limited_bar_format): raise ValueError("The 'limited_bar_format' parameter value must contain the '{bar}' or '{b}' placeholder.") self.limited_bar_format = limited_bar_format @@ -1560,7 +1584,7 @@ def set_bar_format( if sep is not None: self.sep = sep - def set_chars(self, chars: tuple[str, ...]) -> None: + def set_chars(self, chars: tuple[str, ...], /) -> None: """Set the characters used to render the progress bar.\n -------------------------------------------------------------------------- - `chars` -⠀a tuple of characters ordered from full to empty progress
@@ -1569,12 +1593,12 @@ def set_chars(self, chars: tuple[str, ...]) -> None: empty sections. If None, uses default Unicode block characters.""" if len(chars) < 2: raise ValueError("The 'chars' parameter must contain at least two characters (full and empty).") - elif not all(isinstance(c, str) and len(c) == 1 for c in chars): + elif not all(len(char) == 1 for char in chars): raise ValueError("All elements of 'chars' must be single-character strings.") self.chars = chars - def show_progress(self, current: int, total: int, label: Optional[str] = None) -> None: + def show_progress(self, current: int, total: int, /, label: Optional[str] = None) -> None: """Show or update the progress bar.\n ------------------------------------------------------------------------------------------- - `current` -⠀the current progress value (below `0` or greater than `total` hides the bar) @@ -1612,7 +1636,7 @@ def hide_progress(self) -> None: self._stop_intercepting() @contextmanager - def progress_context(self, total: int, label: Optional[str] = None) -> Generator[ProgressUpdater, None, None]: + def progress_context(self, total: int, /, label: Optional[str] = None) -> Generator[ProgressUpdater, None, None]: """Context manager for automatic cleanup. Returns a function to update progress.\n ---------------------------------------------------------------------------------------------------- - `total` -⠀the total value representing 100% progress (must be greater than `0`) @@ -1648,7 +1672,7 @@ def progress_context(self, total: int, label: Optional[str] = None) -> Generator finally: self.hide_progress() - def _draw_progress_bar(self, current: int, total: int, label: Optional[str] = None) -> None: + def _draw_progress_bar(self, current: int, total: int, /, label: Optional[str] = None) -> None: if total <= 0 or not self._original_stdout: return @@ -1674,12 +1698,13 @@ def _get_formatted_info_and_bar_width( current: int, total: int, percentage: float, + /, label: Optional[str] = None, ) -> tuple[str, int]: - fmt_parts = [] + fmt_parts: list[str] = [] - for s in bar_format: - fmt_part = _PATTERNS.label.sub(label or "", s) + for part in bar_format: + fmt_part = _PATTERNS.label.sub(label or "", part) fmt_part = _PATTERNS.current.sub(_ProgressBarCurrentReplacer(current), fmt_part) fmt_part = _PATTERNS.total.sub(_ProgressBarTotalReplacer(total), fmt_part) fmt_part = _PATTERNS.percentage.sub(_ProgressBarPercentageReplacer(percentage), fmt_part) @@ -1694,9 +1719,9 @@ def _get_formatted_info_and_bar_width( return fmt_str, bar_width - def _create_bar(self, current: int, total: int, bar_width: int) -> str: + def _create_bar(self, current: int, total: int, bar_width: int, /) -> str: progress = current / total if total > 0 else 0 - bar = [] + bar: list[str] = [] for i in range(bar_width): pos_progress = (i + 1) / bar_width @@ -1759,7 +1784,7 @@ class _ProgressContextHelper: - `type_checking` -⠀whether to check the parameters' types: Is false per default to save performance, but can be set to true for debugging purposes.""" - def __init__(self, progress_bar: ProgressBar, total: int, label: Optional[str]): + def __init__(self, progress_bar: ProgressBar, total: int, label: Optional[str], /): self.progress_bar = progress_bar self.total = total self.current_label = label @@ -1788,16 +1813,16 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: if label is not None: self.current_label = label - self.progress_bar.show_progress(current=self.current_progress, total=self.total, label=self.current_label) + self.progress_bar.show_progress(self.current_progress, self.total, label=self.current_label) class _ProgressBarCurrentReplacer: """Internal, callable class to replace `{current}` placeholder with formatted number.""" - def __init__(self, current: int) -> None: + def __init__(self, current: int, /) -> None: self.current = current - def __call__(self, match: _rx.Match[str]) -> str: + def __call__(self, match: _rx.Match[str], /) -> str: if (sep := match.group(1)): return f"{self.current:,}".replace(",", sep) return str(self.current) @@ -1806,10 +1831,10 @@ def __call__(self, match: _rx.Match[str]) -> str: class _ProgressBarTotalReplacer: """Internal, callable class to replace `{total}` placeholder with formatted number.""" - def __init__(self, total: int) -> None: + def __init__(self, total: int, /) -> None: self.total = total - def __call__(self, match: _rx.Match[str]) -> str: + def __call__(self, match: _rx.Match[str], /) -> str: if (sep := match.group(1)): return f"{self.total:,}".replace(",", sep) return str(self.total) @@ -1818,39 +1843,40 @@ def __call__(self, match: _rx.Match[str]) -> str: class _ProgressBarPercentageReplacer: """Internal, callable class to replace `{percentage}` placeholder with formatted float.""" - def __init__(self, percentage: float) -> None: + def __init__(self, percentage: float, /) -> None: self.percentage = percentage - def __call__(self, match: _rx.Match[str]) -> str: + def __call__(self, match: _rx.Match[str], /) -> str: return f"{self.percentage:.{match.group(1) if match.group(1) else '1'}f}" -class Spinner: - """A console spinner for indeterminate processes with customizable appearance. +class Throbber: + """A console throbber for indeterminate processes with customizable appearance. This class intercepts stdout to allow printing while the animation is active.\n --------------------------------------------------------------------------------------------- - `label` -⠀the current label text - - `spinner_format` -⠀the format string used to render the spinner, containing placeholders: + - `throbber_format` -⠀the format string used to render the throbber, containing placeholders: * `{label}` `{l}` * `{animation}` `{a}` - `frames` -⠀a tuple of strings representing the animation frames - `interval` -⠀the time in seconds between each animation frame --------------------------------------------------------------------------------------------- - The `spinner_format` can additionally be formatted with special formatting codes. For more + The `throbber_format` can additionally be formatted with special formatting codes. For more detailed information about formatting codes, see the `format_codes` module documentation.""" def __init__( self, + *, label: Optional[str] = None, - spinner_format: list[str] | tuple[str, ...] = ["{l}", "[b]({a}) "], + throbber_format: list[str] | tuple[str, ...] = ["{l}", "[b]({a}) "], sep: str = " ", frames: tuple[str, ...] = ("· ", "·· ", "···", " ··", " ·", " ·", " ··", "···", "·· ", "· "), interval: float = 0.2, ): - self.spinner_format: list[str] | tuple[str, ...] - """The format strings used to render the spinner (joined by `sep`).""" + self.throbber_format: list[str] | tuple[str, ...] + """The format strings used to render the throbber (joined by `sep`).""" self.sep: str - """The separator string used to join multiple spinner-format strings.""" + """The separator string used to join multiple throbber-format strings.""" self.frames: tuple[str, ...] """A tuple of strings representing the animation frames.""" self.interval: float @@ -1858,10 +1884,10 @@ def __init__( self.label: Optional[str] """The current label text.""" self.active: bool = False - """Whether the spinner is currently active (intercepting stdout) or not.""" + """Whether the throbber is currently active (intercepting stdout) or not.""" self.update_label(label) - self.set_format(spinner_format, sep) + self.set_format(throbber_format, sep=sep) self.set_frames(frames) self.set_interval(interval) @@ -1873,23 +1899,23 @@ def __init__( self._stop_event: Optional[_threading.Event] = None self._animation_thread: Optional[_threading.Thread] = None - def set_format(self, spinner_format: list[str] | tuple[str, ...], sep: Optional[str] = None) -> None: - """Set the format string used to render the spinner.\n + def set_format(self, throbber_format: list[str] | tuple[str, ...], *, sep: Optional[str] = None) -> None: + """Set the format string used to render the throbber.\n --------------------------------------------------------------------------------------------- - - `spinner_format` -⠀the format strings used to render the spinner, containing placeholders: + - `throbber_format` -⠀the format strings used to render the throbber, containing placeholders: * `{label}` `{l}` * `{animation}` `{a}` - `sep` -⠀the separator string used to join multiple format strings""" - if not any(_PATTERNS.animation.search(fmt) for fmt in spinner_format): + if not any(_PATTERNS.animation.search(fmt) for fmt in throbber_format): raise ValueError( - "At least one format string in 'spinner_format' must contain the '{animation}' or '{a}' placeholder." + "At least one format string in 'throbber_format' must contain the '{animation}' or '{a}' placeholder." ) - self.spinner_format = spinner_format + self.throbber_format = throbber_format self.sep = sep or self.sep - def set_frames(self, frames: tuple[str, ...]) -> None: - """Set the frames used for the spinner animation.\n + def set_frames(self, frames: tuple[str, ...], /) -> None: + """Set the frames used for the throbber animation.\n --------------------------------------------------------------------- - `frames` -⠀a tuple of strings representing the animation frames""" if len(frames) < 2: @@ -1897,7 +1923,7 @@ def set_frames(self, frames: tuple[str, ...]) -> None: self.frames = frames - def set_interval(self, interval: int | float) -> None: + def set_interval(self, interval: int | float, /) -> None: """Set the time interval between each animation frame.\n ------------------------------------------------------------------- - `interval` -⠀the time in seconds between each animation frame""" @@ -1906,10 +1932,10 @@ def set_interval(self, interval: int | float) -> None: self.interval = interval - def start(self, label: Optional[str] = None) -> None: - """Start the spinner animation and intercept stdout.\n + def start(self, label: Optional[str] = None, /) -> None: + """Start the throbber animation and intercept stdout.\n ---------------------------------------------------------- - - `label` -⠀the label to display alongside the spinner""" + - `label` -⠀the label to display alongside the throbber""" if self.active: return @@ -1920,7 +1946,7 @@ def start(self, label: Optional[str] = None) -> None: self._animation_thread.start() def stop(self) -> None: - """Stop and hide the spinner and restore normal console output.""" + """Stop and hide the throbber and restore normal console output.""" if self.active: if self._stop_event: self._stop_event.set() @@ -1931,27 +1957,27 @@ def stop(self) -> None: self._animation_thread = None self._frame_index = 0 - self._clear_spinner_line() + self._clear_throbber_line() self._stop_intercepting() - def update_label(self, label: Optional[str]) -> None: - """Update the spinner's label text.\n + def update_label(self, label: Optional[str], /) -> None: + """Update the throbber's label text.\n -------------------------------------- - `new_label` -⠀the new label text""" self.label = label @contextmanager - def context(self, label: Optional[str] = None) -> Generator[Callable[[str], None], None, None]: + def context(self, label: Optional[str] = None, /) -> Generator[Callable[[str], None], None, None]: """Context manager for automatic cleanup. Returns a function to update the label.\n ---------------------------------------------------------------------------------------------- - - `label` -⠀the label to display alongside the spinner + - `label` -⠀the label to display alongside the throbber ----------------------------------------------------------------------------------------------- The returned callable accepts a single parameter: - `new_label` -⠀the new label text\n #### Example usage: ```python - with Spinner().context("Starting...") as update_label: + with Throbber().context("Starting...") as update_label: time.sleep(2) update_label("Processing...") time.sleep(3) @@ -1979,10 +2005,8 @@ def _animation_loop(self) -> None: frame = FormatCodes.to_ansi(f"{self.frames[self._frame_index % len(self.frames)]}[*]") formatted = FormatCodes.to_ansi(self.sep.join( - s for s in ( \ - _PATTERNS.animation.sub(frame, _PATTERNS.label.sub(self.label or "", s)) - for s in self.spinner_format - ) if s + fmt_part for part in self.throbber_format if \ + (fmt_part := _PATTERNS.animation.sub(frame, _PATTERNS.label.sub(self.label or "", part))) )) self._current_animation_str = formatted @@ -2018,14 +2042,14 @@ def _emergency_cleanup(self) -> None: except Exception: pass - def _clear_spinner_line(self) -> None: + def _clear_throbber_line(self) -> None: if self._last_line_len > 0 and self._original_stdout: self._original_stdout.write(f"{ANSI.CHAR}[2K\r") self._original_stdout.flush() def _flush_buffer(self) -> None: if self._buffer and self._original_stdout: - self._clear_spinner_line() + self._clear_throbber_line() for content in self._buffer: self._original_stdout.write(content) self._original_stdout.flush() @@ -2041,29 +2065,29 @@ def _redraw_display(self) -> None: class _InterceptedOutput: """Custom StringIO that captures output and stores it in the progress bar buffer.""" - def __init__(self, progress_bar: ProgressBar | Spinner): - self.progress_bar = progress_bar + def __init__(self, status_indicator: ProgressBar | Throbber, /): + self.status_indicator = status_indicator self.string_io = StringIO() - def write(self, content: str) -> int: + def write(self, content: str, /) -> int: self.string_io.write(content) try: if content and content != "\r": - self.progress_bar._buffer.append(content) + cast(ProgressBar | Throbber, self.status_indicator)._buffer.append(content) # type: ignore[protected-access] return len(content) except Exception: - self.progress_bar._emergency_cleanup() + self.status_indicator._emergency_cleanup() # type: ignore[protected-access] raise def flush(self) -> None: self.string_io.flush() try: - if self.progress_bar.active and self.progress_bar._buffer: - self.progress_bar._flush_buffer() - self.progress_bar._redraw_display() + if self.status_indicator.active and self.status_indicator._buffer: # type: ignore[protected-access] + self.status_indicator._flush_buffer() # type: ignore[protected-access] + self.status_indicator._redraw_display() # type: ignore[protected-access] except Exception: - self.progress_bar._emergency_cleanup() + self.status_indicator._emergency_cleanup() # type: ignore[protected-access] raise - def __getattr__(self, name: str) -> Any: + def __getattr__(self, name: str, /) -> Any: return getattr(self.string_io, name) diff --git a/src/xulbux/data.py b/src/xulbux/data.py index 06d1c81..f7e6109 100644 --- a/src/xulbux/data.py +++ b/src/xulbux/data.py @@ -3,18 +3,20 @@ methods to work with nested data structures. """ -from .base.types import DataStructureTypes, IndexIterableTypes, DataStructure, IndexIterable +from .base.types import IndexIterableTT, IndexIterable, DataObjTT, DataObj as DataObjType from .format_codes import FormatCodes from .string import String from .regex import Regex -from typing import Optional, Literal, Final, Any, cast +from typing import Optional, Literal, TypeVar, Final, Any, overload, cast import base64 as _base64 import math as _math import re as _re +DataObj = TypeVar("DataObj", bound=DataObjType) + _DEFAULT_SYNTAX_HL: Final[dict[str, tuple[str, str]]] = { "str": ("[br:blue]", "[_c]"), "number": ("[br:magenta]", "[_c]"), @@ -29,21 +31,21 @@ class Data: """This class includes methods to work with nested data structures (dictionaries and lists).""" @classmethod - def serialize_bytes(cls, data: bytes | bytearray) -> dict[str, str]: + def serialize_bytes(cls, data: bytes | bytearray, /) -> dict[str, str]: """Converts bytes or bytearray to a JSON-compatible format (dictionary) with explicit keys.\n ---------------------------------------------------------------------------------------------- - `data` -⠀the bytes or bytearray to serialize""" key = "bytearray" if isinstance(data, bytearray) else "bytes" try: - return {key: cast(bytes | bytearray, data).decode("utf-8"), "encoding": "utf-8"} + return {key: data.decode("utf-8"), "encoding": "utf-8"} except UnicodeDecodeError: pass return {key: _base64.b64encode(data).decode("utf-8"), "encoding": "base64"} @classmethod - def deserialize_bytes(cls, obj: dict[str, str]) -> bytes | bytearray: + def deserialize_bytes(cls, obj: dict[str, str], /) -> bytes | bytearray: """Tries to converts a JSON-compatible bytes/bytearray format (dictionary) back to its original type.\n -------------------------------------------------------------------------------------------------------- - `obj` -⠀the dictionary to deserialize\n @@ -64,71 +66,92 @@ def deserialize_bytes(cls, obj: dict[str, str]) -> bytes | bytearray: raise ValueError(f"Invalid serialized data:\n {obj}") @classmethod - def chars_count(cls, data: DataStructure) -> int: + def chars_count(cls, data: DataObjType, /) -> int: """The sum of all the characters amount including the keys in dictionaries.\n ------------------------------------------------------------------------------ - `data` -⠀the data structure to count the characters from""" chars_count = 0 if isinstance(data, dict): - for k, v in data.items(): - chars_count += len(str(k)) + (cls.chars_count(v) if isinstance(v, DataStructureTypes) else len(str(v))) - - elif isinstance(data, IndexIterableTypes): + for key, val in data.items(): + chars_count += len(str(key)) + ( + cls.chars_count(cast(DataObjType, val)) \ + if isinstance(val, DataObjTT) + else len(str(val)) + ) + else: for item in data: - chars_count += cls.chars_count(item) if isinstance(item, DataStructureTypes) else len(str(item)) + chars_count += ( + cls.chars_count(cast(DataObjType, item)) \ + if isinstance(item, DataObjTT) + else len(str(item)) + ) return chars_count @classmethod - def strip(cls, data: DataStructure) -> DataStructure: + def strip(cls, data: DataObj, /) -> DataObj: """Removes leading and trailing whitespaces from the data structure's items.\n ------------------------------------------------------------------------------- - `data` -⠀the data structure to strip the items from""" if isinstance(data, dict): - return {k.strip(): cls.strip(v) if isinstance(v, DataStructureTypes) else v.strip() for k, v in data.items()} - - if isinstance(data, IndexIterableTypes): - return type(data)(cls.strip(item) if isinstance(item, DataStructureTypes) else item.strip() for item in data) + return type(data)({key.strip(): ( + cls.strip(cast(DataObjType, val)) \ + if isinstance(val, DataObjTT) + else val.strip() + ) for key, val in data.items()}) - raise TypeError(f"Unsupported data structure type: {type(data)}") + else: + return cast(DataObj, type(data)([ + cls.strip(cast(DataObjType, item)) \ + if isinstance(item, DataObjTT) + else item.strip() + for item in data + ])) @classmethod - def remove_empty_items(cls, data: DataStructure, spaces_are_empty: bool = False) -> DataStructure: + def remove_empty_items(cls, data: DataObj, /, *, spaces_are_empty: bool = False) -> DataObj: """Removes empty items from the data structure.\n --------------------------------------------------------------------------------- - `data` -⠀the data structure to remove empty items from. - `spaces_are_empty` -⠀if true, it will count items with only spaces as empty""" if isinstance(data, dict): - return { - k: (v if not isinstance(v, DataStructureTypes) else cls.remove_empty_items(v, spaces_are_empty)) - for k, v in data.items() if not String.is_empty(v, spaces_are_empty) - } - - if isinstance(data, IndexIterableTypes): - return type(data)( - item for item in - ( - (item if not isinstance(item, DataStructureTypes) else cls.remove_empty_items(item, spaces_are_empty)) \ - for item in data if not (isinstance(item, (str, type(None))) and String.is_empty(item, spaces_are_empty)) + return type(data)({ + key: ( + val if not isinstance(val, DataObjTT) else + cls.remove_empty_items(cast(DataObjType, val), spaces_are_empty=spaces_are_empty) ) - if item not in ([], (), {}, set(), frozenset()) - ) + for key, val in data.items() if not String.is_empty(val, spaces_are_empty=spaces_are_empty) + }) - raise TypeError(f"Unsupported data structure type: {type(data)}") + else: + return cast(DataObj, type(data)([ + item for item in [ + ( + item \ + if not isinstance(item, DataObjTT) + else cls.remove_empty_items(cast(DataObjType, item), spaces_are_empty=spaces_are_empty) + ) + for item in data + if not (isinstance(item, (str, type(None))) and String.is_empty(item, spaces_are_empty=spaces_are_empty)) + ] if item not in ([], (), {}, set(), frozenset()) + ])) @classmethod - def remove_duplicates(cls, data: DataStructure) -> DataStructure: + def remove_duplicates(cls, data: DataObj, /) -> DataObj: """Removes all duplicates from the data structure.\n ----------------------------------------------------------- - `data` -⠀the data structure to remove duplicates from""" if isinstance(data, dict): - return {k: cls.remove_duplicates(v) if isinstance(v, DataStructureTypes) else v for k, v in data.items()} + return type(data)({ + key: cls.remove_duplicates(cast(DataObjType, val)) if isinstance(val, DataObjTT) else val + for key, val in data.items() + }) - if isinstance(data, (list, tuple)): + elif isinstance(data, (list, tuple)): result: list[Any] = [] for item in data: - processed_item = cls.remove_duplicates(item) if isinstance(item, DataStructureTypes) else item + processed_item = cls.remove_duplicates(cast(DataObjType, item)) if isinstance(item, DataObjTT) else item is_duplicate: bool = False for existing_item in result: @@ -139,25 +162,25 @@ def remove_duplicates(cls, data: DataStructure) -> DataStructure: if not is_duplicate: result.append(processed_item) - return type(data)(result) + return cast(DataObj, type(data)(result)) - if isinstance(data, (set, frozenset)): - processed_elements = set() + else: + processed_elements: set[Any] = set() for item in data: - processed_item = cls.remove_duplicates(item) if isinstance(item, DataStructureTypes) else item + processed_item = cls.remove_duplicates(cast(DataObjType, item)) if isinstance(item, DataObjTT) else item processed_elements.add(processed_item) - return type(data)(processed_elements) - - raise TypeError(f"Unsupported data structure type: {type(data)}") + return cast(DataObj, type(data)(processed_elements)) @classmethod def remove_comments( cls, - data: DataStructure, + data: DataObj, + /, + *, comment_start: str = ">>", comment_end: str = "<<", comment_sep: str = "", - ) -> DataStructure: + ) -> DataObj: """Remove comments from a list, tuple or dictionary.\n --------------------------------------------------------------------------------------------------------------- - `data` -⠀list, tuple or dictionary, where the comments should get removed from @@ -209,19 +232,21 @@ def remove_comments( if len(comment_start) == 0: raise ValueError("The 'comment_start' parameter string must not be empty.") - return _DataRemoveCommentsHelper( - data=data, + return cast(DataObj, _DataRemoveCommentsHelper( + data, comment_start=comment_start, comment_end=comment_end, comment_sep=comment_sep, - )() + )()) @classmethod def is_equal( cls, - data1: DataStructure, - data2: DataStructure, + data1: DataObjType, + data2: DataObjType, + /, ignore_paths: str | list[str] = "", + *, path_sep: str = "->", comment_start: str = ">>", comment_end: str = "<<", @@ -247,16 +272,63 @@ def is_equal( ignore_paths = [ignore_paths] return cls._compare_nested( - data1=cls.remove_comments(data1, comment_start, comment_end), - data2=cls.remove_comments(data2, comment_start, comment_end), + cls.remove_comments(data1, comment_start=comment_start, comment_end=comment_end), + cls.remove_comments(data2, comment_start=comment_start, comment_end=comment_end), ignore_paths=[str(path).split(path_sep) for path in ignore_paths if path], ) + @overload + @classmethod + def get_path_id( + cls, + data: DataObjType, + value_paths: str, + /, + *, + path_sep: str = "->", + comment_start: str = ">>", + comment_end: str = "<<", + ignore_not_found: bool = False, + ) -> Optional[str]: + ... + + @overload + @classmethod + def get_path_id( + cls, + data: DataObjType, + value_paths: list[str], + /, + *, + path_sep: str = "->", + comment_start: str = ">>", + comment_end: str = "<<", + ignore_not_found: bool = False, + ) -> list[Optional[str]]: + ... + + @overload + @classmethod + def get_path_id( + cls, + data: DataObjType, + value_paths: str | list[str], + /, + *, + path_sep: str = "->", + comment_start: str = ">>", + comment_end: str = "<<", + ignore_not_found: bool = False, + ) -> Optional[str | list[Optional[str]]]: + ... + @classmethod def get_path_id( cls, - data: DataStructure, + data: DataObjType, value_paths: str | list[str], + /, + *, path_sep: str = "->", comment_start: str = ">>", comment_end: str = "<<", @@ -288,40 +360,45 @@ def get_path_id( if len(path_sep) == 0: raise ValueError("The 'path_sep' parameter string must not be empty.") - data = cls.remove_comments(data, comment_start, comment_end) + data = cls.remove_comments(data, comment_start=comment_start, comment_end=comment_end) if isinstance(value_paths, str): - return _DataGetPathIdHelper(value_paths, path_sep, data, ignore_not_found)() + return _DataGetPathIdHelper(value_paths, path_sep=path_sep, data_obj=data, ignore_not_found=ignore_not_found)() - results = [_DataGetPathIdHelper(path, path_sep, data, ignore_not_found)() for path in value_paths] + results = [ + _DataGetPathIdHelper(path, path_sep=path_sep, data_obj=data, ignore_not_found=ignore_not_found)() + for path in value_paths + ] return results if len(results) > 1 else results[0] if results else None @classmethod - def get_value_by_path_id(cls, data: DataStructure, path_id: str, get_key: bool = False) -> Any: + def get_value_by_path_id(cls, data: DataObjType, path_id: str, /, *, get_key: bool = False) -> Any: """Retrieves the value from `data` using the provided `path_id`, as long as the data structure hasn't changed since creating the path ID.\n -------------------------------------------------------------------------------------------------- - `data` -⠀the list, tuple, or dictionary to retrieve the value from - `path_id` -⠀the path ID to the value to retrieve, created before using `Data.get_path_id()` - `get_key` -⠀if true and the final item is in a dict, it returns the key instead of the value""" - parent: Optional[DataStructure] = None + parent: Optional[DataObjType] = None path = cls._sep_path_id(path_id) current_data: Any = data for i, path_idx in enumerate(path): if isinstance(current_data, dict): - keys = list(current_data.keys()) + dict_data = cast(dict[Any, Any], current_data) + keys: list[str] = list(dict_data.keys()) if i == len(path) - 1 and get_key: return keys[path_idx] - parent = current_data - current_data = current_data[keys[path_idx]] + parent = dict_data + current_data = dict_data[keys[path_idx]] - elif isinstance(current_data, IndexIterableTypes): + elif isinstance(current_data, IndexIterableTT): + idx_iterable_data = cast(IndexIterable, current_data) if i == len(path) - 1 and get_key: if parent is None or not isinstance(parent, dict): raise ValueError(f"Cannot get key from a non-dict parent at path '{path[:i + 1]}'") - return next(key for key, value in parent.items() if value is current_data) - parent = current_data - current_data = list(current_data)[path_idx] # CONVERT TO LIST FOR INDEXING + return next(key for key, value in parent.items() if value is idx_iterable_data) + parent = idx_iterable_data + current_data = list(idx_iterable_data)[path_idx] # CONVERT TO LIST FOR INDEXING else: raise TypeError(f"Unsupported type '{type(current_data)}' at path '{path[:i + 1]}'") @@ -329,7 +406,7 @@ def get_value_by_path_id(cls, data: DataStructure, path_id: str, get_key: bool = return current_data @classmethod - def set_value_by_path_id(cls, data: DataStructure, update_values: dict[str, Any]) -> DataStructure: + def set_value_by_path_id(cls, data: DataObj, update_values: dict[str, Any], /) -> DataObj: """Updates the value/s from `update_values` in the `data`, as long as the data structure hasn't changed since creating the path ID to that value.\n ----------------------------------------------------------------------------------------- @@ -344,14 +421,16 @@ def set_value_by_path_id(cls, data: DataStructure, update_values: dict[str, Any] raise ValueError(f"No valid 'update_values' found in dictionary:\n{update_values!r}") for path_id, new_val in valid_update_values: - data = cls._set_nested_val(data, id_path=cls._sep_path_id(path_id), value=new_val) + data = cls._set_nested_val(data, cls._sep_path_id(path_id), new_val) return data @classmethod def render( cls, - data: DataStructure, + data: DataObjType, + /, + *, indent: int = 4, compactness: Literal[0, 1, 2] = 1, max_width: int = 127, @@ -395,7 +474,7 @@ def render( return _DataRenderHelper( cls, - data=data, + data, indent=indent, compactness=compactness, max_width=max_width, @@ -407,7 +486,9 @@ def render( @classmethod def print( cls, - data: DataStructure, + data: DataObjType, + /, + *, indent: int = 4, compactness: Literal[0, 1, 2] = 1, max_width: int = 127, @@ -448,7 +529,7 @@ def print( For more detailed information about formatting codes, see the `format_codes` module documentation.""" FormatCodes.print( cls.render( - data=data, + data, indent=indent, compactness=compactness, max_width=max_width, @@ -464,6 +545,7 @@ def _compare_nested( cls, data1: Any, data2: Any, + /, ignore_paths: list[list[str]], current_path: list[str] = [], ) -> bool: @@ -474,24 +556,26 @@ def _compare_nested( return False if isinstance(data1, dict) and isinstance(data2, dict): - if set(data1.keys()) != set(data2.keys()): + dict_data1, dict_data2 = cast(dict[Any, Any], data1), cast(dict[Any, Any], data2) + if set(dict_data1.keys()) != set(dict_data2.keys()): return False return all(cls._compare_nested( \ - data1=data1[key], - data2=data2[key], + dict_data1[key], + dict_data2[key], ignore_paths=ignore_paths, current_path=current_path + [key], - ) for key in data1) + ) for key in dict_data1) - elif isinstance(data1, (list, tuple)): - if len(data1) != len(data2): + elif isinstance(data1, (list, tuple)) and isinstance(data2, (list, tuple)): + array_data1, array_data2 = cast(IndexIterable, data1), cast(IndexIterable, data2) + if len(array_data1) != len(array_data2): return False return all(cls._compare_nested( \ - data1=item1, - data2=item2, + item1, + item2, ignore_paths=ignore_paths, current_path=current_path + [str(i)], - ) for i, (item1, item2) in enumerate(zip(data1, data2))) + ) for i, (item1, item2) in enumerate(zip(array_data1, array_data2))) elif isinstance(data1, (set, frozenset)): return data1 == data2 @@ -499,7 +583,7 @@ def _compare_nested( return data1 == data2 @staticmethod - def _sep_path_id(path_id: str) -> list[int]: + def _sep_path_id(path_id: str, /) -> list[int]: """Internal method to separate a path-ID string into its ID parts as a list of integers.""" if len(split_id := path_id.split(">")) == 2: id_part_len, path_id_parts = split_id @@ -513,29 +597,33 @@ def _sep_path_id(path_id: str) -> list[int]: raise ValueError(f"Path ID '{path_id}' is an invalid format.") @classmethod - def _set_nested_val(cls, data: DataStructure, id_path: list[int], value: Any) -> Any: + def _set_nested_val(cls, data: DataObjType, id_path: list[int], value: Any, /) -> Any: """Internal method to set a value in a nested data structure based on the provided ID path.""" current_data: Any = data if len(id_path) == 1: if isinstance(current_data, dict): - keys, data_dict = list(current_data.keys()), dict(current_data) - data_dict[keys[id_path[0]]] = value - return data_dict - elif isinstance(current_data, IndexIterableTypes): - was_t, data_list = type(current_data), list(current_data) - data_list[id_path[0]] = value - return was_t(data_list) + dict_data = cast(dict[Any, Any], current_data) + keys, dict_data = list(dict_data.keys()), dict(dict_data) + dict_data[keys[id_path[0]]] = value + return dict_data + elif isinstance(current_data, IndexIterableTT): + idx_iterable_data = cast(IndexIterable, current_data) + was_t, idx_iterable_data = type(idx_iterable_data), list(idx_iterable_data) + idx_iterable_data[id_path[0]] = value + return was_t(idx_iterable_data) else: if isinstance(current_data, dict): - keys, data_dict = list(current_data.keys()), dict(current_data) - data_dict[keys[id_path[0]]] = cls._set_nested_val(data_dict[keys[id_path[0]]], id_path[1:], value) - return data_dict - elif isinstance(current_data, IndexIterableTypes): - was_t, data_list = type(current_data), list(current_data) - data_list[id_path[0]] = cls._set_nested_val(data_list[id_path[0]], id_path[1:], value) - return was_t(data_list) + dict_data = cast(dict[Any, Any], current_data) + keys, dict_data = list(dict_data.keys()), dict(dict_data) + dict_data[keys[id_path[0]]] = cls._set_nested_val(dict_data[keys[id_path[0]]], id_path[1:], value) + return dict_data + elif isinstance(current_data, IndexIterableTT): + idx_iterable_data = cast(IndexIterable, current_data) + was_t, idx_iterable_data = type(idx_iterable_data), list(idx_iterable_data) + idx_iterable_data[id_path[0]] = cls._set_nested_val(idx_iterable_data[id_path[0]], id_path[1:], value) + return was_t(idx_iterable_data) return current_data @@ -543,13 +631,13 @@ def _set_nested_val(cls, data: DataStructure, id_path: list[int], value: Any) -> class _DataRemoveCommentsHelper: """Internal, callable helper class to remove all comments from nested data structures.""" - def __init__(self, data: DataStructure, comment_start: str, comment_end: str, comment_sep: str): + def __init__(self, data: DataObjType, /, *, comment_start: str, comment_end: str, comment_sep: str): self.data = data self.comment_start = comment_start self.comment_end = comment_end self.comment_sep = comment_sep - self.pattern = _re.compile(Regex._clean( \ + self.pattern = _re.compile(Regex._clean( # type: ignore[protected-access] rf"""^( (?:(?!{_re.escape(comment_start)}).)* ) @@ -559,21 +647,23 @@ def __init__(self, data: DataStructure, comment_start: str, comment_end: str, co (.*?)$""" )) if len(comment_end) > 0 else None - def __call__(self) -> DataStructure: + def __call__(self) -> DataObjType: return self.remove_nested_comments(self.data) - def remove_nested_comments(self, item: Any) -> Any: + def remove_nested_comments(self, item: Any, /) -> Any: if isinstance(item, dict): + dict_item = cast(dict[Any, Any], item) return { key: val - for key, val in ( \ - (self.remove_nested_comments(k), self.remove_nested_comments(v)) for k, v in item.items() - ) if key is not None + for key, val in [ + (self.remove_nested_comments(k), self.remove_nested_comments(v)) for k, v in dict_item.items() \ + ] if key is not None } - if isinstance(item, IndexIterableTypes): - processed = (v for v in map(self.remove_nested_comments, item) if v is not None) - return type(item)(processed) + if isinstance(item, IndexIterableTT): + idx_iterable_item = cast(IndexIterable, item) + processed = [val for val in map(self.remove_nested_comments, idx_iterable_item) if val is not None] + return type(idx_iterable_item)(processed) if isinstance(item, str): if self.pattern: @@ -590,7 +680,7 @@ def remove_nested_comments(self, item: Any) -> Any: class _DataGetPathIdHelper: """Internal, callable helper class to process a data path and generate its unique path ID.""" - def __init__(self, path: str, path_sep: str, data_obj: DataStructure, ignore_not_found: bool): + def __init__(self, path: str, /, *, path_sep: str, data_obj: DataObjType, ignore_not_found: bool): self.keys = path.split(path_sep) self.data_obj = data_obj self.ignore_not_found = ignore_not_found @@ -608,14 +698,14 @@ def __call__(self) -> Optional[str]: return None return f"{self.max_id_length}>{''.join(id.zfill(self.max_id_length) for id in self.path_ids)}" - def process_key(self, key: str) -> bool: + def process_key(self, key: str, /) -> bool: """Process a single key and update `path_ids`. Returns `False` if processing should stop.""" idx: Optional[int] = None if isinstance(self.current_data, dict): if (idx := self.process_dict_key(key)) is None: return False - elif isinstance(self.current_data, IndexIterableTypes): + elif isinstance(self.current_data, IndexIterableTT): if (idx := self.process_iterable_key(key)) is None: return False else: @@ -625,7 +715,7 @@ def process_key(self, key: str) -> bool: self.max_id_length = max(self.max_id_length, len(str(idx))) return True - def process_dict_key(self, key: str) -> Optional[int]: + def process_dict_key(self, key: str, /) -> Optional[int]: """Process a key for dictionary data. Returns the index or `None` if not found.""" if key.isdigit(): if self.ignore_not_found: @@ -641,7 +731,7 @@ def process_dict_key(self, key: str) -> Optional[int]: return None raise KeyError(f"Key '{key}' not found in dict.") - def process_iterable_key(self, key: str) -> Optional[int]: + def process_iterable_key(self, key: str, /) -> Optional[int]: """Process a key for iterable data. Returns the index or `None` if not found.""" try: idx = int(key) @@ -664,7 +754,9 @@ class _DataRenderHelper: def __init__( self, cls: type[Data], - data: DataStructure, + data: DataObjType, + /, + *, indent: int, compactness: Literal[0, 1, 2], max_width: int, @@ -689,8 +781,8 @@ def __init__( raise TypeError(f"Expected 'syntax_highlighting' to be a dict or bool. Got: {type(syntax_highlighting)}") self.syntax_hl.update({ - k: (f"[{v}]", "[_]") if k in self.syntax_hl and v not in {"", None} else ("", "") - for k, v in syntax_highlighting.items() + key: (f"[{val}]", "[_]") if key in self.syntax_hl and val not in {"", None} else ("", "") + for key, val in syntax_highlighting.items() }) sep = f"{self.syntax_hl['punctuation'][0]}{sep}{self.syntax_hl['punctuation'][1]}" @@ -699,10 +791,19 @@ def __init__( punct_map: dict[str, str | tuple[str, str]] = {"(": ("/(", "("), **{c: c for c in "'\":)[]{}"}} self.punct: dict[str, str] = { - k: ((f"{self.syntax_hl['punctuation'][0]}{v[0]}{self.syntax_hl['punctuation'][1]}" if self.do_syntax_hl else v[1]) - if isinstance(v, (list, tuple)) else - (f"{self.syntax_hl['punctuation'][0]}{v}{self.syntax_hl['punctuation'][1]}" if self.do_syntax_hl else v)) - for k, v in punct_map.items() + key: ( + ( + f"{self.syntax_hl['punctuation'][0]}{val[0]}{self.syntax_hl['punctuation'][1]}" \ + if self.do_syntax_hl + else val[1] + ) \ + if isinstance(val, (list, tuple)) + else ( + f"{self.syntax_hl['punctuation'][0]}{val}{self.syntax_hl['punctuation'][1]}" \ + if self.do_syntax_hl + else val + ) + ) for key, val in punct_map.items() } def __call__(self) -> str: @@ -711,21 +812,21 @@ def __call__(self) -> str: self.format_dict(self.data, 0) if isinstance(self.data, dict) else self.format_sequence(self.data, 0) ) - def format_value(self, value: Any, current_indent: Optional[int] = None) -> str: + def format_value(self, value: Any, /, current_indent: Optional[int] = None) -> str: if current_indent is not None and isinstance(value, dict): - return self.format_dict(value, current_indent + self.indent) + return self.format_dict(cast(dict[Any, Any], value), current_indent + self.indent) elif current_indent is not None and hasattr(value, "__dict__"): return self.format_dict(value.__dict__, current_indent + self.indent) - elif current_indent is not None and isinstance(value, IndexIterableTypes): - return self.format_sequence(value, current_indent + self.indent) + elif current_indent is not None and isinstance(value, IndexIterableTT): + return self.format_sequence(cast(IndexIterable, value), current_indent + self.indent) elif current_indent is not None and isinstance(value, (bytes, bytearray)): obj_dict = self.cls.serialize_bytes(value) return ( self.format_dict(obj_dict, current_indent + self.indent) if self.as_json else ( - f"{self.syntax_hl['type'][0]}{(k := next(iter(obj_dict)))}{self.syntax_hl['type'][1]}" - + self.format_sequence((obj_dict[k], obj_dict["encoding"]), current_indent + self.indent) - if self.do_syntax_hl else (k := next(iter(obj_dict))) - + self.format_sequence((obj_dict[k], obj_dict["encoding"]), current_indent + self.indent) + f"{self.syntax_hl['type'][0]}{(key := next(iter(obj_dict)))}{self.syntax_hl['type'][1]}" + + self.format_sequence((obj_dict[key], obj_dict["encoding"]), current_indent + self.indent) + if self.do_syntax_hl else (key := next(iter(obj_dict))) + + self.format_sequence((obj_dict[key], obj_dict["encoding"]), current_indent + self.indent) ) ) elif isinstance(value, bool): @@ -754,7 +855,7 @@ def format_value(self, value: Any, current_indent: Optional[int] = None) -> str: + self.punct["'"] if self.do_syntax_hl else self.punct["'"] + String.escape(str(value), "'") + self.punct["'"] )) - def should_expand(self, seq: IndexIterable) -> bool: + def should_expand(self, seq: IndexIterable, /) -> bool: if self.compactness == 0: return True if self.compactness == 2: @@ -770,20 +871,21 @@ def should_expand(self, seq: IndexIterable) -> bool: or (complex_items == 1 and len(seq) > 1) \ or self.cls.chars_count(seq) + (len(seq) * len(self.sep)) > self.max_width - def format_dict(self, d: dict, current_indent: int) -> str: - if self.compactness == 2 or not d or not self.should_expand(list(d.values())): + def format_dict(self, data_dict: dict[Any, Any], current_indent: int, /) -> str: + if self.compactness == 2 or not data_dict or not self.should_expand(list(data_dict.values())): return self.punct["{"] + self.sep.join( - f"{self.format_value(k)}{self.punct[':']} {self.format_value(v, current_indent)}" for k, v in d.items() + f"{self.format_value(key)}{self.punct[':']} {self.format_value(val, current_indent)}" + for key, val in data_dict.items() ) + self.punct["}"] - items = [] - for k, val in d.items(): + items: list[str] = [] + for key, val in data_dict.items(): formatted_value = self.format_value(val, current_indent) - items.append(f"{' ' * (current_indent + self.indent)}{self.format_value(k)}{self.punct[':']} {formatted_value}") + items.append(f"{' ' * (current_indent + self.indent)}{self.format_value(key)}{self.punct[':']} {formatted_value}") return self.punct["{"] + "\n" + f"{self.sep}\n".join(items) + f"\n{' ' * current_indent}" + self.punct["}"] - def format_sequence(self, seq, current_indent: int) -> str: + def format_sequence(self, seq: IndexIterable, current_indent: int, /) -> str: if self.as_json: seq = list(seq) diff --git a/src/xulbux/env_path.py b/src/xulbux/env_path.py index b70d125..7831182 100644 --- a/src/xulbux/env_path.py +++ b/src/xulbux/env_path.py @@ -5,8 +5,9 @@ from .file_sys import FileSys -from typing import Optional, cast +from typing import Optional, Literal, overload from pathlib import Path +import subprocess as _subprocess import sys as _sys import os as _os @@ -14,8 +15,23 @@ class EnvPath: """This class includes methods to work with the PATH environment variable.""" + @overload @classmethod - def paths(cls, as_list: bool = False) -> Path | list[Path]: + def paths(cls, *, as_list: Literal[True]) -> list[Path]: + ... + + @overload + @classmethod + def paths(cls, *, as_list: Literal[False] = False) -> Path: + ... + + @overload + @classmethod + def paths(cls, *, as_list: bool = False) -> Path | list[Path]: + ... + + @classmethod + def paths(cls, *, as_list: bool = False) -> Path | list[Path]: """Get the PATH environment variable.\n ------------------------------------------------------------------------------------------------ - `as_list` -⠀if true, returns the paths as a list of `Path`s; otherwise, as a single `Path`""" @@ -25,39 +41,39 @@ def paths(cls, as_list: bool = False) -> Path | list[Path]: return Path(paths_str) @classmethod - def has_path(cls, path: Optional[Path | str] = None, cwd: bool = False, base_dir: bool = False) -> bool: + def has_path(cls, path: Optional[Path | str] = None, /, *, cwd: bool = False, base_dir: bool = False) -> bool: """Check if a path is present in the PATH environment variable.\n ------------------------------------------------------------------------ - `path` -⠀the path to check for - `cwd` -⠀if true, uses the current working directory as the path - `base_dir` -⠀if true, uses the script's base directory as the path""" - check_path = cls._get(path, cwd, base_dir).resolve() - return check_path in {path.resolve() for path in cast(list[Path], cls.paths(as_list=True))} + check_path = cls._get(path, cwd=cwd, base_dir=base_dir).resolve() + return check_path in {path.resolve() for path in cls.paths(as_list=True)} @classmethod - def add_path(cls, path: Optional[Path | str] = None, cwd: bool = False, base_dir: bool = False) -> None: + def add_path(cls, path: Optional[Path | str] = None, /, *, cwd: bool = False, base_dir: bool = False) -> None: """Add a path to the PATH environment variable.\n ------------------------------------------------------------------------ - `path` -⠀the path to add - `cwd` -⠀if true, uses the current working directory as the path - `base_dir` -⠀if true, uses the script's base directory as the path""" - path_obj = cls._get(path, cwd, base_dir) + path_obj = cls._get(path, cwd=cwd, base_dir=base_dir) if not cls.has_path(path_obj): cls._persistent(path_obj) @classmethod - def remove_path(cls, path: Optional[Path | str] = None, cwd: bool = False, base_dir: bool = False) -> None: + def remove_path(cls, path: Optional[Path | str] = None, /, *, cwd: bool = False, base_dir: bool = False) -> None: """Remove a path from the PATH environment variable.\n ------------------------------------------------------------------------ - `path` -⠀the path to remove - `cwd` -⠀if true, uses the current working directory as the path - `base_dir` -⠀if true, uses the script's base directory as the path""" - path_obj = cls._get(path, cwd, base_dir) + path_obj = cls._get(path, cwd=cwd, base_dir=base_dir) if cls.has_path(path_obj): cls._persistent(path_obj, remove=True) @staticmethod - def _get(path: Optional[Path | str] = None, cwd: bool = False, base_dir: bool = False) -> Path: + def _get(path: Optional[Path | str] = None, /, *, cwd: bool = False, base_dir: bool = False) -> Path: """Internal method to get the normalized `path`, CWD path or script directory path.\n -------------------------------------------------------------------------------------- Raise an error if no path is provided and neither `cwd` or `base_dir` is true.""" @@ -74,10 +90,10 @@ def _get(path: Optional[Path | str] = None, cwd: bool = False, base_dir: bool = return Path(path) if isinstance(path, str) else path @classmethod - def _persistent(cls, path: Path, remove: bool = False) -> None: + def _persistent(cls, path: Path, /, *, remove: bool = False) -> None: """Internal method to add or remove a path from the PATH environment variable, persistently, across sessions, as well as the current session.""" - current_paths = cast(list[Path], cls.paths(as_list=True)) + current_paths = cls.paths(as_list=True) path_resolved = path.resolve() if remove: @@ -120,4 +136,4 @@ def _persistent(cls, path: Path, remove: bool = False) -> None: file.truncate() - _os.system(f"source {shell_rc_file}") + _subprocess.run(f"source {shell_rc_file}", shell=True, executable='/bin/bash') diff --git a/src/xulbux/file.py b/src/xulbux/file.py index 7621b8d..a199fbc 100644 --- a/src/xulbux/file.py +++ b/src/xulbux/file.py @@ -17,6 +17,8 @@ def rename_extension( cls, file_path: Path | str, new_extension: str, + /, + *, full_extension: bool = False, camel_case_filename: bool = False, ) -> Path: @@ -48,7 +50,7 @@ def rename_extension( return path.parent / f"{filename}{new_extension}" @classmethod - def create(cls, file_path: Path | str, content: str = "", force: bool = False) -> Path: + def create(cls, file_path: Path | str, content: str = "", /, *, force: bool = False) -> Path: """Create a file with ot without content.\n ------------------------------------------------------------------ - `file_path` -⠀the path where the file should be created diff --git a/src/xulbux/file_sys.py b/src/xulbux/file_sys.py index f5ef031..5caa309 100644 --- a/src/xulbux/file_sys.py +++ b/src/xulbux/file_sys.py @@ -52,7 +52,9 @@ class FileSys(metaclass=_FileSysMeta): def extend_path( cls, rel_path: Path | str, + /, search_in: Optional[Path | str | PathsList] = None, + *, fuzzy_match: bool = False, raise_error: bool = False, ) -> Optional[Path]: @@ -89,26 +91,18 @@ def extend_path( if search_in is not None: if isinstance(search_in, (str, Path)): search_dirs.extend([Path(search_in)]) - elif isinstance(search_in, list): - search_dirs.extend([Path(path) for path in search_in]) else: - raise TypeError( - f"The 'search_in' parameter must be a string, Path, or a list of strings/Paths, got {type(search_in)}" - ) - - return _ExtendPathHelper( - cls, - rel_path=path, - search_dirs=search_dirs, - fuzzy_match=fuzzy_match, - raise_error=raise_error, - )() + search_dirs.extend([Path(path) for path in search_in]) + + return _ExtendPathHelper(cls, path, search_dirs=search_dirs, fuzzy_match=fuzzy_match, raise_error=raise_error)() @classmethod def extend_or_make_path( cls, rel_path: Path | str, + /, search_in: Optional[Path | str | list[Path | str]] = None, + *, prefer_script_dir: bool = True, fuzzy_match: bool = False, ) -> Path: @@ -131,12 +125,7 @@ def extend_or_make_path( If `prefer_script_dir` is false, it will instead make a path that points to where the `rel_path` would be in the CWD.""" try: - result = cls.extend_path( - rel_path=rel_path, - search_in=search_in, - raise_error=True, - fuzzy_match=fuzzy_match, - ) + result = cls.extend_path(rel_path, search_in=search_in, raise_error=True, fuzzy_match=fuzzy_match) return result if result is not None else Path() except PathNotFoundError: @@ -145,7 +134,7 @@ def extend_or_make_path( return base_dir / path @classmethod - def remove(cls, path: Path | str, only_content: bool = False) -> None: + def remove(cls, path: Path | str, /, *, only_content: bool = False) -> None: """Removes the directory or the directory's content at the specified path.\n ----------------------------------------------------------------------------- - `path` -⠀the path to the directory or file to remove @@ -179,7 +168,9 @@ def __init__( self, cls: type[FileSys], rel_path: Path, + /, search_dirs: list[Path], + *, fuzzy_match: bool, raise_error: bool, ): @@ -213,7 +204,7 @@ def __call__(self) -> Optional[Path]: return self.search_in_dirs(expanded_path) @staticmethod - def expand_env_vars(path: Path) -> Path: + def expand_env_vars(path: Path, /) -> Path: """Expand all environment variables in the given path.""" if "%" not in (str_path := str(path)): return path @@ -224,24 +215,20 @@ def expand_env_vars(path: Path) -> Path: return Path("".join(parts)) - def search_in_dirs(self, path: Path) -> Optional[Path]: + def search_in_dirs(self, path: Path, /) -> Optional[Path]: """Search for the path in all configured directories.""" for search_dir in self.search_dirs: if (full_path := search_dir / path).exists(): return full_path elif self.fuzzy_match: - if (match := self.find_path( \ - base_dir=search_dir, - target_path=path, - fuzzy_match=self.fuzzy_match, - )) is not None: + if (match := self.find_path(search_dir, path, fuzzy_match=self.fuzzy_match)) is not None: return match if self.raise_error: raise PathNotFoundError(f"Path {self.rel_path!r} not found in specified directories.") return None - def find_path(self, base_dir: Path, target_path: Path, fuzzy_match: bool) -> Optional[Path]: + def find_path(self, base_dir: Path, target_path: Path, /, *, fuzzy_match: bool) -> Optional[Path]: """Find a path by traversing the given parts from the base directory, optionally using closest matches for each part.""" current_path: Path = base_dir @@ -256,7 +243,7 @@ def find_path(self, base_dir: Path, target_path: Path, fuzzy_match: bool) -> Opt return current_path if current_path.exists() and current_path != base_dir else None @staticmethod - def get_closest_match(dir: Path, path_part: str) -> Optional[str]: + def get_closest_match(dir: Path, path_part: str, /) -> Optional[str]: """Internal method to get the closest matching file or folder name in the given directory for the given path part.""" try: diff --git a/src/xulbux/format_codes.py b/src/xulbux/format_codes.py index 62ea6c4..a001ea2 100644 --- a/src/xulbux/format_codes.py +++ b/src/xulbux/format_codes.py @@ -133,12 +133,12 @@ ------------------------------------------------------------------------------------------------------------------------------------ #### Additional Formatting Codes when a `default_color` is set -1. `[*]` resets everything, just like `[_]`, but the text color will remain in `default_color` - (if no `default_color` is set, it resets everything, exactly like `[_]`) -2. `[default]` will just color the text in `default_color` - (if no `default_color` is set, it's treated as an invalid formatting code) -3. `[background:default]` `[BG:default]` will color the background in `default_color` - (if no `default_color` is set, both are treated as invalid formatting codes)\n +1. `[*]` resets everything, just like `[_]`, but the text color will remain in `default_color` + (if no `default_color` is set, it resets everything, exactly like `[_]`) +2. `[default]` will just color the text in `default_color` + (if no `default_color` is set, it's treated as an invalid formatting code) +3. `[background:default]` `[BG:default]` will color the background in `default_color` + (if no `default_color` is set, both are treated as invalid formatting codes)\n Unlike the standard console colors, the default color can be changed by using the following modifiers: @@ -161,7 +161,7 @@ from .regex import LazyRegex, Regex from .color import Color, rgba, hexa -from typing import Optional, Literal, Final, cast +from typing import Optional, Literal, Final, overload, cast import ctypes as _ctypes import regex as _rx import sys as _sys @@ -245,8 +245,10 @@ def print( def input( cls, prompt: object = "", + /, default_color: Optional[Rgba | Hexa] = None, brightness_steps: int = 20, + *, reset_ansi: bool = False, ) -> str: """An input, whose `prompt` can be formatted using formatting codes.\n @@ -270,8 +272,10 @@ def input( def to_ansi( cls, string: str, + /, default_color: Optional[Rgba | Hexa] = None, brightness_steps: int = 20, + *, _default_start: bool = True, _validate_default: bool = True, ) -> str: @@ -301,8 +305,14 @@ def to_ansi( string = _PATTERNS.star_reset.sub(r"[\1_\2]", string) # REPLACE `[…|*|…]` WITH `[…|_|…]` string = "\n".join( - _PATTERNS.formatting.sub(_ReplaceKeysHelper(cls, use_default, default_color, brightness_steps), line) - for line in string.split("\n") + _PATTERNS.formatting.sub( + _ReplaceKeysHelper( + cls, + use_default=use_default, + default_color=default_color, + brightness_steps=brightness_steps, + ), line + ) for line in string.split("\n") ) return ( @@ -314,7 +324,9 @@ def to_ansi( def escape( cls, string: str, + /, default_color: Optional[Rgba | Hexa] = None, + *, _escape_char: Literal["/", "\\"] = "/", ) -> str: """Escapes all valid formatting codes in the string, so they are visible when output @@ -329,22 +341,65 @@ def escape( use_default, default_color = cls._validate_default_color(default_color) return "\n".join( - _PATTERNS.formatting.sub(_EscapeFormatCodeHelper(cls, use_default, default_color, _escape_char), line) - for line in string.split("\n") + _PATTERNS.formatting.sub( + _EscapeFormatCodeHelper(cls, use_default=use_default, default_color=default_color, escape_char=_escape_char), + line, + ) for line in string.split("\n") ) @classmethod - def escape_ansi(cls, ansi_string: str) -> str: + def escape_ansi(cls, ansi_string: str, /) -> str: """Escapes all ANSI codes in the string, so they are visible when output to the console.\n ------------------------------------------------------------------------------------------- - `ansi_string` -⠀the string that contains the ANSI codes to escape""" return ansi_string.replace(ANSI.CHAR, ANSI.CHAR_ESCAPED) + @overload @classmethod def remove( cls, string: str, + /, default_color: Optional[Rgba | Hexa] = None, + *, + get_removals: Literal[True], + _ignore_linebreaks: bool = False, + ) -> tuple[str, tuple[tuple[int, str], ...]]: + ... + + @overload + @classmethod + def remove( + cls, + string: str, + /, + default_color: Optional[Rgba | Hexa] = None, + *, + get_removals: Literal[False] = False, + _ignore_linebreaks: bool = False, + ) -> str: + ... + + @overload + @classmethod + def remove( + cls, + string: str, + /, + default_color: Optional[Rgba | Hexa] = None, + *, + get_removals: bool = False, + _ignore_linebreaks: bool = False, + ) -> str | tuple[str, tuple[tuple[int, str], ...]]: + ... + + @classmethod + def remove( + cls, + string: str, + /, + default_color: Optional[Rgba | Hexa] = None, + *, get_removals: bool = False, _ignore_linebreaks: bool = False, ) -> str | tuple[str, tuple[tuple[int, str], ...]]: @@ -361,10 +416,48 @@ def remove( _ignore_linebreaks=_ignore_linebreaks, ) + @overload @classmethod def remove_ansi( cls, ansi_string: str, + /, + *, + get_removals: Literal[True], + _ignore_linebreaks: bool = False, + ) -> tuple[str, tuple[tuple[int, str], ...]]: + ... + + @overload + @classmethod + def remove_ansi( + cls, + ansi_string: str, + /, + *, + get_removals: Literal[False] = False, + _ignore_linebreaks: bool = False, + ) -> str: + ... + + @overload + @classmethod + def remove_ansi( + cls, + ansi_string: str, + /, + *, + get_removals: bool = False, + _ignore_linebreaks: bool = False, + ) -> str | tuple[str, tuple[tuple[int, str], ...]]: + ... + + @classmethod + def remove_ansi( + cls, + ansi_string: str, + /, + *, get_removals: bool = False, _ignore_linebreaks: bool = False, ) -> str | tuple[str, tuple[tuple[int, str], ...]]: @@ -407,27 +500,27 @@ def _config_console(cls) -> None: kernel32.SetConsoleMode(h, mode.value | 0x0004) except Exception: pass - _CONSOLE_ANSI_CONFIGURED = True + _CONSOLE_ANSI_CONFIGURED = True # type: ignore[assignment] @staticmethod - def _validate_default_color(default_color: Optional[Rgba | Hexa]) -> tuple[bool, Optional[rgba]]: + def _validate_default_color(default_color: Optional[Rgba | Hexa], /) -> tuple[bool, Optional[rgba]]: """Internal method to validate and convert `default_color` to a `rgba` color object.""" if default_color is None: return False, None - if Color.is_valid_hexa(default_color, False): + if Color.is_valid_hexa(default_color, allow_alpha=False): return True, hexa(cast(str | int, default_color)).to_rgba() - elif Color.is_valid_rgba(default_color, False): - return True, Color._parse_rgba(default_color) + elif Color.is_valid_rgba(default_color, allow_alpha=False): + return True, Color._parse_rgba(cast(Rgba, default_color)) # type: ignore[protected-access] raise TypeError("The 'default_color' parameter must be either a valid RGBA or HEXA color, or None.") @staticmethod - def _formats_to_keys(formats: str) -> list[str]: + def _formats_to_keys(formats: str, /) -> list[str]: """Internal method to convert a string of multiple format keys to a list of individual, stripped format keys.""" - return [k.strip() for k in formats.split("|") if k.strip()] + return [key.strip() for key in formats.split("|") if key.strip()] @classmethod - def _get_replacement(cls, format_key: str, default_color: Optional[rgba], brightness_steps: int = 20) -> str: + def _get_replacement(cls, format_key: str, default_color: Optional[rgba], /, brightness_steps: int = 20) -> str: """Internal method that gives you the corresponding ANSI code for the given format key. If `default_color` is not `None`, the text color will be `default_color` if all formats are reset or you can get lighter or darker version of `default_color` (also as BG)""" @@ -438,7 +531,8 @@ def _get_replacement(cls, format_key: str, default_color: Optional[rgba], bright if (isinstance(map_key, tuple) and format_key in map_key) or format_key == map_key: return _ANSI_SEQ_1.format( next(( - v for k, v in ANSI.CODES_MAP.items() if format_key == k or (isinstance(k, tuple) and format_key in k) + val for key, val in ANSI.CODES_MAP.items() \ + if format_key == key or (isinstance(key, tuple) and format_key in key) ), None) ) rgb_match = _PATTERNS.rgb.match(format_key) @@ -463,14 +557,14 @@ def _get_replacement(cls, format_key: str, default_color: Optional[rgba], bright @staticmethod def _get_default_ansi( default_color: rgba, + /, format_key: Optional[str] = None, brightness_steps: Optional[int] = None, + *, _modifiers: tuple[str, str] = (_DEFAULT_COLOR_MODS["lighten"], _DEFAULT_COLOR_MODS["darken"]), ) -> Optional[str]: """Internal method to get the `default_color` and lighter/darker versions of it as ANSI code.""" - if not isinstance(default_color, rgba): - return None - _default_color: tuple[int, int, int] = tuple(default_color)[:3] + _default_color: tuple[int, int, int] = (default_color[0], default_color[1], default_color[2]) if brightness_steps is None or (format_key and _PATTERNS.bg_opt_default.search(format_key)): return (ANSI.SEQ_BG_COLOR if format_key and _PATTERNS.bg_default.search(format_key) else ANSI.SEQ_COLOR).format( *_default_color @@ -488,13 +582,15 @@ def _get_default_ansi( if adjust == 0: return None elif modifiers in _modifiers[0]: - new_rgb = tuple(Color.adjust_lightness(default_color, (brightness_steps / 100) * adjust)) + adjusted_rgb = Color.adjust_lightness(default_color, (brightness_steps / 100) * adjust) + new_rgb = (adjusted_rgb[0], adjusted_rgb[1], adjusted_rgb[2]) elif modifiers in _modifiers[1]: - new_rgb = tuple(Color.adjust_lightness(default_color, -(brightness_steps / 100) * adjust)) + adjusted_rgb = Color.adjust_lightness(default_color, -(brightness_steps / 100) * adjust) + new_rgb = (adjusted_rgb[0], adjusted_rgb[1], adjusted_rgb[2]) return (ANSI.SEQ_BG_COLOR if is_bg else ANSI.SEQ_COLOR).format(*new_rgb[:3]) @staticmethod - def _normalize_key(format_key: str) -> str: + def _normalize_key(format_key: str, /) -> str: """Internal method to normalize the given format key.""" k_parts = format_key.replace(" ", "").lower().split(":") prefix_str = "".join( @@ -513,6 +609,7 @@ class _EscapeFormatCodeHelper: def __init__( self, cls: type[FormatCodes], + *, use_default: bool, default_color: Optional[rgba], escape_char: Literal["/", "\\"], @@ -522,7 +619,7 @@ def __init__( self.default_color = default_color self.escape_char: Literal["/", "\\"] = escape_char - def __call__(self, match: _rx.Match[str]) -> str: + def __call__(self, match: _rx.Match[str], /) -> str: formats, auto_reset_txt = match.group(1), match.group(3) # CHECK IF ALREADY ESCAPED OR CONTAINS NO FORMATTING @@ -536,12 +633,14 @@ def __call__(self, match: _rx.Match[str]) -> str: else: _formats = _PATTERNS.star_reset_inside.sub(r"\1_\2", formats) - if all((self.cls._get_replacement(k, self.default_color) != k) for k in self.cls._formats_to_keys(_formats)): + if all(self.cls._get_replacement(format_key, self.default_color) != format_key # type: ignore[protected-access] + for format_key in self.cls._formats_to_keys(_formats) # type: ignore[protected-access] + ): # ESCAPE THE FORMATTING CODE escaped = f"[{self.escape_char}{formats}]" if auto_reset_txt: # RECURSIVELY ESCAPE FORMATTING IN AUTO-RESET TEXT - escaped_auto_reset = self.cls.escape(auto_reset_txt, self.default_color, self.escape_char) + escaped_auto_reset = self.cls.escape(auto_reset_txt, self.default_color, _escape_char=self.escape_char) escaped += f"({escaped_auto_reset})" return escaped else: @@ -549,7 +648,7 @@ def __call__(self, match: _rx.Match[str]) -> str: result = f"[{formats}]" if auto_reset_txt: # STILL RECURSIVELY PROCESS AUTO-RESET TEXT - escaped_auto_reset = self.cls.escape(auto_reset_txt, self.default_color, self.escape_char) + escaped_auto_reset = self.cls.escape(auto_reset_txt, self.default_color, _escape_char=self.escape_char) result += f"({escaped_auto_reset})" return result @@ -557,10 +656,10 @@ def __call__(self, match: _rx.Match[str]) -> str: class _RemAnsiSeqHelper: """Internal, callable helper class to remove ANSI sequences and track their removal positions.""" - def __init__(self, removals: list[tuple[int, str]]): + def __init__(self, removals: list[tuple[int, str]], /): self.removals = removals - def __call__(self, match: _rx.Match[str]) -> str: + def __call__(self, match: _rx.Match[str], /) -> str: start_pos = match.start() - sum(len(removed) for _, removed in self.removals) if self.removals and self.removals[-1][0] == start_pos: start_pos = self.removals[-1][0] @@ -574,6 +673,7 @@ class _ReplaceKeysHelper: def __init__( self, cls: type[FormatCodes], + *, use_default: bool, default_color: Optional[rgba], brightness_steps: int, @@ -593,7 +693,7 @@ def __init__( self.ansi_formats: list[str] = [] self.ansi_resets: list[str] = [] - def __call__(self, match: _rx.Match[str]) -> str: + def __call__(self, match: _rx.Match[str], /) -> str: self.original_formats = self.formats = match.group(1) self.auto_reset_escaped = bool(match.group(2)) self.auto_reset_txt = match.group(3) @@ -635,11 +735,12 @@ def process_formats_and_auto_reset(self) -> None: def convert_to_ansi(self) -> None: """Convert format keys to ANSI codes and generate resets if needed.""" - self.format_keys = self.cls._formats_to_keys(self.formats) - self.ansi_formats = [ - r if (r := self.cls._get_replacement(k, self.default_color, self.brightness_steps)) != k else f"[{k}]" - for k in self.format_keys - ] + self.format_keys = self.cls._formats_to_keys(self.formats) # type: ignore[protected-access] + self.ansi_formats = [( + ansi_code \ + if (ansi_code := self.cls._get_replacement(format_key, self.default_color, self.brightness_steps)) != format_key # type: ignore[protected-access] + else f"[{format_key}]" + ) for format_key in self.format_keys] # GENERATE RESET CODES IF AUTO-RESET IS ACTIVE if self.auto_reset_txt and not self.auto_reset_escaped: @@ -652,44 +753,44 @@ def gen_reset_codes(self) -> None: default_color_resets = ("_bg", "default") if self.use_default else ("_bg", "_c") reset_keys: list[str] = [] - for k in self.format_keys: - k_lower = k.lower() + for format_key in self.format_keys: + k_lower = format_key.lower() k_set = set(k_lower.split(":")) # BACKGROUND COLOR FORMAT if _PREFIX["BG"] & k_set and len(k_set) <= 3: if k_set & _PREFIX["BR"]: # BRIGHT BACKGROUND COLOR - RESET BOTH BG AND COLOR - for i in range(len(k)): - if self.is_valid_color(k[i:]): + for i in range(len(format_key)): + if self.is_valid_color(format_key[i:]): reset_keys.extend(default_color_resets) break else: # REGULAR BACKGROUND COLOR - RESET ONLY BG - for i in range(len(k)): - if self.is_valid_color(k[i:]): + for i in range(len(format_key)): + if self.is_valid_color(format_key[i:]): reset_keys.append("_bg") break # TEXT COLOR FORMAT - elif self.is_valid_color(k) or any( - k_lower.startswith(pref_colon := f"{prefix}:") and self.is_valid_color(k[len(pref_colon):]) \ + elif self.is_valid_color(format_key) or any( + k_lower.startswith(pref_colon := f"{prefix}:") and self.is_valid_color(format_key[len(pref_colon):]) \ for prefix in _PREFIX["BR"] ): reset_keys.append(default_color_resets[1]) # TEXT STYLE FORMAT else: - reset_keys.append(f"_{k}") + reset_keys.append(f"_{format_key}") # CONVERT RESET KEYS TO ANSI CODES self.ansi_resets = [ - r for k in reset_keys if ( \ - r := self.cls._get_replacement(k, self.default_color, self.brightness_steps) + ansi_code for reset_key in reset_keys if ( \ + ansi_code := self.cls._get_replacement(reset_key, self.default_color, self.brightness_steps) # type: ignore[protected-access] ).startswith(f"{ANSI.CHAR}{ANSI.START}") ] - def build_output(self, match: _rx.Match[str]) -> str: + def build_output(self, match: _rx.Match[str], /) -> str: """Build the final output string based on processed formats and resets.""" # CHECK IF ALL FORMATS WERE VALID has_single_valid_ansi = len(self.ansi_formats) == 1 and self.ansi_formats[0].count(f"{ANSI.CHAR}{ANSI.START}") >= 1 @@ -717,6 +818,6 @@ def build_output(self, match: _rx.Match[str]) -> str: return output - def is_valid_color(self, color: str) -> bool: + def is_valid_color(self, color: str, /) -> bool: """Check whether the given color string is a valid formatting-key color.""" return bool((color in ANSI.COLOR_MAP) or Color.is_valid_rgba(color) or Color.is_valid_hexa(color)) diff --git a/src/xulbux/json.py b/src/xulbux/json.py index c60bca7..48c429f 100644 --- a/src/xulbux/json.py +++ b/src/xulbux/json.py @@ -3,11 +3,12 @@ create and update JSON files, with support for comments inside the JSON data. """ +from .base.types import DataObj from .file_sys import FileSys from .data import Data from .file import File -from typing import Literal, Any, cast +from typing import Literal, Any, overload, cast from pathlib import Path import json as _json @@ -16,14 +17,42 @@ class Json: """This class provides methods to read, create and update JSON files, with support for comments inside the JSON data.""" + @overload @classmethod def read( cls, json_file: Path | str, + /, + *, + comment_start: str = ">>", + comment_end: str = "<<", + return_original: Literal[True], + ) -> tuple[dict[str, Any], dict[str, Any]]: + ... + + @overload + @classmethod + def read( + cls, + json_file: Path | str, + /, + *, + comment_start: str = ">>", + comment_end: str = "<<", + return_original: Literal[False] = False, + ) -> dict[str, Any]: + ... + + @classmethod + def read( + cls, + json_file: Path | str, + /, + *, comment_start: str = ">>", comment_end: str = "<<", return_original: bool = False, - ) -> dict | tuple[dict, dict]: + ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any]]: """Read JSON files, ignoring comments.\n ------------------------------------------------------------------------------------ - `json_file` -⠀the path (relative or absolute) to the JSON file to read @@ -49,7 +78,7 @@ def read( fmt_error = "\n ".join(str(e).splitlines()) raise ValueError(f"Error parsing JSON in {file_path!r}:\n {fmt_error}") from e - if not (processed_data := dict(Data.remove_comments(data, comment_start, comment_end))): + if not (processed_data := dict(Data.remove_comments(data, comment_start=comment_start, comment_end=comment_end))): raise ValueError(f"The JSON file {file_path!r} is empty or contains only comments.") return (processed_data, data) if return_original else processed_data @@ -58,7 +87,9 @@ def read( def create( cls, json_file: Path | str, - data: dict, + data: dict[str, Any], + /, + *, indent: int = 2, compactness: Literal[0, 1, 2] = 1, force: bool = False, @@ -81,14 +112,8 @@ def create( file_path = FileSys.extend_or_make_path(json_path, prefer_script_dir=True) File.create( - file_path=file_path, - content=Data.render( - data=data, - indent=indent, - compactness=compactness, - as_json=True, - syntax_highlighting=False, - ), + file_path, + Data.render(data, indent=indent, compactness=compactness, as_json=True, syntax_highlighting=False), force=force, ) @@ -99,6 +124,8 @@ def update( cls, json_file: Path | str, update_values: dict[str, Any], + /, + *, comment_start: str = ">>", comment_end: str = "<<", path_sep: str = "->", @@ -142,7 +169,7 @@ def update( you can use the items list index inside the value-path, so `healthy->fruits->0`.\n ⇾ If the given value-path doesn't exist, it will be created.""" processed_data, data = cls.read( - json_file=json_file, + json_file, comment_start=comment_start, comment_end=comment_end, return_original=True, @@ -151,8 +178,8 @@ def update( update: dict[str, Any] = {} for val_path, new_val in update_values.items(): try: - if (path_id := Data.get_path_id(data=processed_data, value_paths=val_path, path_sep=path_sep)) is not None: - update[cast(str, path_id)] = new_val + if (path_id := Data.get_path_id(cast(DataObj, processed_data), val_path, path_sep=path_sep)) is not None: + update[path_id] = new_val else: data = cls._create_nested_path(data, val_path.split(path_sep), new_val) except Exception: @@ -161,10 +188,10 @@ def update( if update: data = Data.set_value_by_path_id(data, update) - cls.create(json_file=json_file, data=dict(data), force=True) + cls.create(json_file, data, force=True) @staticmethod - def _create_nested_path(data_obj: dict, path_keys: list[str], value: Any) -> dict: + def _create_nested_path(data_obj: dict[str, Any], path_keys: list[str], value: Any, /) -> dict[str, Any]: """Internal method that creates nested dictionaries/lists based on the given path keys and sets the specified value at the end of the path.""" last_idx, current = len(path_keys) - 1, data_obj @@ -175,11 +202,11 @@ def _create_nested_path(data_obj: dict, path_keys: list[str], value: Any) -> dic current[key] = value elif isinstance(current, list) and key.isdigit(): idx = int(key) - while len(current) <= idx: - current.append(None) + while len(cast(list[Any], current)) <= idx: + cast(list[Any], current).append(None) current[idx] = value else: - raise TypeError(f"Cannot set key '{key}' on {type(current)}") + raise TypeError(f"Cannot set key '{key}' on {type(cast(Any, current))}") else: next_key = path_keys[i + 1] @@ -189,12 +216,12 @@ def _create_nested_path(data_obj: dict, path_keys: list[str], value: Any) -> dic current = current[key] elif isinstance(current, list) and key.isdigit(): idx = int(key) - while len(current) <= idx: - current.append(None) + while len(cast(list[Any], current)) <= idx: + cast(list[Any], current).append(None) if current[idx] is None: current[idx] = [] if next_key.isdigit() else {} - current = current[idx] + current = cast(list[Any], current)[idx] else: - raise TypeError(f"Cannot navigate through {type(current)}") + raise TypeError(f"Cannot navigate through {type(cast(Any, current))}") return data_obj diff --git a/src/xulbux/py.typed b/src/xulbux/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/xulbux/regex.py b/src/xulbux/regex.py index a9ef866..16729f0 100644 --- a/src/xulbux/regex.py +++ b/src/xulbux/regex.py @@ -29,6 +29,8 @@ def brackets( cls, bracket1: str = "(", bracket2: str = ")", + /, + *, is_group: bool = False, strip_spaces: bool = False, ignore_in_strings: bool = True, @@ -75,12 +77,12 @@ def brackets( ) @classmethod - def outside_strings(cls, pattern: str = r".*") -> str: + def outside_strings(cls, pattern: str = r".*", /) -> str: """Matches the `pattern` only when it is not found inside a string (`'…'` or `"…"`).""" return rf"""(? str: + def all_except(cls, disallowed_pattern: str, /, ignore_pattern: str = "", *, is_group: bool = False) -> str: """Matches everything up to the `disallowed_pattern`, unless the `disallowed_pattern` is found inside a string/quotes (`'…'` or `"…"`).\n ------------------------------------------------------------------------------------- @@ -100,10 +102,10 @@ def all_except(cls, disallowed_pattern: str, ignore_pattern: str = "", is_group: ) @classmethod - def func_call(cls, func_name: Optional[str] = None) -> str: + def func_call(cls, func_name: Optional[str] = None, /) -> str: """Match a function call, and get back two groups: - 1. function name - 2. the function's arguments\n + 1. The function name + 2. The function's arguments (content inside the parentheses)\n If no `func_name` is given, it will match any function call.\n --------------------------------------------------------------------------------- Attention: Requires non-standard library `regex`, not standard library `re`!""" @@ -113,7 +115,7 @@ def func_call(cls, func_name: Optional[str] = None) -> str: return rf"""(?<=\b)({func_name})\s*{cls.brackets("(", ")", is_group=True)}""" @classmethod - def rgba_str(cls, fix_sep: Optional[str] = ",", allow_alpha: bool = True) -> str: + def rgba_str(cls, fix_sep: Optional[str] = ",", *, allow_alpha: bool = True) -> str: """Matches an RGBA color inside a string.\n ---------------------------------------------------------------------------------- - `fix_sep` -⠀the fixed separator between the RGBA values (e.g. `,`, `;` …)
@@ -155,7 +157,7 @@ def rgba_str(cls, fix_sep: Optional[str] = ",", allow_alpha: bool = True) -> str ) @classmethod - def hsla_str(cls, fix_sep: Optional[str] = ",", allow_alpha: bool = True) -> str: + def hsla_str(cls, fix_sep: Optional[str] = ",", *, allow_alpha: bool = True) -> str: """Matches a HSLA color inside a string.\n ---------------------------------------------------------------------------------- - `fix_sep` -⠀the fixed separator between the HSLA values (e.g. `,`, `;` …)
@@ -197,7 +199,7 @@ def hsla_str(cls, fix_sep: Optional[str] = ",", allow_alpha: bool = True) -> str ) @classmethod - def hexa_str(cls, allow_alpha: bool = True) -> str: + def hexa_str(cls, *, allow_alpha: bool = True) -> str: """Matches a HEXA color inside a string.\n ---------------------------------------------------------------------- - `allow_alpha` -⠀whether to include the alpha channel in the match\n @@ -239,7 +241,7 @@ class LazyRegex: def __init__(self, **patterns: str): self._patterns = patterns - def __getattr__(self, name: str) -> _rx.Pattern: + def __getattr__(self, name: str, /) -> _rx.Pattern[str]: if name in self._patterns: setattr(self, name, compiled := _rx.compile(self._patterns[name])) return compiled diff --git a/src/xulbux/string.py b/src/xulbux/string.py index 35746c2..aa16ff6 100644 --- a/src/xulbux/string.py +++ b/src/xulbux/string.py @@ -13,7 +13,7 @@ class String: """This class provides various utility methods for string manipulation and conversion.""" @classmethod - def to_type(cls, string: str) -> Any: + def to_type(cls, string: str, /) -> Any: """Will convert a string to the found type, including complex nested structures.\n ----------------------------------------------------------------------------------- - `string` -⠀the string to convert""" @@ -26,7 +26,7 @@ def to_type(cls, string: str) -> Any: return string @classmethod - def normalize_spaces(cls, string: str, tab_spaces: int = 4) -> str: + def normalize_spaces(cls, string: str, /, tab_spaces: int = 4) -> str: """Replaces all special space characters with normal spaces.\n --------------------------------------------------------------- - `tab_spaces` -⠀number of spaces to replace tab chars with""" @@ -38,7 +38,7 @@ def normalize_spaces(cls, string: str, tab_spaces: int = 4) -> str: .replace("\u2007", " ").replace("\u2008", " ").replace("\u2009", " ").replace("\u200A", " ") @classmethod - def escape(cls, string: str, str_quotes: Optional[Literal["'", '"']] = None) -> str: + def escape(cls, string: str, /, str_quotes: Optional[Literal["'", '"']] = None) -> str: """Escapes Python's special characters (e.g. `\\n`, `\\t`, …) and quotes inside the string.\n -------------------------------------------------------------------------------------------------------- - `string` -⠀the string to escape @@ -57,7 +57,7 @@ def escape(cls, string: str, str_quotes: Optional[Literal["'", '"']] = None) -> return string @classmethod - def is_empty(cls, string: Optional[str], spaces_are_empty: bool = False) -> bool: + def is_empty(cls, string: Optional[str], /, *, spaces_are_empty: bool = False) -> bool: """Returns `True` if the string is considered empty and `False` otherwise.\n ----------------------------------------------------------------------------------------------- - `string` -⠀the string to check (or `None`, which is considered empty) @@ -68,22 +68,22 @@ def is_empty(cls, string: Optional[str], spaces_are_empty: bool = False) -> bool ) @classmethod - def single_char_repeats(cls, string: str, char: str) -> int | bool: + def single_char_repeats(cls, string: str, char: str, /) -> int: """- If the string consists of only the same `char`, it returns the number of times it is present. - - If the string doesn't consist of only the same character, it returns `False`.\n + - If the string is empty or doesn't consist of only the same character, it returns `0`.\n --------------------------------------------------------------------------------------------------- - `string` -⠀the string to check - `char` -⠀the character to check for repetition""" if len(char) != 1: raise ValueError(f"The 'char' parameter must be a single character, got {char!r}") - if len(string) == (len(char) * string.count(char)): + if len(string) == string.count(char): return string.count(char) else: - return False + return 0 @classmethod - def decompose(cls, case_string: str, seps: str = "-_", lower_all: bool = True) -> list[str]: + def decompose(cls, case_string: str, /, seps: str = "-_", *, lower_all: bool = True) -> list[str]: """Will decompose the string (any type of casing, also mixed) into parts.\n ---------------------------------------------------------------------------- - `case_string` -⠀the string to decompose @@ -95,7 +95,7 @@ def decompose(cls, case_string: str, seps: str = "-_", lower_all: bool = True) - ] @classmethod - def to_camel_case(cls, string: str, upper: bool = True) -> str: + def to_camel_case(cls, string: str, /, *, upper: bool = True) -> str: """Will convert the string of any type of casing to CamelCase.\n ----------------------------------------------------------------- - `string` -⠀the string to convert @@ -109,7 +109,7 @@ def to_camel_case(cls, string: str, upper: bool = True) -> str: ) @classmethod - def to_delimited_case(cls, string: str, delimiter: str = "_", screaming: bool = False) -> str: + def to_delimited_case(cls, string: str, /, delimiter: str = "_", *, screaming: bool = False) -> str: """Will convert the string of any type of casing to delimited case.\n ----------------------------------------------------------------------- - `string` -⠀the string to convert @@ -121,7 +121,7 @@ def to_delimited_case(cls, string: str, delimiter: str = "_", screaming: bool = ) @classmethod - def get_lines(cls, string: str, remove_empty_lines: bool = False) -> list[str]: + def get_lines(cls, string: str, /, *, remove_empty_lines: bool = False) -> list[str]: """Will split the string into lines.\n ------------------------------------------------------------------------------------ - `string` -⠀the string to split @@ -136,7 +136,7 @@ def get_lines(cls, string: str, remove_empty_lines: bool = False) -> list[str]: return non_empty_lines @classmethod - def remove_consecutive_empty_lines(cls, string: str, max_consecutive: int = 0) -> str: + def remove_consecutive_empty_lines(cls, string: str, /, max_consecutive: int = 0) -> str: """Will remove consecutive empty lines from the string.\n ------------------------------------------------------------------------------------- - `string` -⠀the string to process @@ -150,7 +150,7 @@ def remove_consecutive_empty_lines(cls, string: str, max_consecutive: int = 0) - return _re.sub(r"(\n\s*){2,}", r"\1" * (max_consecutive + 1), string) @classmethod - def split_count(cls, string: str, count: int) -> list[str]: + def split_count(cls, string: str, count: int, /) -> list[str]: """Will split the string every `count` characters.\n ----------------------------------------------------- - `string` -⠀the string to split diff --git a/src/xulbux/system.py b/src/xulbux/system.py index 500dffd..350107e 100644 --- a/src/xulbux/system.py +++ b/src/xulbux/system.py @@ -10,11 +10,11 @@ from .console import Console from typing import Optional -import subprocess as _subprocess import multiprocessing as _multiprocessing +import subprocess as _subprocess import platform as _platform -import ctypes as _ctypes import getpass as _getpass +import ctypes as _ctypes import socket as _socket import time as _time import sys as _sys @@ -97,8 +97,7 @@ def architecture(cls) -> str: def cpu_count(cls) -> int: """The number of CPU cores available.""" try: - count = _multiprocessing.cpu_count() - return count if count is not None else 1 + return _multiprocessing.cpu_count() except (NotImplementedError, AttributeError): return 1 @@ -112,7 +111,7 @@ class System(metaclass=_SystemMeta): """This class provides methods to interact with the underlying operating system.""" @classmethod - def restart(cls, prompt: object = "", wait: int = 0, continue_program: bool = False, force: bool = False) -> None: + def restart(cls, prompt: object = "", /, *, wait: int = 0, continue_program: bool = False, force: bool = False) -> None: """Restarts the system with some advanced options\n -------------------------------------------------------------------------------------------------- - `prompt` -⠀the message to be displayed in the systems restart notification @@ -122,12 +121,14 @@ def restart(cls, prompt: object = "", wait: int = 0, continue_program: bool = Fa if wait < 0: raise ValueError(f"The 'wait' parameter must be non-negative, got {wait!r}") - _SystemRestartHelper(prompt, wait, continue_program, force)() + _SystemRestartHelper(prompt, wait=wait, continue_program=continue_program, force=force)() @classmethod def check_libs( cls, lib_names: list[str], + /, + *, install_missing: bool = False, missing_libs_msgs: MissingLibsMsgs = { "found_missing": "The following required libraries are missing:", @@ -145,10 +146,15 @@ def check_libs( ------------------------------------------------------------------------------------------------------------ If some libraries are missing or they could not be installed, their names will be returned as a list. If all libraries are installed (or were installed successfully), `None` will be returned.""" - return _SystemCheckLibsHelper(lib_names, install_missing, missing_libs_msgs, confirm_install)() + return _SystemCheckLibsHelper( + lib_names, + install_missing=install_missing, + missing_libs_msgs=missing_libs_msgs, + confirm_install=confirm_install, + )() @classmethod - def elevate(cls, win_title: Optional[str] = None, args: Optional[list] = None) -> bool: + def elevate(cls, win_title: Optional[str] = None, args: Optional[list[str]] = None) -> bool: """Attempts to start a new process with elevated privileges.\n --------------------------------------------------------------------------------- - `win_title` -⠀the window title of the elevated process (only on Windows) @@ -193,7 +199,7 @@ def elevate(cls, win_title: Optional[str] = None, args: Optional[list] = None) - class _SystemRestartHelper: """Internal, callable helper class to handle system restart with platform-specific logic.""" - def __init__(self, prompt: object, wait: int, continue_program: bool, force: bool): + def __init__(self, prompt: object, /, *, wait: int, continue_program: bool, force: bool): self.prompt = prompt self.wait = wait self.continue_program = continue_program @@ -207,7 +213,7 @@ def __call__(self) -> None: else: raise NotImplementedError(f"Restart not implemented for '{system}' systems.") - def check_running_processes(self, command: str | list[str], skip_lines: int = 0) -> None: + def check_running_processes(self, command: str | list[str], /, skip_lines: int = 0) -> None: """Check if processes are running and raise error if force is False.""" if self.force: return @@ -226,9 +232,9 @@ def restart_windows(self) -> None: self.check_running_processes("tasklist", skip_lines=3) if self.prompt: - _os.system(f'shutdown /r /t {self.wait} /c "{self.prompt}"') + _subprocess.run(["shutdown", "/r", "/t", str(self.wait), "/c", str(self.prompt)]) else: - _os.system("shutdown /r /t 0") + _subprocess.run(["shutdown", "/r", "/t", "0"]) if self.continue_program: self.wait_for_restart() @@ -261,6 +267,8 @@ class _SystemCheckLibsHelper: def __init__( self, lib_names: list[str], + /, + *, install_missing: bool, missing_libs_msgs: MissingLibsMsgs, confirm_install: bool, @@ -285,7 +293,7 @@ def __call__(self) -> Optional[list[str]]: def find_missing_libs(self) -> list[str]: """Find which libraries are missing.""" - missing = [] + missing: list[str] = [] for lib in self.lib_names: try: __import__(lib) @@ -293,7 +301,7 @@ def find_missing_libs(self) -> list[str]: missing.append(lib) return missing - def confirm_installation(self, missing: list[str]) -> bool: + def confirm_installation(self, missing: list[str], /) -> bool: """Ask user for confirmation before installing libraries.""" FormatCodes.print(f"[b]({self.missing_libs_msgs['found_missing']})") for lib in missing: @@ -301,7 +309,7 @@ def confirm_installation(self, missing: list[str]) -> bool: print() return Console.confirm(self.missing_libs_msgs["should_install"], end="\n") - def install_libs(self, missing: list[str]) -> Optional[list[str]]: + def install_libs(self, missing: list[str], /) -> Optional[list[str]]: """Install missing libraries using pip.""" for lib in missing[:]: try: diff --git a/tests/test_color.py b/tests/test_color.py index d1028eb..0652b14 100644 --- a/tests/test_color.py +++ b/tests/test_color.py @@ -127,8 +127,8 @@ def test_str_to_rgba(): def test_luminance(): assert Color.luminance(255, 0, 0) == 54 - assert Color.luminance(255, 0, 0, int) == 21 - assert 0.20 < Color.luminance(255, 0, 0, float) < 0.22 + assert Color.luminance(255, 0, 0, output_type=int) == 21 + assert 0.20 < Color.luminance(255, 0, 0, output_type=float) < 0.22 assert Color.luminance(0, 0, 0) == 0 assert Color.luminance(255, 255, 255) == 255 assert Color.luminance(128, 128, 128) == 55 diff --git a/tests/test_color_types.py b/tests/test_color_types.py index 0c58c1f..416e572 100644 --- a/tests/test_color_types.py +++ b/tests/test_color_types.py @@ -1,7 +1,9 @@ from xulbux.color import rgba, hexa, hsla +from typing import Optional -def assert_rgba_equal(actual: rgba, expected: tuple): + +def assert_rgba_equal(actual: rgba, expected: tuple[int, int, int, Optional[float]]): assert isinstance(actual, rgba) assert actual[0] == expected[0] assert actual[1] == expected[1] @@ -9,7 +11,7 @@ def assert_rgba_equal(actual: rgba, expected: tuple): assert actual[3] == expected[3] -def assert_hsla_equal(actual: hsla, expected: tuple): +def assert_hsla_equal(actual: hsla, expected: tuple[int, int, int, Optional[float]]): assert isinstance(actual, hsla) assert actual[0] == expected[0] assert actual[1] == expected[1] @@ -177,7 +179,7 @@ def test_hexa_return_values(): def test_hexa_construction(): assert hexa("#F00").values() == (255, 0, 0, None) - assert hexa("#F008").values(True) == (255, 0, 0, 0.53) + assert hexa("#F008").values(round_alpha=True) == (255, 0, 0, 0.53) assert hexa("#FF0000").values() == (255, 0, 0, None) assert hexa("#FF000080").values() == (255, 0, 0, 0.5) assert hexa(0xFF0000).values() == (255, 0, 0, None) diff --git a/tests/test_console.py b/tests/test_console.py index bd9c705..6eb03e7 100644 --- a/tests/test_console.py +++ b/tests/test_console.py @@ -1,48 +1,48 @@ from xulbux.console import ParsedArgData, ParsedArgs -from xulbux.console import Spinner, ProgressBar +from xulbux.console import Throbber, ProgressBar from xulbux.console import Console from xulbux import console +from typing import Any from unittest.mock import MagicMock, patch -from collections import namedtuple import builtins import pytest import sys import io +import os @pytest.fixture -def mock_terminal_size(monkeypatch): - TerminalSize = namedtuple("TerminalSize", ["columns", "lines"]) +def mock_terminal_size(monkeypatch: pytest.MonkeyPatch): def mock_get_terminal_size(): - return TerminalSize(columns=80, lines=24) + return os.terminal_size((80, 24)) - monkeypatch.setattr(console._os, "get_terminal_size", mock_get_terminal_size) + monkeypatch.setattr("xulbux.console._os.get_terminal_size", mock_get_terminal_size) @pytest.fixture -def mock_formatcodes_print(monkeypatch): +def mock_formatcodes_print(monkeypatch: pytest.MonkeyPatch): mock = MagicMock() # PATCH IN THE ORIGINAL MODULE WHERE IT IS DEFINED import xulbux.format_codes monkeypatch.setattr(xulbux.format_codes.FormatCodes, "print", mock) # ALSO PATCH IN CONSOLE MODULE JUST IN CASE - monkeypatch.setattr(console.FormatCodes, "print", mock) + monkeypatch.setattr("xulbux.console.FormatCodes.print", mock) return mock @pytest.fixture -def mock_builtin_input(monkeypatch): +def mock_builtin_input(monkeypatch: pytest.MonkeyPatch): mock = MagicMock() monkeypatch.setattr(builtins, "input", mock) return mock @pytest.fixture -def mock_prompt_toolkit(monkeypatch): +def mock_prompt_toolkit(monkeypatch: pytest.MonkeyPatch): mock = MagicMock(return_value="mocked multiline input") - monkeypatch.setattr(console._pt, "prompt", mock) + monkeypatch.setattr("xulbux.console._pt.prompt", mock) return mock @@ -55,19 +55,19 @@ def test_console_user(): assert user_output != "" -def test_console_width(mock_terminal_size): +def test_console_width(mock_terminal_size: MagicMock): width_output = Console.w assert isinstance(width_output, int) assert width_output == 80 -def test_console_height(mock_terminal_size): +def test_console_height(mock_terminal_size: MagicMock): height_output = Console.h assert isinstance(height_output, int) assert height_output == 24 -def test_console_size(mock_terminal_size): +def test_console_size(mock_terminal_size: MagicMock): size_output = Console.size assert isinstance(size_output, tuple) assert len(size_output) == 2 @@ -249,7 +249,7 @@ def test_console_supports_color(): ), ] ) -def test_get_args(monkeypatch, argv, arg_parse_configs, expected_parsed_args): +def test_get_args(monkeypatch: pytest.MonkeyPatch, argv: list[str], arg_parse_configs: dict[str, Any], expected_parsed_args: dict[str, dict[str, Any]]): monkeypatch.setattr(sys, "argv", argv) args_result = Console.get_args(arg_parse_configs) assert isinstance(args_result, ParsedArgs) @@ -273,7 +273,7 @@ def test_get_args_invalid_params(): Console.get_args({"arg": {"-a"}}, flag_value_sep="") -def test_get_args_custom_sep(monkeypatch): +def test_get_args_custom_sep(monkeypatch: pytest.MonkeyPatch): """Test custom flag-value separator handling""" monkeypatch.setattr(sys, "argv", ["script.py", "--msg::This is a message", "-d::42"]) result = Console.get_args({"message": {"--msg"}, "data": {"-d"}}, flag_value_sep="::") @@ -294,7 +294,7 @@ def test_get_args_custom_sep(monkeypatch): } -def test_get_args_mixed_dash_scenarios(monkeypatch): +def test_get_args_mixed_dash_scenarios(monkeypatch: pytest.MonkeyPatch): """Test complex scenario mixing defined flags with dash-prefixed values""" monkeypatch.setattr( sys, "argv", \ @@ -369,7 +369,7 @@ def test_args_dunder_methods(): assert (args != ParsedArgs()) is True -def test_multiline_input(mock_prompt_toolkit, capsys): +def test_multiline_input(mock_prompt_toolkit: MagicMock, capsys: pytest.CaptureFixture[str]): expected_input = "mocked multiline input" result = Console.multiline_input("Enter text:", show_keybindings=True, default_color="#BCA") @@ -388,7 +388,7 @@ def test_multiline_input(mock_prompt_toolkit, capsys): assert "key_bindings" in pt_kwargs -def test_multiline_input_no_bindings(mock_prompt_toolkit, capsys): +def test_multiline_input_no_bindings(mock_prompt_toolkit: MagicMock, capsys: pytest.CaptureFixture[str]): Console.multiline_input("Enter text:", show_keybindings=False, end="DONE") captured = capsys.readouterr() @@ -399,24 +399,24 @@ def test_multiline_input_no_bindings(mock_prompt_toolkit, capsys): mock_prompt_toolkit.assert_called_once() -def test_pause_exit_pause_only(monkeypatch, capsys): +def test_pause_exit_pause_only(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]): mock_keyboard = MagicMock() - monkeypatch.setattr(console._keyboard, "read_key", mock_keyboard) + monkeypatch.setattr("xulbux.console._keyboard.read_key", mock_keyboard) - Console.pause_exit(pause=True, exit=False, prompt="Press any key...") + Console.pause_exit("Press any key...", pause=True, exit=False) captured = capsys.readouterr() assert "Press any key..." in captured.out mock_keyboard.assert_called_once_with(suppress=True) -def test_pause_exit_with_exit(monkeypatch, capsys): +def test_pause_exit_with_exit(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]): mock_keyboard = MagicMock() mock_sys_exit = MagicMock() - monkeypatch.setattr(console._keyboard, "read_key", mock_keyboard) - monkeypatch.setattr(console._sys, "exit", mock_sys_exit) + monkeypatch.setattr("xulbux.console._keyboard.read_key", mock_keyboard) + monkeypatch.setattr("xulbux.console._sys.exit", mock_sys_exit) - Console.pause_exit(pause=True, exit=True, prompt="Exiting...", exit_code=1) + Console.pause_exit("Exiting...", pause=True, exit=True, exit_code=1) captured = capsys.readouterr() assert "Exiting..." in captured.out @@ -424,9 +424,9 @@ def test_pause_exit_with_exit(monkeypatch, capsys): mock_sys_exit.assert_called_once_with(1) -def test_pause_exit_reset_ansi(monkeypatch, capsys): +def test_pause_exit_reset_ansi(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]): mock_keyboard = MagicMock() - monkeypatch.setattr(console._keyboard, "read_key", mock_keyboard) + monkeypatch.setattr("xulbux.console._keyboard.read_key", mock_keyboard) Console.pause_exit(pause=True, exit=False, reset_ansi=True) @@ -435,29 +435,29 @@ def test_pause_exit_reset_ansi(monkeypatch, capsys): assert "\033[0m" in captured.out or captured.out.strip() == "" -def test_cls(monkeypatch): +def test_cls(monkeypatch: pytest.MonkeyPatch): mock_shutil = MagicMock() - mock_os_system = MagicMock() + mock_subprocess_run = MagicMock() mock_print = MagicMock() - monkeypatch.setattr(console._shutil, "which", mock_shutil) - monkeypatch.setattr(console._os, "system", mock_os_system) + monkeypatch.setattr("xulbux.console._shutil.which", mock_shutil) + monkeypatch.setattr("xulbux.console._subprocess.run", mock_subprocess_run) monkeypatch.setattr(builtins, "print", mock_print) - mock_shutil.side_effect = lambda cmd: "/bin/cls" if cmd == "cls" else None + mock_shutil.side_effect = lambda cmd: "/bin/cls" if cmd == "cls" else None # type: ignore Console.cls() - mock_os_system.assert_called_with("cls") + mock_subprocess_run.assert_called_with(["cls"]) mock_print.assert_called_with("\033[0m", end="", flush=True) - mock_os_system.reset_mock() + mock_subprocess_run.reset_mock() mock_print.reset_mock() - mock_shutil.side_effect = lambda cmd: "/bin/clear" if cmd == "clear" else None + mock_shutil.side_effect = lambda cmd: "/bin/clear" if cmd == "clear" else None # type: ignore Console.cls() - mock_os_system.assert_called_with("clear") + mock_subprocess_run.assert_called_with(["clear"]) mock_print.assert_called_with("\033[0m", end="", flush=True) -def test_log_basic(capsys): +def test_log_basic(capsys: pytest.CaptureFixture[str]): Console.log("INFO", "Test message") captured = capsys.readouterr() @@ -465,14 +465,14 @@ def test_log_basic(capsys): assert "Test message" in captured.out -def test_log_no_title(capsys): - Console.log(title=None, prompt="Just a message") +def test_log_no_title(capsys: pytest.CaptureFixture[str]): + Console.log(None, "Just a message") captured = capsys.readouterr() assert "Just a message" in captured.out -def test_debug_active(capsys): +def test_debug_active(capsys: pytest.CaptureFixture[str]): Console.debug("Debug message", active=True) captured = capsys.readouterr() @@ -480,13 +480,13 @@ def test_debug_active(capsys): assert "Debug message" in captured.out -def test_debug_inactive(mock_formatcodes_print): +def test_debug_inactive(mock_formatcodes_print: MagicMock): Console.debug("Debug message", active=False) mock_formatcodes_print.assert_not_called() -def test_info(capsys): +def test_info(capsys: pytest.CaptureFixture[str]): Console.info("Info message") captured = capsys.readouterr() @@ -494,7 +494,7 @@ def test_info(capsys): assert "Info message" in captured.out -def test_done(capsys): +def test_done(capsys: pytest.CaptureFixture[str]): Console.done("Task completed") captured = capsys.readouterr() @@ -502,7 +502,7 @@ def test_done(capsys): assert "Task completed" in captured.out -def test_warn(capsys): +def test_warn(capsys: pytest.CaptureFixture[str]): Console.warn("Warning message") captured = capsys.readouterr() @@ -510,9 +510,9 @@ def test_warn(capsys): assert "Warning message" in captured.out -def test_fail(capsys, monkeypatch): +def test_fail(capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch): mock_sys_exit = MagicMock() - monkeypatch.setattr(console._sys, "exit", mock_sys_exit) + monkeypatch.setattr("xulbux.console._sys.exit", mock_sys_exit) Console.fail("Error occurred") @@ -522,9 +522,9 @@ def test_fail(capsys, monkeypatch): mock_sys_exit.assert_called_once_with(1) -def test_exit_method(capsys, monkeypatch): +def test_exit_method(capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch): mock_sys_exit = MagicMock() - monkeypatch.setattr(console._sys, "exit", mock_sys_exit) + monkeypatch.setattr("xulbux.console._sys.exit", mock_sys_exit) Console.exit("Program ending") @@ -534,7 +534,7 @@ def test_exit_method(capsys, monkeypatch): mock_sys_exit.assert_called_once_with(0) -def test_log_box_filled(capsys): +def test_log_box_filled(capsys: pytest.CaptureFixture[str]): Console.log_box_filled("Line 1", "Line 2", box_bg_color="green") captured = capsys.readouterr() @@ -542,7 +542,7 @@ def test_log_box_filled(capsys): assert "Line 2" in captured.out -def test_log_box_bordered(capsys): +def test_log_box_bordered(capsys: pytest.CaptureFixture[str]): Console.log_box_bordered("Content line", border_type="rounded") captured = capsys.readouterr() @@ -550,43 +550,43 @@ def test_log_box_bordered(capsys): @patch("xulbux.console.Console.input") -def test_confirm_yes(mock_input): +def test_confirm_yes(mock_input: MagicMock): mock_input.return_value = "y" result = Console.confirm("Continue?") assert result is True @patch("xulbux.console.Console.input") -def test_confirm_no(mock_input): +def test_confirm_no(mock_input: MagicMock): mock_input.return_value = "n" result = Console.confirm("Continue?") assert result is False @patch("xulbux.console.Console.input") -def test_confirm_default_yes(mock_input): +def test_confirm_default_yes(mock_input: MagicMock): mock_input.return_value = "" result = Console.confirm("Continue?", default_is_yes=True) assert result is True @patch("xulbux.console.Console.input") -def test_confirm_default_no(mock_input): +def test_confirm_default_no(mock_input: MagicMock): mock_input.return_value = "" result = Console.confirm("Continue?", default_is_yes=False) assert result is False @pytest.fixture -def mock_prompt_session(monkeypatch): +def mock_prompt_session(monkeypatch: pytest.MonkeyPatch): mock_session = MagicMock() mock_session_class = MagicMock(return_value=mock_session) mock_session.prompt.return_value = None - monkeypatch.setattr(console._pt, "PromptSession", mock_session_class) + monkeypatch.setattr("xulbux.console._pt.PromptSession", mock_session_class) return mock_session_class, mock_session -def test_input_creates_prompt_session(mock_prompt_session, mock_formatcodes_print): +def test_input_creates_prompt_session(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that Console.input creates a PromptSession with correct parameters.""" mock_session_class, mock_session = mock_prompt_session @@ -603,7 +603,7 @@ def test_input_creates_prompt_session(mock_prompt_session, mock_formatcodes_prin mock_session.prompt.assert_called_once() -def test_input_with_placeholder(mock_prompt_session, mock_formatcodes_print): +def test_input_with_placeholder(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that placeholder is correctly passed to PromptSession.""" mock_session_class, _ = mock_prompt_session @@ -615,7 +615,7 @@ def test_input_with_placeholder(mock_prompt_session, mock_formatcodes_print): assert call_kwargs["placeholder"] != "" -def test_input_without_placeholder(mock_prompt_session, mock_formatcodes_print): +def test_input_without_placeholder(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that placeholder is empty when not provided.""" mock_session_class, _ = mock_prompt_session @@ -627,11 +627,11 @@ def test_input_without_placeholder(mock_prompt_session, mock_formatcodes_print): assert call_kwargs["placeholder"] == "" -def test_input_with_validator_function(mock_prompt_session, mock_formatcodes_print): +def test_input_with_validator_function(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that a custom validator function is properly handled.""" mock_session_class, _ = mock_prompt_session - def email_validator(text): + def email_validator(text: str) -> str | None: if "@" not in text: return "Invalid email" return None @@ -645,7 +645,7 @@ def email_validator(text): assert hasattr(validator_instance, "validate") -def test_input_with_length_constraints(mock_prompt_session, mock_formatcodes_print): +def test_input_with_length_constraints(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that min_len and max_len are properly handled.""" mock_session_class, _ = mock_prompt_session @@ -658,7 +658,7 @@ def test_input_with_length_constraints(mock_prompt_session, mock_formatcodes_pri assert hasattr(validator_instance, "validate") -def test_input_with_allowed_chars(mock_prompt_session, mock_formatcodes_print): +def test_input_with_allowed_chars(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that allowed_chars parameter is handled.""" mock_session_class, _ = mock_prompt_session @@ -670,7 +670,7 @@ def test_input_with_allowed_chars(mock_prompt_session, mock_formatcodes_print): assert call_kwargs["key_bindings"] is not None -def test_input_disable_paste(mock_prompt_session, mock_formatcodes_print): +def test_input_disable_paste(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that allow_paste=False is handled.""" mock_session_class, _ = mock_prompt_session @@ -682,7 +682,7 @@ def test_input_disable_paste(mock_prompt_session, mock_formatcodes_print): assert call_kwargs["key_bindings"] is not None -def test_input_with_start_end_formatting(mock_prompt_session, capsys): +def test_input_with_start_end_formatting(mock_prompt_session: tuple[MagicMock, MagicMock], capsys: pytest.CaptureFixture[str]): """Test that start and end parameters trigger FormatCodes.print calls.""" mock_session_class, _ = mock_prompt_session @@ -694,7 +694,7 @@ def test_input_with_start_end_formatting(mock_prompt_session, capsys): assert captured.out != "" or True # OUTPUT MAY BE CAPTURED OR GO TO REAL STDOUT -def test_input_message_formatting(mock_prompt_session, mock_formatcodes_print): +def test_input_message_formatting(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that the prompt message is properly formatted.""" mock_session_class, _ = mock_prompt_session @@ -706,7 +706,7 @@ def test_input_message_formatting(mock_prompt_session, mock_formatcodes_print): assert call_kwargs["message"] is not None -def test_input_bottom_toolbar_function(mock_prompt_session, capsys): +def test_input_bottom_toolbar_function(mock_prompt_session: tuple[MagicMock, MagicMock], capsys: pytest.CaptureFixture[str]): """Test that bottom toolbar function is set up.""" mock_session_class, _ = mock_prompt_session @@ -725,7 +725,7 @@ def test_input_bottom_toolbar_function(mock_prompt_session, capsys): pass -def test_input_style_configuration(mock_prompt_session, capsys): +def test_input_style_configuration(mock_prompt_session: tuple[MagicMock, MagicMock], capsys: pytest.CaptureFixture[str]): """Test that custom style is applied.""" mock_session_class, _ = mock_prompt_session @@ -737,7 +737,7 @@ def test_input_style_configuration(mock_prompt_session, capsys): assert call_kwargs["style"] is not None -def test_input_validate_while_typing_enabled(mock_prompt_session, mock_formatcodes_print): +def test_input_validate_while_typing_enabled(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that validate_while_typing is enabled.""" mock_session_class, _ = mock_prompt_session @@ -749,7 +749,7 @@ def test_input_validate_while_typing_enabled(mock_prompt_session, mock_formatcod assert call_kwargs["validate_while_typing"] is True -def test_input_validator_class_creation(mock_prompt_session, mock_formatcodes_print): +def test_input_validator_class_creation(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that InputValidator class is properly instantiated.""" mock_session_class, _ = mock_prompt_session @@ -763,7 +763,7 @@ def test_input_validator_class_creation(mock_prompt_session, mock_formatcodes_pr assert callable(getattr(validator_instance, "validate", None)) -def test_input_key_bindings_setup(mock_prompt_session, mock_formatcodes_print): +def test_input_key_bindings_setup(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that key bindings are properly set up.""" mock_session_class, _ = mock_prompt_session @@ -777,7 +777,7 @@ def test_input_key_bindings_setup(mock_prompt_session, mock_formatcodes_print): assert hasattr(kb, "bindings") -def test_input_mask_char_single_character(mock_prompt_session, mock_formatcodes_print): +def test_input_mask_char_single_character(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that mask_char works with single characters.""" mock_session_class, _ = mock_prompt_session @@ -786,7 +786,7 @@ def test_input_mask_char_single_character(mock_prompt_session, mock_formatcodes_ assert mock_session_class.called -def test_input_output_type_int(mock_prompt_session, mock_formatcodes_print): +def test_input_output_type_int(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that output_type parameter is handled for int conversion.""" mock_session_class, _ = mock_prompt_session @@ -795,7 +795,7 @@ def test_input_output_type_int(mock_prompt_session, mock_formatcodes_print): assert mock_session_class.called -def test_input_default_val_handling(mock_prompt_session, mock_formatcodes_print): +def test_input_default_val_handling(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that default_val parameter is properly handled.""" mock_session_class, _ = mock_prompt_session @@ -804,7 +804,7 @@ def test_input_default_val_handling(mock_prompt_session, mock_formatcodes_print) assert mock_session_class.called -def test_input_custom_style_object(mock_prompt_session, mock_formatcodes_print): +def test_input_custom_style_object(mock_prompt_session: tuple[MagicMock, MagicMock], mock_formatcodes_print: MagicMock): """Test that a custom Style object is created.""" mock_session_class, _ = mock_prompt_session @@ -887,7 +887,7 @@ def test_progressbar_show_progress_invalid_total(): @patch("sys.stdout", new_callable=io.StringIO) -def test_progressbar_show_progress(mock_stdout): +def test_progressbar_show_progress(mock_stdout: MagicMock): pb = ProgressBar() # MANUALLY SET AND RESTORE _original_stdout TO AVOID PATCHING ISSUES WITH COMPILED CLASSES original = pb._original_stdout @@ -911,7 +911,7 @@ def test_progressbar_hide_progress(): assert pb._original_stdout is None -def test_progressbar_progress_context(capsys): +def test_progressbar_progress_context(capsys: pytest.CaptureFixture[str]): pb = ProgressBar() # TEST CONTEXT MANAGER BEHAVIOR BY CHECKING ACTUAL EFFECTS @@ -958,7 +958,7 @@ def test_progressbar_create_bar(): def test_progressbar_intercepted_output(): pb = ProgressBar() - intercepted = console._InterceptedOutput(pb) + intercepted = console._InterceptedOutput(pb) # type: ignore result = intercepted.write("test content") assert result == len("test content") assert "test content" in pb._buffer @@ -974,7 +974,7 @@ def test_progressbar_emergency_cleanup(): assert pb.active is False -def test_progressbar_get_formatted_info_and_bar_width(mock_terminal_size): +def test_progressbar_get_formatted_info_and_bar_width(mock_terminal_size: MagicMock): pb = ProgressBar() formatted, bar_width = pb._get_formatted_info_and_bar_width( ["{l}", "|{b}|", "{c}/{t}", "({p}%)"], @@ -998,7 +998,7 @@ def test_progressbar_start_stop_intercepting(): pb._start_intercepting() assert pb.active is True assert pb._original_stdout == original_stdout - assert isinstance(sys.stdout, console._InterceptedOutput) + assert isinstance(sys.stdout, console._InterceptedOutput) # type: ignore pb._stop_intercepting() assert pb.active is False @@ -1029,128 +1029,128 @@ def test_progressbar_redraw_progress_bar(): mock_stdout.flush.assert_called_once() -################################################## Spinner TESTS ################################################## +################################################## Throbber TESTS ################################################## -def test_spinner_init_defaults(): - spinner = Spinner() - assert spinner.label is None - assert spinner.interval == 0.2 - assert spinner.active is False - assert spinner.sep == " " - assert len(spinner.frames) > 0 +def test_throbber_init_defaults(): + throbber = Throbber() + assert throbber.label is None + assert throbber.interval == 0.2 + assert throbber.active is False + assert throbber.sep == " " + assert len(throbber.frames) > 0 -def test_spinner_init_custom(): - spinner = Spinner(label="Loading", interval=0.5, sep="-") - assert spinner.label == "Loading" - assert spinner.interval == 0.5 - assert spinner.sep == "-" +def test_throbber_init_custom(): + throbber = Throbber(label="Loading", interval=0.5, sep="-") + assert throbber.label == "Loading" + assert throbber.interval == 0.5 + assert throbber.sep == "-" -def test_spinner_set_format_valid(): - spinner = Spinner() - spinner.set_format(["{l}", "{a}"]) - assert spinner.spinner_format == ["{l}", "{a}"] +def test_throbber_set_format_valid(): + throbber = Throbber() + throbber.set_format(["{l}", "{a}"]) + assert throbber.throbber_format == ["{l}", "{a}"] -def test_spinner_set_format_invalid(): - spinner = Spinner() +def test_throbber_set_format_invalid(): + throbber = Throbber() with pytest.raises(ValueError): - spinner.set_format(["{l}"]) # MISSING {a} + throbber.set_format(["{l}"]) # MISSING {a} -def test_spinner_set_frames_valid(): - spinner = Spinner() - spinner.set_frames(("a", "b")) - assert spinner.frames == ("a", "b") +def test_throbber_set_frames_valid(): + throbber = Throbber() + throbber.set_frames(("a", "b")) + assert throbber.frames == ("a", "b") -def test_spinner_set_frames_invalid(): - spinner = Spinner() +def test_throbber_set_frames_invalid(): + throbber = Throbber() with pytest.raises(ValueError): - spinner.set_frames(("a", )) # LESS THAN 2 FRAMES + throbber.set_frames(("a", )) # LESS THAN 2 FRAMES -def test_spinner_set_interval_valid(): - spinner = Spinner() - spinner.set_interval(1.0) - assert spinner.interval == 1.0 +def test_throbber_set_interval_valid(): + throbber = Throbber() + throbber.set_interval(1.0) + assert throbber.interval == 1.0 -def test_spinner_set_interval_invalid(): - spinner = Spinner() +def test_throbber_set_interval_invalid(): + throbber = Throbber() with pytest.raises(ValueError): - spinner.set_interval(0) + throbber.set_interval(0) with pytest.raises(ValueError): - spinner.set_interval(-1) + throbber.set_interval(-1) @patch("xulbux.console._threading.Thread") @patch("xulbux.console._threading.Event") @patch("sys.stdout", new_callable=MagicMock) -def test_spinner_start(mock_stdout, mock_event, mock_thread): +def test_throbber_start(mock_stdout: MagicMock, mock_event: MagicMock, mock_thread: MagicMock): mock_thread.return_value.start.return_value = None - spinner = Spinner() - spinner.start("Test") + throbber = Throbber() + throbber.start("Test") - assert spinner.active is True - assert spinner.label == "Test" + assert throbber.active is True + assert throbber.label == "Test" mock_event.assert_called_once() mock_thread.assert_called_once() # TEST CALLING START AGAIN DOESN'T DO ANYTHING - spinner.start("Test2") + throbber.start("Test2") assert mock_event.call_count == 1 @patch("xulbux.console._threading.Thread") @patch("xulbux.console._threading.Event") -def test_spinner_stop(mock_event, mock_thread): - spinner = Spinner() +def test_throbber_stop(mock_event: MagicMock, mock_thread: MagicMock): + throbber = Throbber() # MANUALLY SET ACTIVE TO SIMULATE RUNNING - spinner.active = True + throbber.active = True mock_stop_event = MagicMock() mock_stop_event.set.return_value = None - spinner._stop_event = mock_stop_event + throbber._stop_event = mock_stop_event mock_animation_thread = MagicMock() mock_animation_thread.join.return_value = None - spinner._animation_thread = mock_animation_thread + throbber._animation_thread = mock_animation_thread - spinner.stop() + throbber.stop() - assert spinner.active is False + assert throbber.active is False mock_stop_event.set.assert_called_once() mock_animation_thread.join.assert_called_once() -def test_spinner_update_label(): - spinner = Spinner() - spinner.update_label("New Label") - assert spinner.label == "New Label" +def test_throbber_update_label(): + throbber = Throbber() + throbber.update_label("New Label") + assert throbber.label == "New Label" -def test_spinner_context_manager(): - spinner = Spinner() +def test_throbber_context_manager(): + throbber = Throbber() # TEST CONTEXT MANAGER BEHAVIOR BY CHECKING ACTUAL EFFECTS - with spinner.context("Test") as update: - assert spinner.active is True - assert spinner.label == "Test" + with throbber.context("Test") as update: + assert throbber.active is True + assert throbber.label == "Test" update("New Label") - assert spinner.label == "New Label" + assert throbber.label == "New Label" - # AFTER CONTEXT EXITS, SPINNER SHOULD BE STOPPED - assert spinner.active is False + # AFTER CONTEXT EXITS, THROBBER SHOULD BE STOPPED + assert throbber.active is False -def test_spinner_context_manager_exception(): - spinner = Spinner() +def test_throbber_context_manager_exception(): + throbber = Throbber() # TEST THAT CLEANUP HAPPENS EVEN WITH EXCEPTIONS with pytest.raises(ValueError): - with spinner.context("Test"): + with throbber.context("Test"): raise ValueError("Oops") - # AFTER EXCEPTION, SPINNER SHOULD STILL BE CLEANED UP - assert spinner.active is False + # AFTER EXCEPTION, THROBBER SHOULD STILL BE CLEANED UP + assert throbber.active is False diff --git a/tests/test_data.py b/tests/test_data.py index 1b8c23c..85feda6 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,9 +1,11 @@ +from xulbux.base.types import DataObj from xulbux.data import Data +from typing import Literal, Any, cast import pytest # ! DON'T CHANGE THIS DATA ! -d_comments = { +d_comments: dict[str, Any] = { "key1": [ ">> COMMENT IN THE BEGINNING OF THE STRING << value1", "value2 >> COMMENT IN THE END OF THE STRING", @@ -15,12 +17,12 @@ ">> ALL THE KEYS VALUES ARE COMMENTS value", } -d1_equal = { +d1_equal: dict[str, Any] = { "key1": ["value1", "value2", "value3", ["value1", "value2", "value3"]], "key2": ["value1", "value2", "value3", ["value1", "value2", "value3"]], "key3": "value", } -d2_equal = { +d2_equal: dict[str, Any] = { "key1": ["value1", "value2", "value3", ["value1", "value2", "value3"]], "key2": ["value1", "value2", "value3", ["value1", "value2", "value3"]], "key3": "CHANGED value", @@ -76,7 +78,7 @@ def test_deserialize_bytes(): ({}, 0), ] ) -def test_chars_count(input_data, expected_count): +def test_chars_count(input_data: DataObj, expected_count: int): assert Data.chars_count(input_data) == expected_count @@ -88,27 +90,30 @@ def test_chars_count(input_data, expected_count): ([" a ", [" b ", " c"]], ["a", ["b", "c"]]), ] ) -def test_strip(input_data, expected_output): +def test_strip(input_data: DataObj, expected_output: DataObj): assert Data.strip(input_data) == expected_output @pytest.mark.parametrize( - "input_data, spaces_are_empty, expected_output", [ - (["a", "", "b", None, " "], False, ["a", "b", " "]), - (["a", "", "b", None, " "], True, ["a", "b"]), - (("a", "", "b", None, " "), False, ("a", "b", " ")), - (("a", "", "b", None, " "), True, ("a", "b")), - ({"k1": "a", "k2": "", "k3": "b", "k4": None, "k5": " "}, False, {"k1": "a", "k3": "b", "k5": " "}), - ({"k1": "a", "k2": "", "k3": "b", "k4": None, "k5": " "}, True, {"k1": "a", "k3": "b"}), - (["a", ["", "b"], "c"], False, ["a", ["b"], "c"]), - (["a", ["", "b"], "c"], True, ["a", ["b"], "c"]), - (["a", {"x": "", "y": "b"}, "c"], False, ["a", {"y": "b"}, "c"]), - (["a", {"x": "", "y": "b"}, "c"], True, ["a", {"y": "b"}, "c"]), - (["a", [], {}], False, ["a"]), - ] + "input_data, spaces_are_empty, expected_output", cast( + list[tuple[DataObj, bool, DataObj]], + [ + (["a", "", "b", None, " "], False, ["a", "b", " "]), + (["a", "", "b", None, " "], True, ["a", "b"]), + (("a", "", "b", None, " "), False, ("a", "b", " ")), + (("a", "", "b", None, " "), True, ("a", "b")), + ({"k1": "a", "k2": "", "k3": "b", "k4": None, "k5": " "}, False, {"k1": "a", "k3": "b", "k5": " "}), + ({"k1": "a", "k2": "", "k3": "b", "k4": None, "k5": " "}, True, {"k1": "a", "k3": "b"}), + (["a", ["", "b"], "c"], False, ["a", ["b"], "c"]), + (["a", ["", "b"], "c"], True, ["a", ["b"], "c"]), + (["a", {"x": "", "y": "b"}, "c"], False, ["a", {"y": "b"}, "c"]), + (["a", {"x": "", "y": "b"}, "c"], True, ["a", {"y": "b"}, "c"]), + (["a", [], {}], False, ["a"]), + ] + ) ) -def test_remove_empty_items(input_data, spaces_are_empty, expected_output): - assert Data.remove_empty_items(input_data, spaces_are_empty) == expected_output +def test_remove_empty_items(input_data: DataObj, spaces_are_empty: bool, expected_output: DataObj): + assert Data.remove_empty_items(input_data, spaces_are_empty=spaces_are_empty) == expected_output @pytest.mark.parametrize( @@ -121,7 +126,7 @@ def test_remove_empty_items(input_data, spaces_are_empty, expected_output): ({"k": ["v", "v"]}, {"k": ["v"]}), ] ) -def test_remove_duplicates(input_data, expected_output): +def test_remove_duplicates(input_data: DataObj, expected_output: DataObj): assert Data.remove_duplicates(input_data) == expected_output @@ -153,7 +158,7 @@ def test_get_path_id(): def test_get_value_by_path_id(): - data = {"a": [1, {"b": "c"}], "d": ("e", "f")} + data: dict[str, Any] = {"a": [1, {"b": "c"}], "d": ("e", "f")} path_id_1 = str(Data.get_path_id(data, "a->1->b")) path_id_2 = str(Data.get_path_id(data, "d->1")) @@ -171,16 +176,16 @@ def test_get_value_by_path_id(): def test_set_value_by_path_id(): - data = {"a": [1, {"b": "c"}], "d": ("e", "f")} + data: dict[str, Any] = {"a": [1, {"b": "c"}], "d": ("e", "f")} path_id_c = Data.get_path_id(data, "a->1->b") path_id_f = Data.get_path_id(data, "d->1") updated_data = Data.set_value_by_path_id(data, {path_id_c: "NEW_C", path_id_f: "NEW_F"}) # type: ignore[assignment] - expected_data = {"a": [1, {"b": "NEW_C"}], "d": ("e", "NEW_F")} + expected_data: dict[str, Any] = {"a": [1, {"b": "NEW_C"}], "d": ("e", "NEW_F")} assert updated_data == expected_data updated_data_types = Data.set_value_by_path_id(data, {path_id_c: [1, 2], path_id_f: {"x": 1}}) # type: ignore[assignment] - expected_data_types = {"a": [1, {"b": [1, 2]}], "d": ("e", {"x": 1})} + expected_data_types: dict[str, Any] = {"a": [1, {"b": [1, 2]}], "d": ("e", {"x": 1})} assert updated_data_types == expected_data_types with pytest.raises(ValueError): @@ -206,8 +211,24 @@ def test_set_value_by_path_id(): ({"data": b"hello"}, 4, 1, 80, ", ", False, "{'data': bytes('hello', 'utf-8')}"), ] ) -def test_render(data, indent, compactness, max_width, sep, as_json, expected_str): - result = Data.render(data, indent, compactness, max_width, sep, as_json, syntax_highlighting=False) +def test_render( + data: DataObj, + indent: int, + compactness: Literal[1, 0, 2], + max_width: int, + sep: str, + as_json: bool, + expected_str: str +): + result = Data.render( + data, + indent=indent, + compactness=compactness, + max_width=max_width, + sep=sep, + as_json=as_json, + syntax_highlighting=False + ) normalized_result = "\n".join(line.rstrip() for line in result.splitlines()) normalized_expected = "\n".join(line.rstrip() for line in expected_str.splitlines()) assert normalized_result == normalized_expected diff --git a/tests/test_file.py b/tests/test_file.py index 0f04aa2..50ecc27 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -38,13 +38,13 @@ ("no_dot_file", ".txt", True, True, "NoDotFile.txt"), ] ) -def test_rename_extension(input_file, new_extension, full_extension, camel_case, expected_output): - result = File.rename_extension(input_file, new_extension, full_extension, camel_case) +def test_rename_extension(input_file: str | Path, new_extension: str, full_extension: bool, camel_case: bool, expected_output: str): + result = File.rename_extension(input_file, new_extension, full_extension=full_extension, camel_case_filename=camel_case) assert isinstance(result, Path) assert str(result) == expected_output -def test_create_new_file(tmp_path): +def test_create_new_file(tmp_path: Path): file_path = tmp_path / "new_file.txt" abs_path = File.create(str(file_path)) assert isinstance(abs_path, Path) @@ -54,10 +54,10 @@ def test_create_new_file(tmp_path): assert file.read() == "" -def test_create_file_with_content(tmp_path): +def test_create_file_with_content(tmp_path: Path): file_path = tmp_path / "content_file.log" content = "This is the file content.\nWith multiple lines." - abs_path = File.create(str(file_path), content=content) + abs_path = File.create(str(file_path), content) assert isinstance(abs_path, Path) assert file_path.exists() assert abs_path.resolve() == file_path.resolve() @@ -65,31 +65,31 @@ def test_create_file_with_content(tmp_path): assert file.read() == content -def test_create_file_exists_error(tmp_path): +def test_create_file_exists_error(tmp_path: Path): file_path = tmp_path / "existing_file.txt" with open(file_path, "w", encoding="utf-8") as file: file.write("Initial content") with pytest.raises(FileExistsError): - File.create(str(file_path), content="New content", force=False) + File.create(str(file_path), "New content", force=False) -def test_create_file_same_content_exists_error(tmp_path): +def test_create_file_same_content_exists_error(tmp_path: Path): file_path = tmp_path / "same_content_file.data" content = "Identical content" - File.create(str(file_path), content=content) + File.create(str(file_path), content) with pytest.raises(SameContentFileExistsError): - File.create(str(file_path), content=content, force=False) + File.create(str(file_path), content, force=False) -def test_create_file_force_overwrite_different_content(tmp_path): +def test_create_file_force_overwrite_different_content(tmp_path: Path): file_path = tmp_path / "overwrite_file.cfg" initial_content = "Old config" new_content = "New configuration values" - File.create(str(file_path), content=initial_content) + File.create(str(file_path), initial_content) assert open(file_path, "r", encoding="utf-8").read() == initial_content - abs_path = File.create(str(file_path), content=new_content, force=True) + abs_path = File.create(str(file_path), new_content, force=True) assert isinstance(abs_path, Path) assert file_path.exists() assert abs_path.resolve() == file_path.resolve() @@ -97,14 +97,14 @@ def test_create_file_force_overwrite_different_content(tmp_path): assert file.read() == new_content -def test_create_file_force_overwrite_same_content(tmp_path): +def test_create_file_force_overwrite_same_content(tmp_path: Path): file_path = tmp_path / "overwrite_same_file.ini" content = "[Settings]\nValue=1" - File.create(str(file_path), content=content) + File.create(str(file_path), content) assert open(file_path, "r", encoding="utf-8").read() == content - abs_path = File.create(str(file_path), content=content, force=True) + abs_path = File.create(str(file_path), content, force=True) assert isinstance(abs_path, Path) assert file_path.exists() assert abs_path.resolve() == file_path.resolve() @@ -112,16 +112,16 @@ def test_create_file_force_overwrite_same_content(tmp_path): assert file.read() == content -def test_create_file_in_subdirectory(tmp_path): +def test_create_file_in_subdirectory(tmp_path: Path): dir_path = tmp_path / "subdir" file_path = dir_path / "sub_file.txt" content = "Content in subdirectory" with pytest.raises(FileNotFoundError): - File.create(str(file_path), content=content) + File.create(str(file_path), content) dir_path.mkdir() - abs_path = File.create(str(file_path), content=content) + abs_path = File.create(str(file_path), content) assert isinstance(abs_path, Path) assert file_path.exists() assert abs_path.resolve() == file_path.resolve() diff --git a/tests/test_file_sys.py b/tests/test_file_sys.py index 0444ee8..42f3c9a 100644 --- a/tests/test_file_sys.py +++ b/tests/test_file_sys.py @@ -9,7 +9,7 @@ @pytest.fixture -def setup_test_environment(tmp_path, monkeypatch): +def setup_test_environment(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> dict[str, Path]: """Sets up a controlled environment for path tests.""" mock_cwd = tmp_path / "mock_cwd" mock_script_dir = tmp_path / "mock_script_dir" @@ -34,7 +34,11 @@ def setup_test_environment(tmp_path, monkeypatch): monkeypatch.setattr(Path, "cwd", staticmethod(lambda: mock_cwd)) monkeypatch.setattr(Path, "home", staticmethod(lambda: mock_home)) monkeypatch.setattr(sys.modules["__main__"], "__file__", str(mock_script_dir / "mock_script.py")) - monkeypatch.setattr(os.path, "expanduser", lambda path: str(mock_home) if path == "~" else path) + + def mock_expanduser(path: str) -> str: + return str(mock_home) if path == "~" else path + + monkeypatch.setattr(os.path, "expanduser", mock_expanduser) monkeypatch.setattr(tempfile, "gettempdir", lambda: str(mock_temp)) return { @@ -50,13 +54,13 @@ def setup_test_environment(tmp_path, monkeypatch): ################################################## Path TESTS ################################################## -def test_path_cwd(setup_test_environment): +def test_path_cwd(setup_test_environment: dict[str, Path]): cwd_output = FileSys.cwd assert isinstance(cwd_output, Path) assert str(cwd_output) == str(setup_test_environment["cwd"]) -def test_path_script_dir(setup_test_environment): +def test_path_script_dir(setup_test_environment: dict[str, Path]): script_dir_output = FileSys.script_dir assert isinstance(script_dir_output, Path) assert str(script_dir_output) == str(setup_test_environment["script_dir"]) @@ -70,7 +74,7 @@ def test_path_home(): assert home.is_dir() -def test_extend(setup_test_environment): +def test_extend(setup_test_environment: dict[str, Path]): env = setup_test_environment search_dir = str(env["search_in"]) search_dirs = [str(env["cwd"]), search_dir] @@ -112,7 +116,7 @@ def test_extend(setup_test_environment): assert FileSys.extend_path("CompletelyWrong/no_file_here.dat", search_in=search_dir, fuzzy_match=True) is None -def test_extend_or_make(setup_test_environment): +def test_extend_or_make(setup_test_environment: dict[str, Path]): env = setup_test_environment search_dir = str(env["search_in"]) @@ -142,7 +146,7 @@ def test_extend_or_make(setup_test_environment): assert str(FileSys.extend_or_make_path(rel_path_wrong, search_in=search_dir, fuzzy_match=True)) == str(expected_made) -def test_remove(tmp_path): +def test_remove(tmp_path: Path): # NON-EXISTENT non_existent_path = tmp_path / "does_not_exist" assert not non_existent_path.exists() diff --git a/tests/test_json.py b/tests/test_json.py index 3344e30..03614c8 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,29 +1,30 @@ from xulbux.base.exceptions import SameContentFileExistsError from xulbux.json import Json +from typing import Any from pathlib import Path import pytest import json -def create_test_json(tmp_path, filename, data): +def create_test_json(tmp_path: Path, filename: str, data: Any) -> Path: file_path = tmp_path / filename with open(file_path, "w") as file: json.dump(data, file, indent=2) return file_path -def create_test_json_string(tmp_path, filename, content): +def create_test_json_string(tmp_path: Path, filename: str, content: str) -> Path: file_path = tmp_path / filename with open(file_path, "w") as file: file.write(content) return file_path -SIMPLE_DATA = {"name": "test", "value": 123} +SIMPLE_DATA: dict[str, Any] = {"name": "test", "value": 123} SIMPLE_DATA_STR = '{"name": "test", "value": 123}' -COMMENT_DATA = { +COMMENT_DATA: dict[str, Any] = { "key1": "value with no comments", "key2": "value >>inline comment<<", "list": [1, ">>item is a comment", 2, "item >>inline comment<<"], @@ -37,7 +38,7 @@ def create_test_json_string(tmp_path, filename, content): "object": {">>": "whole key & value is a comment"}, ">>": "whole key & value is a comment" }""" -COMMENT_DATA_PROCESSED = { +COMMENT_DATA_PROCESSED: dict[str, Any] = { "key1": "value with no comments", "key2": "value", "list": [1, 2, "item"], @@ -52,7 +53,7 @@ def create_test_json_string(tmp_path, filename, content): }, "user": "Test User >>DON'T TOUCH<<" }""" -COMMENT_UPDATE_VALUES = { +COMMENT_UPDATE_VALUES: dict[str, Any] = { "config->version": 2.0, "config->features->0": "c", "user": "Cool Test User", @@ -66,16 +67,16 @@ def create_test_json_string(tmp_path, filename, content): "user": "Cool Test User >>DON'T TOUCH<<" }""" -UPDATE_DATA_START = { +UPDATE_DATA_START: dict[str, Any] = { "config": {"version": 1.0, "features": ["a", "b"]}, "user": "Test User", } -UPDATE_VALUES = { +UPDATE_VALUES: dict[str, Any] = { "config->version": 2.0, "config->features->1": "c", "user": {"name": "Test User", "admin": True}, } -UPDATE_DATA_END = { +UPDATE_DATA_END: dict[str, Any] = { "config": {"version": 2.0, "features": ["a", "c"]}, "user": {"name": "Test User", "admin": True}, } @@ -84,19 +85,19 @@ def create_test_json_string(tmp_path, filename, content): ################################################## Json TESTS ################################################## -def test_read_simple(tmp_path): +def test_read_simple(tmp_path: Path): file_path = create_test_json(tmp_path, "simple.json", SIMPLE_DATA) data = Json.read(str(file_path)) assert data == SIMPLE_DATA -def test_read_with_comments(tmp_path): +def test_read_with_comments(tmp_path: Path): file_path = create_test_json_string(tmp_path, "comments.json", COMMENT_DATA_STR) data = Json.read(str(file_path)) assert data == COMMENT_DATA_PROCESSED -def test_read_with_comments_return_original(tmp_path): +def test_read_with_comments_return_original(tmp_path: Path): file_path = create_test_json_string(tmp_path, "comments_orig.json", COMMENT_DATA_STR) processed, original = Json.read(str(file_path), return_original=True) assert processed == COMMENT_DATA_PROCESSED @@ -109,13 +110,13 @@ def test_read_non_existent_file(): Json.read("non_existent_file.json") -def test_read_invalid_json(tmp_path): +def test_read_invalid_json(tmp_path: Path): file_path = create_test_json_string(tmp_path, "invalid.json", "{invalid json") with pytest.raises(ValueError, match="Error parsing JSON"): Json.read(str(file_path)) -def test_read_empty_json(tmp_path): +def test_read_empty_json(tmp_path: Path): file_path = create_test_json_string(tmp_path, "empty.json", "{}") try: data = Json.read(str(file_path)) @@ -124,13 +125,13 @@ def test_read_empty_json(tmp_path): assert "empty or contains only comments" in str(e) -def test_read_comment_only_json(tmp_path): +def test_read_comment_only_json(tmp_path: Path): file_path = create_test_json_string(tmp_path, "comment_only.json", '{\n">>": "comment"\n}') with pytest.raises(ValueError, match="empty or contains only comments"): Json.read(str(file_path)) -def test_create_simple(tmp_path): +def test_create_simple(tmp_path: Path): file_path_str = str(tmp_path / "created.json") created_path = Json.create(file_path_str, SIMPLE_DATA) assert isinstance(created_path, Path) @@ -140,7 +141,7 @@ def test_create_simple(tmp_path): assert data == SIMPLE_DATA -def test_create_with_indent_compactness(tmp_path): +def test_create_with_indent_compactness(tmp_path: Path): file_path_str = str(tmp_path / "formatted.json") Json.create(file_path_str, SIMPLE_DATA, indent=4, compactness=0) with open(file_path_str, "r") as file: @@ -148,13 +149,13 @@ def test_create_with_indent_compactness(tmp_path): assert '\n "name":' in content -def test_create_force_false_exists(tmp_path): +def test_create_force_false_exists(tmp_path: Path): file_path = create_test_json(tmp_path, "existing.json", {"a": 1}) with pytest.raises(FileExistsError): Json.create(str(file_path), {"b": 2}, force=False) -def test_create_force_false_same_content(tmp_path): +def test_create_force_false_same_content(tmp_path: Path): from pathlib import Path file_path = Json.create(f"{tmp_path}/existing_same.json", SIMPLE_DATA, force=False) assert isinstance(file_path, Path) @@ -162,7 +163,7 @@ def test_create_force_false_same_content(tmp_path): Json.create(file_path, SIMPLE_DATA, force=False) -def test_create_force_true_exists(tmp_path): +def test_create_force_true_exists(tmp_path: Path): file_path = create_test_json(tmp_path, "overwrite.json", {"a": 1}) Json.create(str(file_path), {"b": 2}, force=True) with open(file_path, "r") as file: @@ -170,7 +171,7 @@ def test_create_force_true_exists(tmp_path): assert data == {"b": 2} -def test_update_existing_values(tmp_path): +def test_update_existing_values(tmp_path: Path): file_path = create_test_json(tmp_path, "update_test.json", UPDATE_DATA_START) Json.update(str(file_path), UPDATE_VALUES) with open(file_path, "r") as file: @@ -178,7 +179,7 @@ def test_update_existing_values(tmp_path): assert data == UPDATE_DATA_END -def test_update_with_comments(tmp_path): +def test_update_with_comments(tmp_path: Path): file_path = create_test_json_string(tmp_path, "update_comments.json", COMMENT_DATA_START) Json.update(str(file_path), COMMENT_UPDATE_VALUES) @@ -191,7 +192,7 @@ def test_update_with_comments(tmp_path): pytest.fail("JSON became invalid after update with comments") -def test_update_different_path_sep(tmp_path): +def test_update_different_path_sep(tmp_path: Path): file_path = create_test_json(tmp_path, "update_sep.json", {"a": {"b": 1}}) Json.update(str(file_path), {"a/b": 2}, path_sep="/") with open(file_path, "r") as file: @@ -199,7 +200,7 @@ def test_update_different_path_sep(tmp_path): assert data == {"a": {"b": 2}} -def test_update_create_non_existent_path(tmp_path): +def test_update_create_non_existent_path(tmp_path: Path): file_path = create_test_json(tmp_path, "update_create.json", {"existing": 1}) Json.update(str(file_path), {"new->nested->value": "created"}) with open(file_path, "r") as file: diff --git a/tests/test_regex.py b/tests/test_regex.py index 633e708..f47d276 100644 --- a/tests/test_regex.py +++ b/tests/test_regex.py @@ -1,5 +1,6 @@ from xulbux.regex import LazyRegex, Regex +from typing import cast import regex as rx import pytest import re @@ -341,9 +342,9 @@ def test_regex_hsla_str_valid_values(): # VERIFY THAT % AND ° SYMBOLS ARE NOT IN THE CAPTURED GROUPS for match in matches: - groups = match if isinstance(match, tuple) else (match, ) + groups = cast(tuple[str], match if isinstance(match, tuple) else (match, )) for group in groups: - if group: # Skip empty groups + if group: # SKIP EMPTY GROUPS assert "%" not in group, f"Percent sign should not be in captured group: {group}" assert "°" not in group, f"Degree sign should not be in captured group: {group}" diff --git a/tests/test_string.py b/tests/test_string.py index 9eb3650..66c2eab 100644 --- a/tests/test_string.py +++ b/tests/test_string.py @@ -64,13 +64,13 @@ def test_is_empty(): def test_single_char_repeats(): - assert String.single_char_repeats("aaaaa", "a") == 5 assert String.single_char_repeats("-----", "-") == 5 - assert String.single_char_repeats("bbbbb", "a") is False - assert String.single_char_repeats("abcde", "a") is False assert String.single_char_repeats("", "a") == 0 assert String.single_char_repeats("a", "a") == 1 - assert String.single_char_repeats("aaaba", "a") is False + assert String.single_char_repeats("aaaaa", "a") == 5 + assert String.single_char_repeats("aaaba", "a") == 0 + assert String.single_char_repeats("abcde", "a") == 0 + assert String.single_char_repeats("bbbbb", "a") == 0 def test_decompose(): diff --git a/tests/test_system.py b/tests/test_system.py index f73d816..1e46218 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -1,6 +1,6 @@ from xulbux.system import System -from unittest.mock import patch +from unittest.mock import MagicMock, patch import platform import pytest import os @@ -101,7 +101,7 @@ def test_check_libs_nonexistent_module(): @patch("xulbux.system._subprocess.check_call") @patch("xulbux.console.Console.confirm", return_value=False) # DECLINE INSTALLATION -def test_check_libs_decline_install(mock_confirm, mock_subprocess): +def test_check_libs_decline_install(mock_confirm: MagicMock, mock_subprocess: MagicMock): """Test check_libs when user declines installation""" result = System.check_libs(["nonexistent_module_12345"], install_missing=True) assert isinstance(result, list) @@ -111,18 +111,18 @@ def test_check_libs_decline_install(mock_confirm, mock_subprocess): @patch("xulbux.system._platform.system") @patch("xulbux.system._subprocess.check_output") -@patch("xulbux.system._os.system") -def test_restart_windows_simple(mock_os_system, mock_subprocess, mock_platform): +@patch("xulbux.system._subprocess.run") +def test_restart_windows_simple(mock_subprocess_run: MagicMock, mock_check_output: MagicMock, mock_platform: MagicMock): """Test simple restart on Windows""" mock_platform.return_value = "Windows" - mock_subprocess.return_value = b"minimal\nprocess\nlist\n" + mock_check_output.return_value = b"minimal\nprocess\nlist\n" System.restart() - mock_os_system.assert_called_once_with("shutdown /r /t 0") + mock_subprocess_run.assert_called_once_with(["shutdown", "/r", "/t", "0"]) @patch("xulbux.system._platform.system") @patch("xulbux.system._subprocess.check_output") -def test_restart_too_many_processes(mock_subprocess, mock_platform): +def test_restart_too_many_processes(mock_subprocess: MagicMock, mock_platform: MagicMock): """Test restart fails when too many processes running""" mock_platform.return_value = "Windows" mock_subprocess.return_value = b"many\nprocess\nlines\nhere\nmore\nprocesses\neven\nmore\n" @@ -132,7 +132,7 @@ def test_restart_too_many_processes(mock_subprocess, mock_platform): @patch("xulbux.system._platform.system") @patch("xulbux.system._subprocess.check_output") -def test_restart_unsupported_system(mock_subprocess, mock_platform): +def test_restart_unsupported_system(mock_subprocess: MagicMock, mock_platform: MagicMock): """Test restart on unsupported system""" mock_platform.return_value = "Unknown" mock_subprocess.return_value = b"some output" @@ -142,7 +142,7 @@ def test_restart_unsupported_system(mock_subprocess, mock_platform): @pytest.mark.skipif(os.name != "nt", reason="Windows-specific test") @patch("xulbux.system._ctypes") -def test_elevate_windows_already_elevated(mock_ctypes): +def test_elevate_windows_already_elevated(mock_ctypes: MagicMock): """Test elevate on WINDOWS when already elevated""" # SETUP THE MOCK TO RETURN 1 (True) FOR IsUserAnAdmin mock_ctypes.windll.shell32.IsUserAnAdmin.return_value = 1 @@ -153,7 +153,7 @@ def test_elevate_windows_already_elevated(mock_ctypes): @pytest.mark.skipif(os.name == "nt", reason="POSIX-specific test") @patch("xulbux.system._os.geteuid") -def test_elevate_posix_already_elevated(mock_geteuid): +def test_elevate_posix_already_elevated(mock_geteuid: MagicMock): """Test elevate on POSIX when already elevated""" mock_geteuid.return_value = 0 result = System.elevate()