diff --git a/codeflash/languages/java/formatter.py b/codeflash/languages/java/formatter.py index 23a178f7e..1ae562131 100644 --- a/codeflash/languages/java/formatter.py +++ b/codeflash/languages/java/formatter.py @@ -14,6 +14,8 @@ import tempfile from pathlib import Path +_FORMATTER_CACHE: dict[str | None, JavaFormatter] = {} + logger = logging.getLogger(__name__) @@ -236,7 +238,7 @@ def format_java_code(source: str, project_root: Path | None = None) -> str: Formatted source code. """ - formatter = JavaFormatter(project_root) + formatter = _get_cached_formatter(project_root) return formatter.format_code(source) @@ -327,3 +329,12 @@ def normalize_java_code(source: str) -> str: normalized_lines.append(stripped) return "\n".join(normalized_lines) + + +def _get_cached_formatter(project_root: Path | None) -> JavaFormatter: + key = str(project_root) if project_root is not None else None + fmt = _FORMATTER_CACHE.get(key) + if fmt is None: + fmt = JavaFormatter(project_root) + _FORMATTER_CACHE[key] = fmt + return fmt diff --git a/codeflash/languages/registry.py b/codeflash/languages/registry.py index 38688cab6..214e11c1c 100644 --- a/codeflash/languages/registry.py +++ b/codeflash/languages/registry.py @@ -50,13 +50,13 @@ def _ensure_languages_registered() -> None: import contextlib with contextlib.suppress(ImportError): - from codeflash.languages.python import support as _ + from codeflash.languages.python import support as _python_support # noqa: F401 with contextlib.suppress(ImportError): - from codeflash.languages.javascript import support as _ + from codeflash.languages.javascript import support as _js_support # noqa: F401 with contextlib.suppress(ImportError): - from codeflash.languages.java import support as _ + from codeflash.languages.java import support as _java_support # noqa: F401 _languages_registered = True