diff --git a/codeflash/languages/javascript/frameworks/detector.py b/codeflash/languages/javascript/frameworks/detector.py index 013de47f5..f4905c593 100644 --- a/codeflash/languages/javascript/frameworks/detector.py +++ b/codeflash/languages/javascript/frameworks/detector.py @@ -10,7 +10,10 @@ import logging from dataclasses import dataclass, field from functools import lru_cache -from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path logger = logging.getLogger(__name__) diff --git a/codeflash/languages/javascript/frameworks/react/context.py b/codeflash/languages/javascript/frameworks/react/context.py index d4ebaf3d1..25e2fd097 100644 --- a/codeflash/languages/javascript/frameworks/react/context.py +++ b/codeflash/languages/javascript/frameworks/react/context.py @@ -20,6 +20,8 @@ from codeflash.languages.javascript.frameworks.react.discovery import ReactComponentInfo from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer +_BUILTIN_COMPONENTS = frozenset(("React.Fragment", "Fragment", "Suspense", "React.Suspense")) + HOOK_PATTERN = re.compile(r"\b(use[A-Z]\w*)\s*\(") JSX_COMPONENT_RE = re.compile(r"<([A-Z][a-zA-Z0-9.]*)") @@ -169,12 +171,10 @@ def _extract_hook_usages(component_source: str) -> list[HookUsage]: def _extract_child_components(component_source: str, analyzer: TreeSitterAnalyzer, full_source: str) -> list[str]: """Find child component names rendered in JSX.""" - children = set() - for match in JSX_COMPONENT_RE.finditer(component_source): - name = match.group(1) - # Skip React built-ins like React.Fragment - if name not in ("React.Fragment", "Fragment", "Suspense", "React.Suspense"): - children.add(name) + children = set(JSX_COMPONENT_RE.findall(component_source)) + # Skip React built-ins like React.Fragment + if children: + children.difference_update(_BUILTIN_COMPONENTS) return sorted(children) diff --git a/codeflash/languages/javascript/frameworks/react/discovery.py b/codeflash/languages/javascript/frameworks/react/discovery.py index 9e39de817..d4f30a9e6 100644 --- a/codeflash/languages/javascript/frameworks/react/discovery.py +++ b/codeflash/languages/javascript/frameworks/react/discovery.py @@ -10,10 +10,11 @@ import re from dataclasses import dataclass from enum import Enum -from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from pathlib import Path + from tree_sitter import Node from codeflash.languages.javascript.treesitter import FunctionNode, TreeSitterAnalyzer @@ -191,11 +192,7 @@ def _node_contains_jsx(node: Node) -> bool: if _node_contains_jsx(child): return True - for child in node.children: - if _node_contains_jsx(child): - return True - - return False + return any(_node_contains_jsx(child) for child in node.children) def _extract_hooks_used(function_source: str) -> list[str]: diff --git a/codeflash/languages/javascript/frameworks/react/profiler.py b/codeflash/languages/javascript/frameworks/react/profiler.py index 880793c11..e95d6f341 100644 --- a/codeflash/languages/javascript/frameworks/react/profiler.py +++ b/codeflash/languages/javascript/frameworks/react/profiler.py @@ -12,10 +12,11 @@ import logging import re -from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from pathlib import Path + from tree_sitter import Node from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer @@ -76,9 +77,7 @@ def instrument_component_with_profiler(source: str, component_name: str, analyze result = _insert_after_imports(result, counter_code, analyzer) # Ensure React is imported - result = _ensure_react_import(result) - - return result + return _ensure_react_import(result) def instrument_all_components_for_tracing(source: str, file_path: Path, analyzer: TreeSitterAnalyzer) -> str: @@ -161,11 +160,12 @@ def walk(node: Node) -> None: def _contains_jsx(node: Node) -> bool: """Check if a tree-sitter node contains JSX elements.""" - if node.type in ("jsx_element", "jsx_self_closing_element", "jsx_fragment"): - return True - for child in node.children: - if _contains_jsx(child): + stack = [node] + while stack: + node = stack.pop() + if node.type in ("jsx_element", "jsx_self_closing_element", "jsx_fragment"): return True + stack.extend(node.children) return False diff --git a/codeflash/languages/javascript/treesitter_utils.py b/codeflash/languages/javascript/treesitter_utils.py index b6126ec9a..75792be6f 100644 --- a/codeflash/languages/javascript/treesitter_utils.py +++ b/codeflash/languages/javascript/treesitter_utils.py @@ -1580,9 +1580,9 @@ def get_analyzer_for_file(file_path: Path) -> TreeSitterAnalyzer: """ suffix = file_path.suffix.lower() - if suffix in (".ts",): + if suffix == ".ts": return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT) - if suffix in (".tsx",): + if suffix == ".tsx": return TreeSitterAnalyzer(TreeSitterLanguage.TSX) # Default to JavaScript for .js, .jsx, .mjs, .cjs return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)