From d4aeb64af7ebda0c9884f4393b4f96b019e79935 Mon Sep 17 00:00:00 2001 From: Sarthak Agarwal Date: Fri, 20 Feb 2026 08:18:15 +0530 Subject: [PATCH 01/37] [WIP] react framework initial commit --- codeflash/languages/base.py | 1 + .../javascript/frameworks/__init__.py | 1 + .../javascript/frameworks/detector.py | 94 + .../javascript/frameworks/react/__init__.py | 1 + .../javascript/frameworks/react/analyzer.py | 161 ++ .../javascript/frameworks/react/context.py | 204 +++ .../javascript/frameworks/react/discovery.py | 251 +++ .../javascript/frameworks/react/profiler.py | 244 +++ .../javascript/frameworks/react/testgen.py | 120 ++ codeflash/languages/javascript/parse.py | 38 + codeflash/languages/javascript/support.py | 74 + .../languages/javascript/treesitter_utils.py | 1588 +++++++++++++++++ codeflash/models/function_types.py | 3 +- 13 files changed, 2779 insertions(+), 1 deletion(-) create mode 100644 codeflash/languages/javascript/frameworks/__init__.py create mode 100644 codeflash/languages/javascript/frameworks/detector.py create mode 100644 codeflash/languages/javascript/frameworks/react/__init__.py create mode 100644 codeflash/languages/javascript/frameworks/react/analyzer.py create mode 100644 codeflash/languages/javascript/frameworks/react/context.py create mode 100644 codeflash/languages/javascript/frameworks/react/discovery.py create mode 100644 codeflash/languages/javascript/frameworks/react/profiler.py create mode 100644 codeflash/languages/javascript/frameworks/react/testgen.py create mode 100644 codeflash/languages/javascript/treesitter_utils.py diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 3e10da319..ce19c536b 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -93,6 +93,7 @@ class CodeContext: read_only_context: str = "" imports: list[str] = field(default_factory=list) language: Language = Language.PYTHON + react_context: str | None = None @dataclass diff --git a/codeflash/languages/javascript/frameworks/__init__.py b/codeflash/languages/javascript/frameworks/__init__.py new file mode 100644 index 000000000..c4bf7a8df --- /dev/null +++ b/codeflash/languages/javascript/frameworks/__init__.py @@ -0,0 +1 @@ +"""Framework detection and support for JavaScript/TypeScript projects.""" diff --git a/codeflash/languages/javascript/frameworks/detector.py b/codeflash/languages/javascript/frameworks/detector.py new file mode 100644 index 000000000..013de47f5 --- /dev/null +++ b/codeflash/languages/javascript/frameworks/detector.py @@ -0,0 +1,94 @@ +"""Framework detection for JavaScript/TypeScript projects. + +Detects React (and potentially other frameworks) by inspecting package.json +dependencies. Results are cached per project root. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from functools import lru_cache +from pathlib import Path + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class FrameworkInfo: + """Information about the frontend framework used in a project.""" + + name: str # "react", "vue", "angular", "none" + version: str | None = None # e.g., "18.2.0" + react_version_major: int | None = None # e.g., 18 + has_testing_library: bool = False # @testing-library/react installed + has_react_compiler: bool = False # React 19+ compiler detected + dev_dependencies: frozenset[str] = field(default_factory=frozenset) + + +_EMPTY_FRAMEWORK = FrameworkInfo(name="none") + + +@lru_cache(maxsize=32) +def detect_framework(project_root: Path) -> FrameworkInfo: + """Detect the frontend framework from package.json. + + Reads dependencies and devDependencies to identify React and its ecosystem. + Results are cached per project root path. + """ + package_json_path = project_root / "package.json" + if not package_json_path.exists(): + return _EMPTY_FRAMEWORK + + try: + package_data = json.loads(package_json_path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError) as e: + logger.debug("Failed to read package.json at %s: %s", package_json_path, e) + return _EMPTY_FRAMEWORK + + deps = package_data.get("dependencies", {}) + dev_deps = package_data.get("devDependencies", {}) + all_deps = {**deps, **dev_deps} + + # Detect React + react_version_str = deps.get("react") or dev_deps.get("react") + if not react_version_str: + return _EMPTY_FRAMEWORK + + version = _parse_version_string(react_version_str) + major = _parse_major_version(version) + + has_testing_library = "@testing-library/react" in all_deps + has_react_compiler = ( + "babel-plugin-react-compiler" in all_deps + or "react-compiler-runtime" in all_deps + or (major is not None and major >= 19) + ) + + return FrameworkInfo( + name="react", + version=version, + react_version_major=major, + has_testing_library=has_testing_library, + has_react_compiler=has_react_compiler, + dev_dependencies=frozenset(all_deps.keys()), + ) + + +def _parse_version_string(version_spec: str) -> str | None: + """Extract a clean version from a semver range like ^18.2.0 or ~17.0.0.""" + stripped = version_spec.lstrip("^~>= int | None: + """Extract major version number from a version string.""" + if not version: + return None + try: + return int(version.split(".")[0]) + except (ValueError, IndexError): + return None diff --git a/codeflash/languages/javascript/frameworks/react/__init__.py b/codeflash/languages/javascript/frameworks/react/__init__.py new file mode 100644 index 000000000..c7622b0d6 --- /dev/null +++ b/codeflash/languages/javascript/frameworks/react/__init__.py @@ -0,0 +1 @@ +"""React framework support for component discovery, profiling, and optimization.""" diff --git a/codeflash/languages/javascript/frameworks/react/analyzer.py b/codeflash/languages/javascript/frameworks/react/analyzer.py new file mode 100644 index 000000000..db87c22e6 --- /dev/null +++ b/codeflash/languages/javascript/frameworks/react/analyzer.py @@ -0,0 +1,161 @@ +"""Static analysis for React optimization opportunities. + +Detects common performance anti-patterns in React components: +- Inline object/array creation in JSX props +- Functions defined inside render body (missing useCallback) +- Expensive computations without useMemo +- Components receiving referentially unstable props +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codeflash.languages.javascript.frameworks.react.discovery import ReactComponentInfo + + +class OpportunitySeverity(str, Enum): + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +class OpportunityType(str, Enum): + INLINE_OBJECT_PROP = "inline_object_prop" + INLINE_ARRAY_PROP = "inline_array_prop" + MISSING_USECALLBACK = "missing_usecallback" + MISSING_USEMEMO = "missing_usememo" + MISSING_REACT_MEMO = "missing_react_memo" + UNSTABLE_REFERENCE = "unstable_reference" + + +@dataclass(frozen=True) +class OptimizationOpportunity: + """A detected optimization opportunity in a React component.""" + + type: OpportunityType + line: int + description: str + severity: OpportunitySeverity + + +# Patterns for expensive operations inside render body +EXPENSIVE_OPS_RE = re.compile( + r"\.(filter|map|sort|reduce|flatMap|find|findIndex|every|some)\s*\(" +) +INLINE_OBJECT_IN_JSX_RE = re.compile(r"=\{\s*\{") # ={{ ... }} in JSX +INLINE_ARRAY_IN_JSX_RE = re.compile(r"=\{\s*\[") # ={[ ... ]} in JSX +FUNCTION_DEF_RE = re.compile( + r"(?:const|let|var)\s+\w+\s*=\s*(?:async\s+)?(?:\([^)]*\)|[a-zA-Z_]\w*)\s*=>" + r"|function\s+\w+\s*\(" +) +USECALLBACK_RE = re.compile(r"\buseCallback\s*\(") +USEMEMO_RE = re.compile(r"\buseMemo\s*\(") + + +def detect_optimization_opportunities( + source: str, component_info: ReactComponentInfo +) -> list[OptimizationOpportunity]: + """Detect optimization opportunities in a React component.""" + opportunities: list[OptimizationOpportunity] = [] + lines = source.splitlines() + + # Only analyze the component's own lines + start = component_info.start_line - 1 + end = min(component_info.end_line, len(lines)) + component_lines = lines[start:end] + component_source = "\n".join(component_lines) + + # Check for inline objects in JSX props + _detect_inline_props(component_lines, start, opportunities) + + # Check for functions defined in render body without useCallback + _detect_missing_usecallback(component_source, component_lines, start, opportunities) + + # Check for expensive computations without useMemo + _detect_missing_usememo(component_source, component_lines, start, opportunities) + + # Check if component should be wrapped in React.memo + if not component_info.is_memoized: + opportunities.append(OptimizationOpportunity( + type=OpportunityType.MISSING_REACT_MEMO, + line=component_info.start_line, + description=f"Component '{component_info.function_name}' is not wrapped in React.memo(). " + "If it receives stable props, wrapping can prevent unnecessary re-renders.", + severity=OpportunitySeverity.MEDIUM, + )) + + return opportunities + + +def _detect_inline_props( + lines: list[str], offset: int, opportunities: list[OptimizationOpportunity] +) -> None: + """Detect inline object/array literals in JSX prop positions.""" + for i, line in enumerate(lines): + line_num = offset + i + 1 + if INLINE_OBJECT_IN_JSX_RE.search(line): + opportunities.append(OptimizationOpportunity( + type=OpportunityType.INLINE_OBJECT_PROP, + line=line_num, + description="Inline object literal in JSX prop creates a new reference on every render. " + "Extract to useMemo or a module-level constant.", + severity=OpportunitySeverity.HIGH, + )) + if INLINE_ARRAY_IN_JSX_RE.search(line): + opportunities.append(OptimizationOpportunity( + type=OpportunityType.INLINE_ARRAY_PROP, + line=line_num, + description="Inline array literal in JSX prop creates a new reference on every render. " + "Extract to useMemo or a module-level constant.", + severity=OpportunitySeverity.HIGH, + )) + + +def _detect_missing_usecallback( + component_source: str, + lines: list[str], + offset: int, + opportunities: list[OptimizationOpportunity], +) -> None: + """Detect arrow functions or function expressions that could use useCallback.""" + has_usecallback = bool(USECALLBACK_RE.search(component_source)) + + for i, line in enumerate(lines): + line_num = offset + i + 1 + stripped = line.strip() + # Look for arrow function or function expression definitions inside the component + if FUNCTION_DEF_RE.search(stripped) and "useCallback" not in stripped and "useMemo" not in stripped: + # Skip if the component already uses useCallback extensively + if not has_usecallback: + opportunities.append(OptimizationOpportunity( + type=OpportunityType.MISSING_USECALLBACK, + line=line_num, + description="Function defined inside render body creates a new reference on every render. " + "Wrap with useCallback() if passed as a prop to child components.", + severity=OpportunitySeverity.MEDIUM, + )) + + +def _detect_missing_usememo( + component_source: str, + lines: list[str], + offset: int, + opportunities: list[OptimizationOpportunity], +) -> None: + """Detect expensive computations that could benefit from useMemo.""" + for i, line in enumerate(lines): + line_num = offset + i + 1 + stripped = line.strip() + if EXPENSIVE_OPS_RE.search(stripped) and "useMemo" not in stripped: + opportunities.append(OptimizationOpportunity( + type=OpportunityType.MISSING_USEMEMO, + line=line_num, + description="Expensive array operation in render body runs on every render. " + "Wrap with useMemo() and specify dependencies.", + severity=OpportunitySeverity.HIGH, + )) diff --git a/codeflash/languages/javascript/frameworks/react/context.py b/codeflash/languages/javascript/frameworks/react/context.py new file mode 100644 index 000000000..b5dc2b871 --- /dev/null +++ b/codeflash/languages/javascript/frameworks/react/context.py @@ -0,0 +1,204 @@ +"""React-specific context extraction for component optimization. + +Extracts props interfaces, hook usage, parent/child component relationships, +context subscriptions, and optimization opportunities from React components. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codeflash.languages.javascript.frameworks.react.analyzer import OptimizationOpportunity + from codeflash.languages.javascript.frameworks.react.discovery import ReactComponentInfo + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer + +logger = logging.getLogger(__name__) + + +@dataclass +class HookUsage: + """Represents a hook call within a component.""" + + name: str + has_dependency_array: bool = False + dependency_count: int = 0 + + +@dataclass +class ReactContext: + """Context information for a React component, used in LLM prompts.""" + + props_interface: str | None = None + hooks_used: list[HookUsage] = field(default_factory=list) + parent_usages: list[str] = field(default_factory=list) + child_components: list[str] = field(default_factory=list) + context_subscriptions: list[str] = field(default_factory=list) + is_already_memoized: bool = False + optimization_opportunities: list[OptimizationOpportunity] = field(default_factory=list) + + def to_prompt_string(self) -> str: + """Format this context for inclusion in an LLM optimization prompt.""" + parts: list[str] = [] + + if self.props_interface: + parts.append(f"Props interface:\n```typescript\n{self.props_interface}\n```") + + if self.hooks_used: + hook_lines = [] + for hook in self.hooks_used: + dep_info = f" (deps: {hook.dependency_count})" if hook.has_dependency_array else " (no deps)" + hook_lines.append(f" - {hook.name}{dep_info}") + parts.append("Hooks used:\n" + "\n".join(hook_lines)) + + if self.child_components: + parts.append("Child components rendered: " + ", ".join(self.child_components)) + + if self.context_subscriptions: + parts.append("Context subscriptions: " + ", ".join(self.context_subscriptions)) + + if self.is_already_memoized: + parts.append("Note: Component is already wrapped in React.memo()") + + if self.optimization_opportunities: + opp_lines = [] + for opp in self.optimization_opportunities: + opp_lines.append(f" - [{opp.severity.value}] Line {opp.line}: {opp.description}") + parts.append("Detected optimization opportunities:\n" + "\n".join(opp_lines)) + + return "\n\n".join(parts) + + +def extract_react_context( + component_info: ReactComponentInfo, + source: str, + analyzer: TreeSitterAnalyzer, + module_root: Path, +) -> ReactContext: + """Extract React-specific context for a component. + + Analyzes the component source to find props types, hooks, child components, + and optimization opportunities. + """ + from codeflash.languages.javascript.frameworks.react.analyzer import ( # noqa: PLC0415 + detect_optimization_opportunities, + ) + + context = ReactContext( + props_interface=component_info.props_type, + is_already_memoized=component_info.is_memoized, + ) + + # Extract hook usage details from the component source + lines = source.splitlines() + start = component_info.start_line - 1 + end = min(component_info.end_line, len(lines)) + component_source = "\n".join(lines[start:end]) + + context.hooks_used = _extract_hook_usages(component_source) + context.child_components = _extract_child_components(component_source, analyzer, source) + context.context_subscriptions = _extract_context_subscriptions(component_source) + context.optimization_opportunities = detect_optimization_opportunities(source, component_info) + + # Extract full props interface definition if we have a type name + if component_info.props_type: + full_interface = _find_type_definition(component_info.props_type, source, analyzer) + if full_interface: + context.props_interface = full_interface + + return context + + +def _extract_hook_usages(component_source: str) -> list[HookUsage]: + """Parse hook calls and their dependency arrays from component source.""" + import re # noqa: PLC0415 + + hooks: list[HookUsage] = [] + # Match useXxx( patterns + hook_pattern = re.compile(r"\b(use[A-Z]\w*)\s*\(") + + for match in hook_pattern.finditer(component_source): + hook_name = match.group(1) + # Try to determine if there's a dependency array + # Look for ], [ pattern after the hook call (simplified heuristic) + rest_of_line = component_source[match.end():] + has_deps = False + dep_count = 0 + + # Simple heuristic: count brackets to find dependency array + bracket_depth = 1 + for i, char in enumerate(rest_of_line): + if char == "(": + bracket_depth += 1 + elif char == ")": + bracket_depth -= 1 + if bracket_depth == 0: + # Check if the last argument before closing paren is an array + preceding = rest_of_line[:i].rstrip() + if preceding.endswith("]"): + has_deps = True + # Count items in the array (rough: count commas + 1 for non-empty) + array_start = preceding.rfind("[") + if array_start >= 0: + array_content = preceding[array_start + 1:-1].strip() + if array_content: + dep_count = array_content.count(",") + 1 + else: + dep_count = 0 # empty deps [] + has_deps = True + break + + hooks.append(HookUsage( + name=hook_name, + has_dependency_array=has_deps, + dependency_count=dep_count, + )) + + return hooks + + +def _extract_child_components(component_source: str, analyzer: TreeSitterAnalyzer, full_source: str) -> list[str]: + """Find child component names rendered in JSX.""" + import re # noqa: PLC0415 + + # Match JSX tags that start with uppercase (React components) + jsx_component_re = re.compile(r"<([A-Z][a-zA-Z0-9.]*)") + children = set() + for match in jsx_component_re.finditer(component_source): + name = match.group(1) + # Skip React built-ins like React.Fragment + if name not in ("React.Fragment", "Fragment", "Suspense", "React.Suspense"): + children.add(name) + return sorted(children) + + +def _extract_context_subscriptions(component_source: str) -> list[str]: + """Find React context subscriptions via useContext calls.""" + import re # noqa: PLC0415 + + context_re = re.compile(r"\buseContext\s*\(\s*(\w+)") + return [match.group(1) for match in context_re.finditer(component_source)] + + +def _find_type_definition(type_name: str, source: str, analyzer: TreeSitterAnalyzer) -> str | None: + """Find the full type/interface definition for a props type.""" + source_bytes = source.encode("utf-8") + tree = analyzer.parse(source_bytes) + + def search_node(node): + if node.type in ("interface_declaration", "type_alias_declaration"): + name_node = node.child_by_field_name("name") + if name_node: + name = source_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8") + if name == type_name: + return source_bytes[node.start_byte:node.end_byte].decode("utf-8") + for child in node.children: + result = search_node(child) + if result: + return result + return None + + return search_node(tree.root_node) diff --git a/codeflash/languages/javascript/frameworks/react/discovery.py b/codeflash/languages/javascript/frameworks/react/discovery.py new file mode 100644 index 000000000..194088885 --- /dev/null +++ b/codeflash/languages/javascript/frameworks/react/discovery.py @@ -0,0 +1,251 @@ +"""React component discovery via tree-sitter analysis. + +Identifies React components (function, arrow, class) and hooks by analyzing +PascalCase naming, JSX returns, and hook usage patterns. +""" + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codeflash.languages.javascript.treesitter import FunctionNode, TreeSitterAnalyzer + +logger = logging.getLogger(__name__) + +PASCAL_CASE_RE = re.compile(r"^[A-Z][a-zA-Z0-9]*$") +HOOK_CALL_RE = re.compile(r"\buse[A-Z]\w*\s*\(") +HOOK_NAME_RE = re.compile(r"^use[A-Z]\w*$") + +# Built-in React hooks +BUILTIN_HOOKS = frozenset({ + "useState", "useEffect", "useContext", "useReducer", "useCallback", + "useMemo", "useRef", "useImperativeHandle", "useLayoutEffect", + "useInsertionEffect", "useDebugValue", "useDeferredValue", + "useTransition", "useId", "useSyncExternalStore", "useOptimistic", + "useActionState", "useFormStatus", +}) + + +class ComponentType(str, Enum): + FUNCTION = "function" + ARROW = "arrow" + CLASS = "class" + HOOK = "hook" + + +@dataclass(frozen=True) +class ReactComponentInfo: + """Information about a discovered React component or hook.""" + + function_name: str + component_type: ComponentType + uses_hooks: tuple[str, ...] = () + returns_jsx: bool = False + props_type: str | None = None + is_memoized: bool = False + start_line: int = 0 + end_line: int = 0 + + +def is_react_component(func: FunctionNode, source: str, analyzer: TreeSitterAnalyzer) -> bool: + """Check if a function is a React component. + + A React component: + - Has a PascalCase name + - Returns JSX (or could be a hook if named use*) + - Is not a class method (standalone function) + """ + if func.is_method: + return False + + name = func.name + + # Hooks (useXxx) are not components + if HOOK_NAME_RE.match(name): + return False + + if not PASCAL_CASE_RE.match(name): + return False + + return _function_returns_jsx(func, source, analyzer) + + +def is_react_hook(func: FunctionNode) -> bool: + """Check if a function is a custom React hook (useXxx naming).""" + return bool(HOOK_NAME_RE.match(func.name)) and not func.is_method + + +def classify_component(func: FunctionNode, source: str, analyzer: TreeSitterAnalyzer) -> ComponentType | None: + """Classify a function as a React component type, hook, or None.""" + if is_react_hook(func): + return ComponentType.HOOK + + if not is_react_component(func, source, analyzer): + return None + + if func.is_arrow: + return ComponentType.ARROW + + return ComponentType.FUNCTION + + +def find_react_components(source: str, file_path: Path, analyzer: TreeSitterAnalyzer) -> list[ReactComponentInfo]: + """Find all React components and hooks in a source file. + + Skips files with "use server" directive (Next.js Server Components). + """ + # Skip Server Components + if _has_server_directive(source): + logger.debug("Skipping server component file: %s", file_path) + return [] + + functions = analyzer.find_functions( + source, include_methods=False, include_arrow_functions=True, require_name=True + ) + + components: list[ReactComponentInfo] = [] + for func in functions: + comp_type = classify_component(func, source, analyzer) + if comp_type is None: + continue + + hooks_used = _extract_hooks_used(func.source_text) + props_type = _extract_props_type(func, source, analyzer) + is_memoized = _is_wrapped_in_memo(func, source) + + components.append(ReactComponentInfo( + function_name=func.name, + component_type=comp_type, + uses_hooks=tuple(hooks_used), + returns_jsx=comp_type != ComponentType.HOOK and _function_returns_jsx(func, source, analyzer), + props_type=props_type, + is_memoized=is_memoized, + start_line=func.start_line, + end_line=func.end_line, + )) + + return components + + +def _has_server_directive(source: str) -> bool: + """Check for 'use server' directive at the top of the file.""" + for line in source.splitlines()[:5]: + stripped = line.strip() + if stripped in ('"use server"', "'use server'", '"use server";', "'use server';"): + return True + if stripped and not stripped.startswith("//") and not stripped.startswith("/*"): + break + return False + + +def _function_returns_jsx(func: FunctionNode, source: str, analyzer: TreeSitterAnalyzer) -> bool: + """Check if a function returns JSX by looking for jsx_element/jsx_self_closing_element nodes.""" + source_bytes = source.encode("utf-8") + node = func.node + + # For arrow functions with expression body (implicit return), check the body directly + body = node.child_by_field_name("body") + if body: + return _node_contains_jsx(body) + + return False + + +def _node_contains_jsx(node) -> bool: + """Recursively check if a tree-sitter node contains JSX.""" + if node.type in ( + "jsx_element", "jsx_self_closing_element", "jsx_fragment", + "jsx_expression", "jsx_opening_element", + ): + return True + + # Check return statements + if node.type == "return_statement": + for child in node.children: + if _node_contains_jsx(child): + return True + + for child in node.children: + if _node_contains_jsx(child): + return True + + return False + + +def _extract_hooks_used(function_source: str) -> list[str]: + """Extract hook names called within a function body.""" + hooks = [] + seen = set() + for match in HOOK_CALL_RE.finditer(function_source): + hook_name = match.group(0).rstrip("( \t") + if hook_name not in seen: + seen.add(hook_name) + hooks.append(hook_name) + return hooks + + +def _extract_props_type(func: FunctionNode, source: str, analyzer: TreeSitterAnalyzer) -> str | None: + """Extract the TypeScript props type annotation from a component's parameters.""" + source_bytes = source.encode("utf-8") + node = func.node + + # Look for formal_parameters -> type_annotation + params = node.child_by_field_name("parameters") + if not params: + return None + + for param in params.children: + # Look for type annotation on first parameter + if param.type in ("required_parameter", "optional_parameter"): + type_node = param.child_by_field_name("type") + if type_node: + # Get the type annotation node (skip the colon) + for child in type_node.children: + if child.type != ":": + return source_bytes[child.start_byte:child.end_byte].decode("utf-8") + # Destructured params with type: { foo, bar }: Props + if param.type == "object_pattern": + # Look for next sibling that is a type_annotation + next_sib = param.next_named_sibling + if next_sib and next_sib.type == "type_annotation": + for child in next_sib.children: + if child.type != ":": + return source_bytes[child.start_byte:child.end_byte].decode("utf-8") + + return None + + +def _is_wrapped_in_memo(func: FunctionNode, source: str) -> bool: + """Check if the component is already wrapped in React.memo or memo().""" + # Check if the variable declaration wrapping this function uses memo() + # e.g., const MyComp = React.memo(function MyComp(...) {...}) + # or const MyComp = memo((...) => {...}) + node = func.node + parent = node.parent + + while parent: + if parent.type == "call_expression": + func_node = parent.child_by_field_name("function") + if func_node: + source_bytes = source.encode("utf-8") + func_text = source_bytes[func_node.start_byte:func_node.end_byte].decode("utf-8") + if func_text in ("React.memo", "memo"): + return True + parent = parent.parent + + # Also check for memo wrapping at the export level: + # export default memo(MyComponent) + name = func.name + memo_patterns = [ + f"React.memo({name})", + f"memo({name})", + f"React.memo({name},", + f"memo({name},", + ] + return any(pattern in source for pattern in memo_patterns) diff --git a/codeflash/languages/javascript/frameworks/react/profiler.py b/codeflash/languages/javascript/frameworks/react/profiler.py new file mode 100644 index 000000000..9d273b70b --- /dev/null +++ b/codeflash/languages/javascript/frameworks/react/profiler.py @@ -0,0 +1,244 @@ +"""React Profiler instrumentation for render counting and timing. + +Wraps React components with React.Profiler to capture render count, +phase (mount/update), actualDuration, and baseDuration. Outputs structured +markers parseable by the existing marker-parsing infrastructure. + +Marker format: + !######REACT_RENDER:{component}:{phase}:{actualDuration}:{baseDuration}:{count}######! +""" + +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer + +logger = logging.getLogger(__name__) + +MARKER_PREFIX = "REACT_RENDER" + + +def generate_render_counter_code(component_name: str) -> str: + """Generate the onRender callback and counter variable for Profiler instrumentation.""" + safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", component_name) + return f"""\ +let _codeflash_render_count_{safe_name} = 0; +function _codeflashOnRender_{safe_name}(id, phase, actualDuration, baseDuration) {{ + _codeflash_render_count_{safe_name}++; + console.log(`!######{MARKER_PREFIX}:${{id}}:${{phase}}:${{actualDuration}}:${{baseDuration}}:${{_codeflash_render_count_{safe_name}}}######!`); +}}""" + + +def instrument_component_with_profiler(source: str, component_name: str, analyzer: TreeSitterAnalyzer) -> str: + """Instrument a single component with React.Profiler. + + Wraps all JSX return statements with and adds the + onRender callback + counter at module scope. + + Handles: + - Single return statements + - Conditional returns (if/else) + - Fragment returns (<>...) + - Early returns (leaves non-JSX returns alone) + """ + source_bytes = source.encode("utf-8") + tree = analyzer.parse(source_bytes) + + safe_name = re.sub(r"[^a-zA-Z0-9_]", "_", component_name) + profiler_id = component_name + + # Find the component function node + func_node = _find_component_function(tree.root_node, component_name, source_bytes) + if func_node is None: + logger.debug("Could not find component function: %s", component_name) + return source + + # Find all return statements with JSX inside this function + return_nodes = _find_jsx_returns(func_node, source_bytes) + if not return_nodes: + logger.debug("No JSX return statements found in: %s", component_name) + return source + + # Apply transformations in reverse order to preserve byte offsets + result = source + for ret_node in sorted(return_nodes, key=lambda n: n.start_byte, reverse=True): + result = _wrap_return_with_profiler(result, ret_node, profiler_id, safe_name) + + # Add render counter code at the top (after imports) + counter_code = generate_render_counter_code(component_name) + result = _insert_after_imports(result, counter_code, analyzer) + + # Ensure React is imported + result = _ensure_react_import(result) + + return result + + +def instrument_all_components_for_tracing(source: str, file_path: Path, analyzer: TreeSitterAnalyzer) -> str: + """Instrument ALL components in a file for tracing/discovery mode.""" + from codeflash.languages.javascript.frameworks.react.discovery import find_react_components # noqa: PLC0415 + + components = find_react_components(source, file_path, analyzer) + if not components: + return source + + result = source + # Process in reverse order by start_line to preserve positions + for comp in sorted(components, key=lambda c: c.start_line, reverse=True): + if comp.returns_jsx: + result = instrument_component_with_profiler(result, comp.function_name, analyzer) + + return result + + +def _find_component_function(root_node, component_name: str, source_bytes: bytes): + """Find the tree-sitter node for a named component function.""" + # Check function declarations + if root_node.type == "function_declaration": + name_node = root_node.child_by_field_name("name") + if name_node: + name = source_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8") + if name == component_name: + return root_node + + # Check variable declarators with arrow functions (const MyComp = () => ...) + if root_node.type == "variable_declarator": + name_node = root_node.child_by_field_name("name") + if name_node: + name = source_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8") + if name == component_name: + return root_node + + # Check export statements + if root_node.type in ("export_statement", "lexical_declaration", "variable_declaration"): + for child in root_node.children: + result = _find_component_function(child, component_name, source_bytes) + if result: + return result + + for child in root_node.children: + result = _find_component_function(child, component_name, source_bytes) + if result: + return result + + return None + + +def _find_jsx_returns(func_node, source_bytes: bytes) -> list: + """Find all return statements that contain JSX within a function node.""" + returns = [] + + def walk(node): + # Don't descend into nested functions + if node != func_node and node.type in ( + "function_declaration", "arrow_function", "function", "method_definition", + ): + return + + if node.type == "return_statement": + # Check if return value contains JSX + for child in node.children: + if _contains_jsx(child): + returns.append(node) + break + else: + for child in node.children: + walk(child) + + walk(func_node) + return returns + + +def _contains_jsx(node) -> bool: + """Check if a tree-sitter node contains JSX elements.""" + if node.type in ( + "jsx_element", "jsx_self_closing_element", "jsx_fragment", + ): + return True + for child in node.children: + if _contains_jsx(child): + return True + return False + + +def _wrap_return_with_profiler(source: str, return_node, profiler_id: str, safe_name: str) -> str: + """Wrap a return statement's JSX with React.Profiler.""" + source_bytes = source.encode("utf-8") + + # Find the JSX part of the return (skip "return" keyword and whitespace) + jsx_start = None + jsx_end = return_node.end_byte + + for child in return_node.children: + if child.type == "return": + continue + if child.type == ";": + jsx_end = child.start_byte + continue + if _contains_jsx(child): + jsx_start = child.start_byte + jsx_end = child.end_byte + break + + if jsx_start is None: + return source + + jsx_content = source_bytes[jsx_start:jsx_end].decode("utf-8").strip() + + # Check if the return uses parentheses: return (...) + # If so, we need to wrap inside the parens + has_parens = False + for child in return_node.children: + if child.type == "parenthesized_expression": + has_parens = True + jsx_start = child.start_byte + 1 # skip ( + jsx_end = child.end_byte - 1 # skip ) + jsx_content = source_bytes[jsx_start:jsx_end].decode("utf-8").strip() + break + + wrapped = ( + f'' + f"\n{jsx_content}\n" + f"" + ) + + return source[:jsx_start] + wrapped + source[jsx_end:] + + +def _insert_after_imports(source: str, code: str, analyzer: TreeSitterAnalyzer) -> str: + """Insert code after the last import statement.""" + source_bytes = source.encode("utf-8") + tree = analyzer.parse(source_bytes) + + last_import_end = 0 + for child in tree.root_node.children: + if child.type == "import_statement": + last_import_end = child.end_byte + + # Find end of line after last import + insert_pos = last_import_end + while insert_pos < len(source) and source[insert_pos] != "\n": + insert_pos += 1 + if insert_pos < len(source): + insert_pos += 1 # skip the newline + + return source[:insert_pos] + "\n" + code + "\n\n" + source[insert_pos:] + + +def _ensure_react_import(source: str) -> str: + """Ensure React is imported (needed for React.Profiler).""" + if "import React" in source or "import * as React" in source: + return source + # Add React import at the top + if "from 'react'" in source or 'from "react"' in source: + # React is imported but maybe not as the default. That's fine for JSX. + # We need React.Profiler so add it + if "React" not in source.split("from")[0] if "from" in source else "": + return 'import React from "react";\n' + source + return source + return 'import React from "react";\n' + source diff --git a/codeflash/languages/javascript/frameworks/react/testgen.py b/codeflash/languages/javascript/frameworks/react/testgen.py new file mode 100644 index 000000000..fd621b05e --- /dev/null +++ b/codeflash/languages/javascript/frameworks/react/testgen.py @@ -0,0 +1,120 @@ +"""React-specific test generation helpers. + +Provides context building for React testgen prompts, re-render counting +test templates, and post-processing for generated React tests. +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codeflash.languages.base import CodeContext + from codeflash.languages.javascript.frameworks.react.context import ReactContext + from codeflash.languages.javascript.frameworks.react.discovery import ReactComponentInfo + + +def build_react_testgen_context( + component_info: ReactComponentInfo, + react_context: ReactContext, + code_context: CodeContext, +) -> dict: + """Assemble context dict for the React testgen LLM prompt.""" + return { + "component_name": component_info.function_name, + "component_type": component_info.component_type.value, + "component_source": code_context.target_code, + "props_interface": react_context.props_interface or "", + "hooks_used": [h.name for h in react_context.hooks_used], + "child_components": react_context.child_components, + "context_subscriptions": react_context.context_subscriptions, + "is_memoized": component_info.is_memoized, + "optimization_opportunities": [ + {"type": o.type.value, "line": o.line, "description": o.description} + for o in react_context.optimization_opportunities + ], + "read_only_context": code_context.read_only_context, + "imports": code_context.imports, + } + + +def generate_rerender_test_template(component_name: str, props_interface: str | None = None) -> str: + """Generate a template test that counts re-renders for a component. + + This template uses @testing-library/react's render + rerender to verify + that same props don't cause unnecessary re-renders. + """ + props_example = "{ /* same props */ }" if not props_interface else "{ /* fill in props matching interface */ }" + + return f"""\ +import {{ render }} from '@testing-library/react'; +import {{ {component_name} }} from './path-to-component'; + +describe('{component_name} render efficiency', () => {{ + it('should not re-render with same props', () => {{ + let renderCount = 0; + const OriginalComponent = {component_name}; + + // Wrap to count renders + const CountingComponent = (props) => {{ + renderCount++; + return ; + }}; + + const props = {props_example}; + const {{ rerender }} = render(); + + // Initial render + expect(renderCount).toBe(1); + + // Re-render with same props + rerender(); + + // Should not have re-rendered (if properly memoized) + // For non-memoized components, renderCount will be 2 + console.log(`!######REACT_RENDER:{component_name}:rerender_test:0:0:${{renderCount}}######!`); + }}); + + it('should render correctly with props', () => {{ + const props = {props_example}; + const {{ container }} = render(<{component_name} {{...props}} />); + expect(container).toBeTruthy(); + }}); +}}); +""" + + +def post_process_react_tests(test_source: str, component_info: ReactComponentInfo) -> str: + """Post-process LLM-generated React tests. + + Ensures: + - @testing-library/react imports are present + - act() wrapping for state updates + - Proper cleanup + """ + result = test_source + + # Ensure testing-library import + if "@testing-library/react" not in result: + result = "import { render, screen, act } from '@testing-library/react';\n" + result + + # Ensure act import if state updates are detected + if "act(" in result and "import" in result and "act" not in result.split("from '@testing-library/react'")[0]: + result = result.replace( + "from '@testing-library/react'", + "act, " + "from '@testing-library/react'", + 1, + ) + + # Ensure user-event import if user interactions are tested + if ("click" in result.lower() or "type" in result.lower() or "userEvent" in result) and "@testing-library/user-event" not in result: + # Add user-event import after testing-library import + result = re.sub( + r"(import .+ from '@testing-library/react';?\n)", + r"\1import userEvent from '@testing-library/user-event';\n", + result, + count=1, + ) + + return result diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py index e3eee4831..03aee9d38 100644 --- a/codeflash/languages/javascript/parse.py +++ b/codeflash/languages/javascript/parse.py @@ -10,6 +10,7 @@ import contextlib import json import re +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING @@ -31,6 +32,43 @@ jest_start_pattern = re.compile(r"!\$######([^:]+):([^:]+):([^:]+):([^:]+):([^#]+)######\$!") jest_end_pattern = re.compile(r"!######([^:]+):([^:]+):([^:]+):([^:]+):([^:]+):(\d+)######!") +# React Profiler render marker pattern +# Format: !######REACT_RENDER:{component}:{phase}:{actualDuration}:{baseDuration}:{renderCount}######! +REACT_RENDER_MARKER_PATTERN = re.compile( + r"!######REACT_RENDER:([^:]+):([^:]+):([^:]+):([^:]+):(\d+)######!" +) + + +@dataclass(frozen=True) +class RenderProfile: + """Parsed React Profiler render data from a single marker.""" + + component_name: str + phase: str # "mount" or "update" + actual_duration_ms: float + base_duration_ms: float + render_count: int + + +def parse_react_render_markers(stdout: str) -> list[RenderProfile]: + """Parse React Profiler render markers from test output. + + Returns a list of RenderProfile instances, one per marker found. + """ + profiles: list[RenderProfile] = [] + for match in REACT_RENDER_MARKER_PATTERN.finditer(stdout): + try: + profiles.append(RenderProfile( + component_name=match.group(1), + phase=match.group(2), + actual_duration_ms=float(match.group(3)), + base_duration_ms=float(match.group(4)), + render_count=int(match.group(5)), + )) + except (ValueError, IndexError) as e: + logger.debug("Failed to parse React render marker: %s", e) + return profiles + def _extract_jest_console_output(suite_elem) -> str: """Extract console output from Jest's JUnit XML system-out element. diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index e0111c634..45007caaa 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -22,6 +22,7 @@ from collections.abc import Sequence from codeflash.languages.base import ReferenceInfo + from codeflash.languages.javascript.frameworks.detector import FrameworkInfo from codeflash.languages.javascript.treesitter import TypeDefinition from codeflash.models.models import GeneratedTestsList, InvocationId @@ -68,6 +69,22 @@ def comment_prefix(self) -> str: def dir_excludes(self) -> frozenset[str]: return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"}) + _cached_framework_info: FrameworkInfo | None = None + _cached_framework_root: Path | None = None + + def get_framework_info(self, project_root: Path) -> FrameworkInfo: + """Get cached framework info for the project.""" + if self._cached_framework_root != project_root or self._cached_framework_info is None: + from codeflash.languages.javascript.frameworks.detector import detect_framework # noqa: PLC0415 + + self._cached_framework_info = detect_framework(project_root) + self._cached_framework_root = project_root + return self._cached_framework_info + + def is_react_project(self, project_root: Path) -> bool: + """Check if the project uses React.""" + return self.get_framework_info(project_root).name == "react" + # === Discovery === def discover_functions( @@ -99,6 +116,31 @@ def discover_functions( source, include_methods=criteria.include_methods, include_arrow_functions=True, require_name=True ) + # Build React component lookup if this is a React project + react_component_map: dict[str, Any] = {} + project_root = file_path.parent # Will be refined by caller + try: + from codeflash.languages.javascript.frameworks.react.discovery import ( # noqa: PLC0415 + classify_component, + ) + + for func in tree_functions: + comp_type = classify_component(func, source, analyzer) + if comp_type is not None: + from codeflash.languages.javascript.frameworks.react.discovery import ( # noqa: PLC0415 + _extract_hooks_used, + _is_wrapped_in_memo, + ) + + react_component_map[func.name] = { + "component_type": comp_type.value, + "hooks_used": _extract_hooks_used(func.source_text), + "is_memoized": _is_wrapped_in_memo(func, source), + "is_react_component": True, + } + except Exception as e: + logger.debug("React detection skipped: %s", e) + functions: list[FunctionToOptimize] = [] for func in tree_functions: # Check for return statement if required @@ -122,6 +164,9 @@ def discover_functions( if func.parent_function: parents.append(FunctionParent(name=func.parent_function, type="FunctionDef")) + # Attach React metadata if this function is a component + metadata = react_component_map.get(func.name) + functions.append( FunctionToOptimize( function_name=func.name, @@ -135,6 +180,7 @@ def discover_functions( is_method=func.is_method, language=str(self.language), doc_start_line=func.doc_start_line, + metadata=metadata, ) ) @@ -423,6 +469,33 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, else: read_only_context = type_definitions_context + # Append React-specific context if this is a React component + react_context_str = "" + if function.metadata and function.metadata.get("is_react_component"): + try: + from codeflash.languages.javascript.frameworks.react.discovery import ( # noqa: PLC0415 + ReactComponentInfo, + find_react_components, + ) + from codeflash.languages.javascript.frameworks.react.context import ( # noqa: PLC0415 + extract_react_context, + ) + + components = find_react_components(source, function.file_path, analyzer) + for comp in components: + if comp.function_name == function.function_name: + react_ctx = extract_react_context(comp, source, analyzer, module_root) + react_context_str = react_ctx.to_prompt_string() + if react_context_str: + react_header = "\n\n// === React Component Context ===\n" + if read_only_context: + read_only_context = read_only_context + react_header + react_context_str + else: + read_only_context = react_context_str + break + except Exception as e: + logger.debug("React context extraction failed: %s", e) + # Validate that the extracted code is syntactically valid # If not, raise an error to fail the optimization early if target_code and not self.validate_syntax(target_code): @@ -440,6 +513,7 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, read_only_context=read_only_context, imports=import_lines, language=Language.JAVASCRIPT, + react_context=react_context_str if react_context_str else None, ) def _find_class_definition( diff --git a/codeflash/languages/javascript/treesitter_utils.py b/codeflash/languages/javascript/treesitter_utils.py new file mode 100644 index 000000000..b6126ec9a --- /dev/null +++ b/codeflash/languages/javascript/treesitter_utils.py @@ -0,0 +1,1588 @@ +"""Tree-sitter utilities for cross-language code analysis. + +This module provides a unified interface for parsing and analyzing code +across multiple languages using tree-sitter. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING + +from tree_sitter import Language, Parser + +if TYPE_CHECKING: + from pathlib import Path + + from tree_sitter import Node, Tree + +logger = logging.getLogger(__name__) + + +class TreeSitterLanguage(Enum): + """Supported tree-sitter languages.""" + + JAVASCRIPT = "javascript" + TYPESCRIPT = "typescript" + TSX = "tsx" + + +# Lazy-loaded language instances +_LANGUAGE_CACHE: dict[TreeSitterLanguage, Language] = {} + + +def _get_language(lang: TreeSitterLanguage) -> Language: + """Get a tree-sitter Language instance, with lazy loading.""" + if lang not in _LANGUAGE_CACHE: + if lang == TreeSitterLanguage.JAVASCRIPT: + import tree_sitter_javascript + + _LANGUAGE_CACHE[lang] = Language(tree_sitter_javascript.language()) + elif lang == TreeSitterLanguage.TYPESCRIPT: + import tree_sitter_typescript + + _LANGUAGE_CACHE[lang] = Language(tree_sitter_typescript.language_typescript()) + elif lang == TreeSitterLanguage.TSX: + import tree_sitter_typescript + + _LANGUAGE_CACHE[lang] = Language(tree_sitter_typescript.language_tsx()) + return _LANGUAGE_CACHE[lang] + + +@dataclass +class FunctionNode: + """Represents a function found by tree-sitter analysis.""" + + name: str + node: Node + start_line: int + end_line: int + start_col: int + end_col: int + is_async: bool + is_method: bool + is_arrow: bool + is_generator: bool + class_name: str | None + parent_function: str | None + source_text: str + doc_start_line: int | None = None # Line where JSDoc comment starts (or None if no JSDoc) + + +@dataclass +class ImportInfo: + """Represents an import statement.""" + + module_path: str # The path being imported from + default_import: str | None # Default import name (import X from ...) + named_imports: list[tuple[str, str | None]] # [(name, alias), ...] + namespace_import: str | None # Namespace import (import * as X from ...) + is_type_only: bool # TypeScript type-only import + start_line: int + end_line: int + + +@dataclass +class ExportInfo: + """Represents an export statement.""" + + exported_names: list[tuple[str, str | None]] # [(name, alias), ...] for named exports + default_export: str | None # Name of default exported function/class/value + is_reexport: bool # Whether this is a re-export (export { x } from './other') + reexport_source: str | None # Module path for re-exports + start_line: int + end_line: int + + +@dataclass +class ModuleLevelDeclaration: + """Represents a module-level (global) variable or constant declaration.""" + + name: str # Variable/constant name + declaration_type: str # "const", "let", "var", "class", "enum", "type", "interface" + source_code: str # Full declaration source code + start_line: int + end_line: int + is_exported: bool # Whether the declaration is exported + + +@dataclass +class TypeDefinition: + """Represents a type definition (interface, type alias, class, or enum).""" + + name: str # Type name + definition_type: str # "interface", "type", "class", "enum" + source_code: str # Full definition source code + start_line: int + end_line: int + is_exported: bool # Whether the definition is exported + file_path: Path | None = None # File where the type is defined + + +class TreeSitterAnalyzer: + """Cross-language code analysis using tree-sitter. + + This class provides methods to parse and analyze JavaScript/TypeScript code, + finding functions, imports, and other code structures. + """ + + def __init__(self, language: TreeSitterLanguage | str) -> None: + """Initialize the analyzer for a specific language. + + Args: + language: The language to analyze (TreeSitterLanguage enum or string). + + """ + if isinstance(language, str): + language = TreeSitterLanguage(language) + self.language = language + self._parser: Parser | None = None + + @property + def parser(self) -> Parser: + """Get the parser, creating it lazily.""" + if self._parser is None: + self._parser = Parser(_get_language(self.language)) + return self._parser + + def parse(self, source: str | bytes) -> Tree: + """Parse source code into a tree-sitter tree. + + Args: + source: Source code as string or bytes. + + Returns: + The parsed tree. + + """ + if isinstance(source, str): + source = source.encode("utf8") + return self.parser.parse(source) + + def get_node_text(self, node: Node, source: bytes) -> str: + """Extract the source text for a tree-sitter node. + + Args: + node: The tree-sitter node. + source: The source code as bytes. + + Returns: + The text content of the node. + + """ + return source[node.start_byte : node.end_byte].decode("utf8") + + def find_functions( + self, source: str, include_methods: bool = True, include_arrow_functions: bool = True, require_name: bool = True + ) -> list[FunctionNode]: + """Find all function definitions in source code. + + Args: + source: The source code to analyze. + include_methods: Whether to include class methods. + include_arrow_functions: Whether to include arrow functions. + require_name: Whether to require functions to have names. + + Returns: + List of FunctionNode objects describing found functions. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + functions: list[FunctionNode] = [] + + self._walk_tree_for_functions( + tree.root_node, + source_bytes, + functions, + include_methods=include_methods, + include_arrow_functions=include_arrow_functions, + require_name=require_name, + current_class=None, + current_function=None, + ) + + return functions + + def _walk_tree_for_functions( + self, + node: Node, + source_bytes: bytes, + functions: list[FunctionNode], + include_methods: bool, + include_arrow_functions: bool, + require_name: bool, + current_class: str | None, + current_function: str | None, + ) -> None: + """Recursively walk the tree to find function definitions.""" + # Function types in JavaScript/TypeScript + function_types = { + "function_declaration", + "function_expression", + "generator_function_declaration", + "generator_function", + } + + if include_arrow_functions: + function_types.add("arrow_function") + + if include_methods: + function_types.add("method_definition") + + # Track class context + new_class = current_class + new_function = current_function + + if node.type in {"class_declaration", "class"}: + # Get class name + name_node = node.child_by_field_name("name") + if name_node: + new_class = self.get_node_text(name_node, source_bytes) + + if node.type in function_types: + func_info = self._extract_function_info(node, source_bytes, current_class, current_function) + + if func_info: + # Check if we should include this function + should_include = True + + if require_name and not func_info.name: + should_include = False + + if func_info.is_method and not include_methods: + should_include = False + + if func_info.is_arrow and not include_arrow_functions: + should_include = False + + # Skip arrow functions that are object properties (e.g., { foo: () => {} }) + # These are not standalone functions - they're values in object literals + if func_info.is_arrow and node.parent and node.parent.type == "pair": + should_include = False + + if should_include: + functions.append(func_info) + + # Track as current function for nested functions + if func_info.name: + new_function = func_info.name + + # Recurse into children + for child in node.children: + self._walk_tree_for_functions( + child, + source_bytes, + functions, + include_methods=include_methods, + include_arrow_functions=include_arrow_functions, + require_name=require_name, + current_class=new_class, + current_function=new_function if node.type in function_types else current_function, + ) + + def _extract_function_info( + self, node: Node, source_bytes: bytes, current_class: str | None, current_function: str | None + ) -> FunctionNode | None: + """Extract function information from a tree-sitter node.""" + name = "" + is_async = False + is_generator = False + is_method = False + is_arrow = node.type == "arrow_function" + + # Check for async modifier + for child in node.children: + if child.type == "async": + is_async = True + break + + # Check for generator + if "generator" in node.type: + is_generator = True + + # Get function name based on node type + if node.type in ("function_declaration", "generator_function_declaration"): + name_node = node.child_by_field_name("name") + if name_node: + name = self.get_node_text(name_node, source_bytes) + else: + # Fallback: search for identifier child (some tree-sitter versions) + for child in node.children: + if child.type == "identifier": + name = self.get_node_text(child, source_bytes) + break + elif node.type == "method_definition": + is_method = True + name_node = node.child_by_field_name("name") + if name_node: + name = self.get_node_text(name_node, source_bytes) + elif node.type in ("function_expression", "generator_function"): + # Check if assigned to a variable + name_node = node.child_by_field_name("name") + if name_node: + name = self.get_node_text(name_node, source_bytes) + else: + # Try to get name from parent assignment + name = self._get_name_from_assignment(node, source_bytes) + elif node.type == "arrow_function": + # Arrow functions get names from variable declarations + name = self._get_name_from_assignment(node, source_bytes) + + # Get source text + source_text = self.get_node_text(node, source_bytes) + + # Find preceding JSDoc comment + doc_start_line = self._find_preceding_jsdoc(node, source_bytes) + + return FunctionNode( + name=name, + node=node, + start_line=node.start_point[0] + 1, # Convert to 1-indexed + end_line=node.end_point[0] + 1, + start_col=node.start_point[1], + end_col=node.end_point[1], + is_async=is_async, + is_method=is_method, + is_arrow=is_arrow, + is_generator=is_generator, + class_name=current_class if is_method else None, + parent_function=current_function, + source_text=source_text, + doc_start_line=doc_start_line, + ) + + def _find_preceding_jsdoc(self, node: Node, source_bytes: bytes) -> int | None: + """Find JSDoc comment immediately preceding a function node. + + For regular functions, looks at the previous sibling of the function node. + For arrow functions assigned to variables, looks at the previous sibling + of the variable declaration. + + Args: + node: The function node to find JSDoc for. + source_bytes: The source code as bytes. + + Returns: + The start line (1-indexed) of the JSDoc, or None if no JSDoc found. + + """ + target_node = node + + # For arrow functions, look at parent variable declaration + if node.type == "arrow_function": + parent = node.parent + if parent and parent.type == "variable_declarator": + grandparent = parent.parent + if grandparent and grandparent.type in ("lexical_declaration", "variable_declaration"): + target_node = grandparent + + # For function expressions assigned to variables, also look at parent + if node.type in ("function_expression", "generator_function"): + parent = node.parent + if parent and parent.type == "variable_declarator": + grandparent = parent.parent + if grandparent and grandparent.type in ("lexical_declaration", "variable_declaration"): + target_node = grandparent + + # Get the previous sibling node + prev_sibling = target_node.prev_named_sibling + + # Check if it's a comment node with JSDoc pattern + if prev_sibling and prev_sibling.type == "comment": + comment_text = self.get_node_text(prev_sibling, source_bytes) + if comment_text.strip().startswith("/**"): + # Verify it's immediately preceding (no blank lines between) + comment_end_line = prev_sibling.end_point[0] + function_start_line = target_node.start_point[0] + if function_start_line - comment_end_line <= 1: + return prev_sibling.start_point[0] + 1 # 1-indexed + + return None + + def _get_name_from_assignment(self, node: Node, source_bytes: bytes) -> str: + """Try to extract function name from parent variable declaration or assignment. + + Handles patterns like: + - const foo = () => {} + - const foo = function() {} + - let bar = function() {} + - obj.method = () => {} + """ + parent = node.parent + if parent is None: + return "" + + # Check for variable declarator: const foo = ... + if parent.type == "variable_declarator": + name_node = parent.child_by_field_name("name") + if name_node: + return self.get_node_text(name_node, source_bytes) + + # Check for assignment expression: foo = ... + if parent.type == "assignment_expression": + left_node = parent.child_by_field_name("left") + if left_node: + if left_node.type == "identifier": + return self.get_node_text(left_node, source_bytes) + if left_node.type == "member_expression": + # For obj.method = ..., get the property name + prop_node = left_node.child_by_field_name("property") + if prop_node: + return self.get_node_text(prop_node, source_bytes) + + # Check for property in object: { foo: () => {} } + if parent.type == "pair": + key_node = parent.child_by_field_name("key") + if key_node: + return self.get_node_text(key_node, source_bytes) + + return "" + + def find_imports(self, source: str) -> list[ImportInfo]: + """Find all import statements in source code. + + Args: + source: The source code to analyze. + + Returns: + List of ImportInfo objects describing imports. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + imports: list[ImportInfo] = [] + + self._walk_tree_for_imports(tree.root_node, source_bytes, imports) + + return imports + + def _walk_tree_for_imports( + self, node: Node, source_bytes: bytes, imports: list[ImportInfo], in_function: bool = False + ) -> None: + """Recursively walk the tree to find import statements. + + Args: + node: Current node to check. + source_bytes: Source code bytes. + imports: List to append found imports to. + in_function: Whether we're currently inside a function/method body. + + """ + # Track when we enter function/method bodies + # These node types contain function/method bodies where require() should not be treated as imports + function_body_types = { + "function_declaration", + "method_definition", + "arrow_function", + "function_expression", + "function", # Generic function in some grammars + } + + if node.type == "import_statement": + import_info = self._extract_import_info(node, source_bytes) + if import_info: + imports.append(import_info) + + # Also handle require() calls for CommonJS, but only at module level + # require() inside functions is a dynamic import, not a module import + if node.type == "call_expression" and not in_function: + func_node = node.child_by_field_name("function") + if func_node and self.get_node_text(func_node, source_bytes) == "require": + import_info = self._extract_require_info(node, source_bytes) + if import_info: + imports.append(import_info) + + # Update in_function flag for children + child_in_function = in_function or node.type in function_body_types + + for child in node.children: + self._walk_tree_for_imports(child, source_bytes, imports, child_in_function) + + def _extract_import_info(self, node: Node, source_bytes: bytes) -> ImportInfo | None: + """Extract import information from an import statement node.""" + module_path = "" + default_import = None + named_imports: list[tuple[str, str | None]] = [] + namespace_import = None + is_type_only = False + + # Get the module path (source) + source_node = node.child_by_field_name("source") + if source_node: + # Remove quotes from string + module_path = self.get_node_text(source_node, source_bytes).strip("'\"") + + # Check for type-only import (TypeScript) + for child in node.children: + if child.type == "type" or self.get_node_text(child, source_bytes) == "type": + is_type_only = True + break + + # Process import clause + for child in node.children: + if child.type == "import_clause": + self._process_import_clause(child, source_bytes, default_import, named_imports, namespace_import) + # Re-extract after processing + for clause_child in child.children: + if clause_child.type == "identifier": + default_import = self.get_node_text(clause_child, source_bytes) + elif clause_child.type == "named_imports": + for spec in clause_child.children: + if spec.type == "import_specifier": + name_node = spec.child_by_field_name("name") + alias_node = spec.child_by_field_name("alias") + if name_node: + name = self.get_node_text(name_node, source_bytes) + alias = self.get_node_text(alias_node, source_bytes) if alias_node else None + named_imports.append((name, alias)) + elif clause_child.type == "namespace_import": + # import * as X + for ns_child in clause_child.children: + if ns_child.type == "identifier": + namespace_import = self.get_node_text(ns_child, source_bytes) + + if not module_path: + return None + + return ImportInfo( + module_path=module_path, + default_import=default_import, + named_imports=named_imports, + namespace_import=namespace_import, + is_type_only=is_type_only, + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + ) + + def _process_import_clause( + self, + node: Node, + source_bytes: bytes, + default_import: str | None, + named_imports: list[tuple[str, str | None]], + namespace_import: str | None, + ) -> None: + """Process an import clause to extract imports.""" + # This is a helper that modifies the lists in place + # Processing is done inline in _extract_import_info + + def _extract_require_info(self, node: Node, source_bytes: bytes) -> ImportInfo | None: + """Extract import information from a require() call. + + Handles various CommonJS require patterns: + - const foo = require('./module') -> default import + - const { a, b } = require('./module') -> named imports + - const { a: aliasA } = require('./module') -> named imports with alias + - const foo = require('./module').bar -> property access (named import) + - require('./module') -> side effect import + """ + # Handle require().property pattern - the call_expression is inside member_expression + actual_require_node = node + property_access = None + + # Check if this require is part of a member_expression like require('./m').foo + if node.parent and node.parent.type == "member_expression": + member_node = node.parent + prop_node = member_node.child_by_field_name("property") + if prop_node: + property_access = self.get_node_text(prop_node, source_bytes) + # Use the member expression's parent for variable assignment lookup + node = member_node + + args_node = actual_require_node.child_by_field_name("arguments") + if not args_node: + return None + + # Get the first argument (module path) + module_path = "" + for child in args_node.children: + if child.type == "string": + module_path = self.get_node_text(child, source_bytes).strip("'\"") + break + + if not module_path: + return None + + # Try to get the variable name from assignment + default_import = None + named_imports: list[tuple[str, str | None]] = [] + + parent = node.parent + if parent and parent.type == "variable_declarator": + name_node = parent.child_by_field_name("name") + if name_node: + if name_node.type == "identifier": + var_name = self.get_node_text(name_node, source_bytes) + if property_access: + # const foo = require('./module').bar + # This imports 'bar' from the module and assigns to 'foo' + named_imports.append((property_access, var_name if var_name != property_access else None)) + else: + # const foo = require('./module') + default_import = var_name + elif name_node.type == "object_pattern": + # Destructuring: const { a, b } = require('...') + named_imports = self._extract_object_pattern_names(name_node, source_bytes) + elif property_access: + # require('./module').foo without assignment - still track the property access + named_imports.append((property_access, None)) + + return ImportInfo( + module_path=module_path, + default_import=default_import, + named_imports=named_imports, + namespace_import=None, + is_type_only=False, + start_line=actual_require_node.start_point[0] + 1, + end_line=actual_require_node.end_point[0] + 1, + ) + + def _extract_object_pattern_names(self, node: Node, source_bytes: bytes) -> list[tuple[str, str | None]]: + """Extract names from an object pattern (destructuring). + + Handles patterns like: + - { a, b } -> [('a', None), ('b', None)] + - { a: aliasA } -> [('a', 'aliasA')] + - { a, b: aliasB } -> [('a', None), ('b', 'aliasB')] + """ + names: list[tuple[str, str | None]] = [] + + for child in node.children: + if child.type == "shorthand_property_identifier_pattern": + # { a } - shorthand, name equals value + name = self.get_node_text(child, source_bytes) + names.append((name, None)) + elif child.type == "pair_pattern": + # { a: aliasA } - renamed import + key_node = child.child_by_field_name("key") + value_node = child.child_by_field_name("value") + if key_node and value_node: + original_name = self.get_node_text(key_node, source_bytes) + alias = self.get_node_text(value_node, source_bytes) + names.append((original_name, alias)) + + return names + + def find_exports(self, source: str) -> list[ExportInfo]: + """Find all export statements in source code. + + Args: + source: The source code to analyze. + + Returns: + List of ExportInfo objects describing exports. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + exports: list[ExportInfo] = [] + + self._walk_tree_for_exports(tree.root_node, source_bytes, exports) + + return exports + + def _walk_tree_for_exports(self, node: Node, source_bytes: bytes, exports: list[ExportInfo]) -> None: + """Recursively walk the tree to find export statements.""" + # Handle ES module export statements + if node.type == "export_statement": + export_info = self._extract_export_info(node, source_bytes) + if export_info: + exports.append(export_info) + + # Handle CommonJS exports: module.exports = ... or exports.foo = ... + if node.type == "assignment_expression": + export_info = self._extract_commonjs_export(node, source_bytes) + if export_info: + exports.append(export_info) + + for child in node.children: + self._walk_tree_for_exports(child, source_bytes, exports) + + def _extract_export_info(self, node: Node, source_bytes: bytes) -> ExportInfo | None: + """Extract export information from an export statement node.""" + exported_names: list[tuple[str, str | None]] = [] + default_export: str | None = None + is_reexport = False + reexport_source: str | None = None + + # Check for re-export source (export { x } from './other') + source_node = node.child_by_field_name("source") + if source_node: + is_reexport = True + reexport_source = self.get_node_text(source_node, source_bytes).strip("'\"") + + for child in node.children: + # Handle 'export default' + if child.type == "default": + # Find what's being exported as default + for sibling in node.children: + if sibling.type in {"function_declaration", "class_declaration"}: + name_node = sibling.child_by_field_name("name") + default_export = self.get_node_text(name_node, source_bytes) if name_node else "default" + elif sibling.type == "identifier": + default_export = self.get_node_text(sibling, source_bytes) + elif sibling.type in ("arrow_function", "function_expression", "object", "array"): + default_export = "default" + break + + # Handle named exports: export { a, b as c } + if child.type == "export_clause": + for spec in child.children: + if spec.type == "export_specifier": + name_node = spec.child_by_field_name("name") + alias_node = spec.child_by_field_name("alias") + if name_node: + name = self.get_node_text(name_node, source_bytes) + alias = self.get_node_text(alias_node, source_bytes) if alias_node else None + exported_names.append((name, alias)) + + # Handle direct exports: export function foo() {} + if child.type == "function_declaration": + name_node = child.child_by_field_name("name") + if name_node: + name = self.get_node_text(name_node, source_bytes) + exported_names.append((name, None)) + + # Handle direct class exports: export class Foo {} + if child.type == "class_declaration": + name_node = child.child_by_field_name("name") + if name_node: + name = self.get_node_text(name_node, source_bytes) + exported_names.append((name, None)) + + # Handle variable exports: export const foo = ... + if child.type == "lexical_declaration": + for decl in child.children: + if decl.type == "variable_declarator": + name_node = decl.child_by_field_name("name") + if name_node and name_node.type == "identifier": + name = self.get_node_text(name_node, source_bytes) + exported_names.append((name, None)) + + # Skip if no exports found + if not exported_names and not default_export: + return None + + return ExportInfo( + exported_names=exported_names, + default_export=default_export, + is_reexport=is_reexport, + reexport_source=reexport_source, + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + ) + + def _extract_commonjs_export(self, node: Node, source_bytes: bytes) -> ExportInfo | None: + """Extract export information from CommonJS module.exports or exports.* patterns. + + Handles patterns like: + - module.exports = function() {} -> default export + - module.exports = { foo, bar } -> named exports + - module.exports.foo = function() {} -> named export 'foo' + - exports.foo = function() {} -> named export 'foo' + - module.exports = require('./other') -> re-export + """ + left_node = node.child_by_field_name("left") + right_node = node.child_by_field_name("right") + + if not left_node or not right_node: + return None + + # Check if this is a module.exports or exports.* pattern + if left_node.type != "member_expression": + return None + + left_text = self.get_node_text(left_node, source_bytes) + + exported_names: list[tuple[str, str | None]] = [] + default_export: str | None = None + is_reexport = False + reexport_source: str | None = None + + if left_text == "module.exports": + # module.exports = something + if right_node.type in {"function_expression", "arrow_function"}: + # module.exports = function foo() {} or module.exports = () => {} + name_node = right_node.child_by_field_name("name") + default_export = self.get_node_text(name_node, source_bytes) if name_node else "default" + elif right_node.type == "identifier": + # module.exports = someFunction + default_export = self.get_node_text(right_node, source_bytes) + elif right_node.type == "object": + # module.exports = { foo, bar, baz: qux } + for child in right_node.children: + if child.type == "shorthand_property_identifier": + # { foo } - exports function named foo + name = self.get_node_text(child, source_bytes) + exported_names.append((name, None)) + elif child.type == "pair": + # { baz: qux } - exports qux as baz + key_node = child.child_by_field_name("key") + value_node = child.child_by_field_name("value") + if key_node and value_node: + export_name = self.get_node_text(key_node, source_bytes) + local_name = self.get_node_text(value_node, source_bytes) + # In CommonJS { baz: qux }, baz is the exported name, qux is local + exported_names.append((local_name, export_name)) + elif right_node.type == "call_expression": + # module.exports = require('./other') - re-export + func_node = right_node.child_by_field_name("function") + if func_node and self.get_node_text(func_node, source_bytes) == "require": + is_reexport = True + args_node = right_node.child_by_field_name("arguments") + if args_node: + for arg in args_node.children: + if arg.type == "string": + reexport_source = self.get_node_text(arg, source_bytes).strip("'\"") + break + default_export = "default" + else: + # module.exports = something else (class, etc.) + default_export = "default" + + elif left_text.startswith("module.exports."): + # module.exports.foo = something + prop_name = left_text.split(".", 2)[2] # Get 'foo' from 'module.exports.foo' + exported_names.append((prop_name, None)) + + elif left_text.startswith("exports."): + # exports.foo = something + prop_name = left_text.split(".", 1)[1] # Get 'foo' from 'exports.foo' + exported_names.append((prop_name, None)) + + else: + # Not a CommonJS export pattern + return None + + # Skip if no exports found + if not exported_names and not default_export: + return None + + return ExportInfo( + exported_names=exported_names, + default_export=default_export, + is_reexport=is_reexport, + reexport_source=reexport_source, + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + ) + + def is_function_exported( + self, source: str, function_name: str, class_name: str | None = None + ) -> tuple[bool, str | None]: + """Check if a function is exported and get its export name. + + For class methods, also checks if the containing class is exported. + + Args: + source: The source code to analyze. + function_name: The name of the function to check. + class_name: For class methods, the name of the containing class. + + Returns: + Tuple of (is_exported, export_name). export_name may differ from + function_name if exported with an alias. For class methods, + returns the class export name. + + """ + exports = self.find_exports(source) + + # First, check if the function itself is directly exported + for export in exports: + # Check default export + if export.default_export == function_name: + return (True, "default") + + # Check named exports + for name, alias in export.exported_names: + if name == function_name: + return (True, alias or name) + + # For class methods, check if the containing class is exported + if class_name: + for export in exports: + # Check if class is default export + if export.default_export == class_name: + return (True, class_name) + + # Check if class is in named exports + for name, alias in export.exported_names: + if name == class_name: + return (True, alias or name) + + return (False, None) + + def find_function_calls(self, source: str, within_function: FunctionNode) -> list[str]: + """Find all function calls within a specific function's body. + + Args: + source: The full source code. + within_function: The function to search within. + + Returns: + List of function names that are called. + + """ + calls: list[str] = [] + source_bytes = source.encode("utf8") + + # Get the body of the function + body_node = within_function.node.child_by_field_name("body") + if body_node is None: + # For arrow functions, the body might be the last child + for child in within_function.node.children: + if child.type in ("statement_block", "expression_statement") or ( + child.type not in ("identifier", "formal_parameters", "async", "=>") + ): + body_node = child + break + + if body_node: + self._walk_tree_for_calls(body_node, source_bytes, calls) + + return list(set(calls)) # Remove duplicates + + def _walk_tree_for_calls(self, node: Node, source_bytes: bytes, calls: list[str]) -> None: + """Recursively find function calls in a subtree.""" + if node.type == "call_expression": + func_node = node.child_by_field_name("function") + if func_node: + if func_node.type == "identifier": + calls.append(self.get_node_text(func_node, source_bytes)) + elif func_node.type == "member_expression": + # For method calls like obj.method(), get the method name + prop_node = func_node.child_by_field_name("property") + if prop_node: + calls.append(self.get_node_text(prop_node, source_bytes)) + + for child in node.children: + self._walk_tree_for_calls(child, source_bytes, calls) + + def find_module_level_declarations(self, source: str) -> list[ModuleLevelDeclaration]: + """Find all module-level variable/constant declarations. + + This finds global variables, constants, classes, enums, type aliases, + and interfaces defined at the top level of the module (not inside functions). + + Args: + source: The source code to analyze. + + Returns: + List of ModuleLevelDeclaration objects. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + declarations: list[ModuleLevelDeclaration] = [] + + # Only look at direct children of the program/module node (top-level) + for child in tree.root_node.children: + self._extract_module_level_declaration(child, source_bytes, declarations) + + return declarations + + def _extract_module_level_declaration( + self, node: Node, source_bytes: bytes, declarations: list[ModuleLevelDeclaration] + ) -> None: + """Extract module-level declarations from a node.""" + is_exported = False + + # Handle export statements - unwrap to get the actual declaration + if node.type == "export_statement": + is_exported = True + # Find the actual declaration inside the export + for child in node.children: + if child.type in ("lexical_declaration", "variable_declaration"): + self._extract_declaration(child, source_bytes, declarations, is_exported, node) + return + if child.type == "class_declaration": + name_node = child.child_by_field_name("name") + if name_node: + declarations.append( + ModuleLevelDeclaration( + name=self.get_node_text(name_node, source_bytes), + declaration_type="class", + source_code=self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + return + if child.type in ("type_alias_declaration", "interface_declaration", "enum_declaration"): + name_node = child.child_by_field_name("name") + if name_node: + decl_type = child.type.replace("_declaration", "").replace("_alias", "") + declarations.append( + ModuleLevelDeclaration( + name=self.get_node_text(name_node, source_bytes), + declaration_type=decl_type, + source_code=self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + return + return + + # Handle non-exported declarations + if node.type in ( + "lexical_declaration", # const/let + "variable_declaration", # var + ): + self._extract_declaration(node, source_bytes, declarations, is_exported, node) + elif node.type == "class_declaration": + name_node = node.child_by_field_name("name") + if name_node: + declarations.append( + ModuleLevelDeclaration( + name=self.get_node_text(name_node, source_bytes), + declaration_type="class", + source_code=self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + elif node.type in ("type_alias_declaration", "interface_declaration", "enum_declaration"): + name_node = node.child_by_field_name("name") + if name_node: + decl_type = node.type.replace("_declaration", "").replace("_alias", "") + declarations.append( + ModuleLevelDeclaration( + name=self.get_node_text(name_node, source_bytes), + declaration_type=decl_type, + source_code=self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + + def _extract_declaration( + self, + node: Node, + source_bytes: bytes, + declarations: list[ModuleLevelDeclaration], + is_exported: bool, + source_node: Node, + ) -> None: + """Extract variable declarations (const/let/var).""" + # Determine declaration type (const, let, var) + decl_type = "var" + for child in node.children: + if child.type in ("const", "let", "var"): + decl_type = child.type + break + + # Find variable declarators + for child in node.children: + if child.type == "variable_declarator": + name_node = child.child_by_field_name("name") + if name_node: + # Handle destructuring patterns + if name_node.type == "identifier": + declarations.append( + ModuleLevelDeclaration( + name=self.get_node_text(name_node, source_bytes), + declaration_type=decl_type, + source_code=self.get_node_text(source_node, source_bytes), + start_line=source_node.start_point[0] + 1, + end_line=source_node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + elif name_node.type in ("object_pattern", "array_pattern"): + # For destructuring, extract all bound identifiers + identifiers = self._extract_pattern_identifiers(name_node, source_bytes) + for ident in identifiers: + declarations.append( + ModuleLevelDeclaration( + name=ident, + declaration_type=decl_type, + source_code=self.get_node_text(source_node, source_bytes), + start_line=source_node.start_point[0] + 1, + end_line=source_node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + + def _extract_pattern_identifiers(self, pattern_node: Node, source_bytes: bytes) -> list[str]: + """Extract all identifier names from a destructuring pattern.""" + identifiers: list[str] = [] + + def walk(n: Node) -> None: + if n.type in {"identifier", "shorthand_property_identifier_pattern"}: + identifiers.append(self.get_node_text(n, source_bytes)) + for child in n.children: + walk(child) + + walk(pattern_node) + return identifiers + + def find_referenced_identifiers(self, source: str) -> set[str]: + """Find all identifiers referenced in the source code. + + This finds all identifier references, excluding: + - Declaration names (left side of assignments) + - Property names in object literals + - Function/class names at definition site + + Args: + source: The source code to analyze. + + Returns: + Set of referenced identifier names. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + references: set[str] = set() + + self._walk_tree_for_references(tree.root_node, source_bytes, references) + + return references + + def _walk_tree_for_references(self, node: Node, source_bytes: bytes, references: set[str]) -> None: + """Walk tree to collect identifier references.""" + if node.type == "identifier": + # Check if this identifier is a reference (not a declaration) + parent = node.parent + if parent is None: + return + + # Skip function/class/method names at definition + if parent.type in ("function_declaration", "class_declaration", "method_definition", "function_expression"): + if parent.child_by_field_name("name") == node: + # Don't recurse into parent's children - the parent will be visited separately + return + + # Skip variable declarator names (left side of declaration) + if parent.type == "variable_declarator" and parent.child_by_field_name("name") == node: + # Don't recurse - the value will be visited when we visit the declarator + return + + # Skip property names in object literals (keys) + if parent.type == "pair" and parent.child_by_field_name("key") == node: + # Don't recurse - the value will be visited when we visit the pair + return + + # Skip property access property names (obj.property - skip 'property') + if parent.type == "member_expression" and parent.child_by_field_name("property") == node: + # Don't recurse - the object will be visited when we visit the member_expression + return + + # Skip import specifier names + if parent.type in ("import_specifier", "import_clause", "namespace_import"): + return + + # Skip export specifier names + if parent.type == "export_specifier": + return + + # Skip parameter names in function definitions (but NOT default values) + if parent.type == "formal_parameters": + return + if parent.type == "required_parameter": + # Only skip if this is the parameter name (pattern field), not the default value + if parent.child_by_field_name("pattern") == node: + return + # If it's the value field (default value), it's a reference - don't skip + + # This is a reference + references.add(self.get_node_text(node, source_bytes)) + return + + # Recurse into children + for child in node.children: + self._walk_tree_for_references(child, source_bytes, references) + + def has_return_statement(self, function_node: FunctionNode, source: str) -> bool: + """Check if a function has a return statement. + + Args: + function_node: The function to check. + source: The source code. + + Returns: + True if the function has a return statement. + + """ + source_bytes = source.encode("utf8") + + # Generator functions always implicitly return a Generator/Iterator + if function_node.is_generator: + return True + + # For arrow functions with expression body, there's an implicit return + if function_node.is_arrow: + body_node = function_node.node.child_by_field_name("body") + if body_node and body_node.type != "statement_block": + # Expression body (implicit return) + return True + + return self._node_has_return(function_node.node) + + def _node_has_return(self, node: Node) -> bool: + """Recursively check if a node contains a return statement.""" + if node.type == "return_statement": + return True + + # Don't recurse into nested function definitions + if node.type in ("function_declaration", "function_expression", "arrow_function", "method_definition"): + # Only check the current function, not nested ones + body_node = node.child_by_field_name("body") + if body_node: + for child in body_node.children: + if self._node_has_return(child): + return True + return False + + return any(self._node_has_return(child) for child in node.children) + + def extract_type_annotations(self, source: str, function_name: str, function_line: int) -> set[str]: + """Extract type annotation names from a function's parameters and return type. + + Finds the function by name and line number, then extracts all user-defined type names + from its type annotations (parameters and return type). + + Args: + source: The source code to analyze. + function_name: Name of the function to find. + function_line: Start line of the function (1-indexed). + + Returns: + Set of type names found in the function's annotations. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + type_names: set[str] = set() + + # Find the function node + func_node = self._find_function_node(tree.root_node, source_bytes, function_name, function_line) + if not func_node: + return type_names + + # Extract type annotations from parameters + params_node = func_node.child_by_field_name("parameters") + if params_node: + self._extract_type_names_from_node(params_node, source_bytes, type_names) + + # Extract return type annotation + return_type_node = func_node.child_by_field_name("return_type") + if return_type_node: + self._extract_type_names_from_node(return_type_node, source_bytes, type_names) + + return type_names + + def extract_class_field_types(self, source: str, class_name: str) -> set[str]: + """Extract type annotation names from class field declarations. + + Args: + source: The source code to analyze. + class_name: Name of the class to analyze. + + Returns: + Set of type names found in class field annotations. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + type_names: set[str] = set() + + # Find the class node + class_node = self._find_class_node(tree.root_node, source_bytes, class_name) + if not class_node: + return type_names + + # Find class body and extract field type annotations + body_node = class_node.child_by_field_name("body") + if body_node: + for child in body_node.children: + # Handle public_field_definition (JS/TS class fields) + if child.type in ("public_field_definition", "field_definition"): + type_annotation = child.child_by_field_name("type") + if type_annotation: + self._extract_type_names_from_node(type_annotation, source_bytes, type_names) + + return type_names + + def _find_function_node( + self, node: Node, source_bytes: bytes, function_name: str, function_line: int + ) -> Node | None: + """Find a function/method node by name and line number.""" + if node.type in ( + "function_declaration", + "method_definition", + "function_expression", + "generator_function_declaration", + ): + name_node = node.child_by_field_name("name") + if name_node: + name = self.get_node_text(name_node, source_bytes) + # Line is 1-indexed, tree-sitter is 0-indexed + if name == function_name and (node.start_point[0] + 1) == function_line: + return node + + # Check arrow functions assigned to variables + if node.type == "lexical_declaration": + for child in node.children: + if child.type == "variable_declarator": + name_node = child.child_by_field_name("name") + value_node = child.child_by_field_name("value") + if name_node and value_node and value_node.type == "arrow_function": + name = self.get_node_text(name_node, source_bytes) + if name == function_name and (node.start_point[0] + 1) == function_line: + return value_node + + # Recurse into children + for child in node.children: + result = self._find_function_node(child, source_bytes, function_name, function_line) + if result: + return result + + return None + + def _find_class_node(self, node: Node, source_bytes: bytes, class_name: str) -> Node | None: + """Find a class node by name.""" + if node.type in ("class_declaration", "class"): + name_node = node.child_by_field_name("name") + if name_node: + name = self.get_node_text(name_node, source_bytes) + if name == class_name: + return node + + for child in node.children: + result = self._find_class_node(child, source_bytes, class_name) + if result: + return result + + return None + + def _extract_type_names_from_node(self, node: Node, source_bytes: bytes, type_names: set[str]) -> None: + """Recursively extract type names from a type annotation node. + + Handles various TypeScript type annotation patterns: + - Simple types: number, string, Point + - Generic types: Array, Promise + - Union types: A | B + - Intersection types: A & B + - Array types: T[] + - Tuple types: [A, B] + - Object/mapped types: { key: Type } + + Args: + node: Tree-sitter node to analyze. + source_bytes: Source code as bytes. + type_names: Set to add found type names to. + + """ + # Handle type identifiers (the actual type name references) + if node.type == "type_identifier": + type_name = self.get_node_text(node, source_bytes) + # Skip primitive types + if type_name not in ( + "number", + "string", + "boolean", + "void", + "null", + "undefined", + "any", + "never", + "unknown", + "object", + "symbol", + "bigint", + ): + type_names.add(type_name) + return + + # Handle regular identifiers in type position (can happen in some contexts) + if node.type == "identifier" and node.parent and node.parent.type in ("type_annotation", "generic_type"): + type_name = self.get_node_text(node, source_bytes) + if type_name not in ( + "number", + "string", + "boolean", + "void", + "null", + "undefined", + "any", + "never", + "unknown", + "object", + "symbol", + "bigint", + ): + type_names.add(type_name) + return + + # Handle nested_type_identifier (e.g., Namespace.Type) + if node.type == "nested_type_identifier": + # Get the full qualified name + type_name = self.get_node_text(node, source_bytes) + # Add both the full name and the first part (namespace) + type_names.add(type_name) + # Also extract the module/namespace part + module_node = node.child_by_field_name("module") + if module_node: + type_names.add(self.get_node_text(module_node, source_bytes)) + return + + # Recurse into all children for compound types + for child in node.children: + self._extract_type_names_from_node(child, source_bytes, type_names) + + def find_type_definitions(self, source: str) -> list[TypeDefinition]: + """Find all type definitions (interface, type, class, enum) in source code. + + Args: + source: The source code to analyze. + + Returns: + List of TypeDefinition objects. + + """ + source_bytes = source.encode("utf8") + tree = self.parse(source_bytes) + definitions: list[TypeDefinition] = [] + + # Walk through top-level nodes + for child in tree.root_node.children: + self._extract_type_definition(child, source_bytes, definitions) + + return definitions + + def _extract_type_definition( + self, node: Node, source_bytes: bytes, definitions: list[TypeDefinition], is_exported: bool = False + ) -> None: + """Extract type definitions from a node.""" + # Handle export statements - unwrap to get the actual definition + if node.type == "export_statement": + for child in node.children: + if child.type in ( + "interface_declaration", + "type_alias_declaration", + "class_declaration", + "enum_declaration", + ): + self._extract_type_definition(child, source_bytes, definitions, is_exported=True) + return + + # Extract interface definitions + if node.type == "interface_declaration": + name_node = node.child_by_field_name("name") + if name_node: + # Look for preceding JSDoc comment + jsdoc = "" + prev_sibling = node.prev_named_sibling + if prev_sibling and prev_sibling.type == "comment": + comment_text = self.get_node_text(prev_sibling, source_bytes) + if comment_text.strip().startswith("/**"): + jsdoc = comment_text + "\n" + + definitions.append( + TypeDefinition( + name=self.get_node_text(name_node, source_bytes), + definition_type="interface", + source_code=jsdoc + self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + + # Extract type alias definitions + elif node.type == "type_alias_declaration": + name_node = node.child_by_field_name("name") + if name_node: + # Look for preceding JSDoc comment + jsdoc = "" + prev_sibling = node.prev_named_sibling + if prev_sibling and prev_sibling.type == "comment": + comment_text = self.get_node_text(prev_sibling, source_bytes) + if comment_text.strip().startswith("/**"): + jsdoc = comment_text + "\n" + + definitions.append( + TypeDefinition( + name=self.get_node_text(name_node, source_bytes), + definition_type="type", + source_code=jsdoc + self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + + # Extract enum definitions + elif node.type == "enum_declaration": + name_node = node.child_by_field_name("name") + if name_node: + # Look for preceding JSDoc comment + jsdoc = "" + prev_sibling = node.prev_named_sibling + if prev_sibling and prev_sibling.type == "comment": + comment_text = self.get_node_text(prev_sibling, source_bytes) + if comment_text.strip().startswith("/**"): + jsdoc = comment_text + "\n" + + definitions.append( + TypeDefinition( + name=self.get_node_text(name_node, source_bytes), + definition_type="enum", + source_code=jsdoc + self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + + # Extract class definitions (as types) + elif node.type == "class_declaration": + name_node = node.child_by_field_name("name") + if name_node: + # Look for preceding JSDoc comment + jsdoc = "" + prev_sibling = node.prev_named_sibling + if prev_sibling and prev_sibling.type == "comment": + comment_text = self.get_node_text(prev_sibling, source_bytes) + if comment_text.strip().startswith("/**"): + jsdoc = comment_text + "\n" + + definitions.append( + TypeDefinition( + name=self.get_node_text(name_node, source_bytes), + definition_type="class", + source_code=jsdoc + self.get_node_text(node, source_bytes), + start_line=node.start_point[0] + 1, + end_line=node.end_point[0] + 1, + is_exported=is_exported, + ) + ) + + +def get_analyzer_for_file(file_path: Path) -> TreeSitterAnalyzer: + """Get the appropriate TreeSitterAnalyzer for a file based on its extension. + + Args: + file_path: Path to the file. + + Returns: + TreeSitterAnalyzer configured for the file's language. + + """ + suffix = file_path.suffix.lower() + + if suffix in (".ts",): + return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT) + if suffix in (".tsx",): + return TreeSitterAnalyzer(TreeSitterLanguage.TSX) + # Default to JavaScript for .js, .jsx, .mjs, .cjs + return TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT) diff --git a/codeflash/models/function_types.py b/codeflash/models/function_types.py index bea6672b0..dc51aaa92 100644 --- a/codeflash/models/function_types.py +++ b/codeflash/models/function_types.py @@ -7,7 +7,7 @@ from __future__ import annotations from pathlib import Path -from typing import Optional +from typing import Any, Optional from pydantic import Field from pydantic.dataclasses import dataclass @@ -61,6 +61,7 @@ class FunctionToOptimize: is_method: bool = False language: str = "python" doc_start_line: Optional[int] = None + metadata: Optional[dict[str, Any]] = Field(default=None) @property def top_level_parent_name(self) -> str: From 3601a18d608e35babaa53e095ad31f8cdc6399cc Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 02:54:14 +0000 Subject: [PATCH 02/37] style: auto-fix linting issues and resolve mypy type errors --- .../javascript/frameworks/react/analyzer.py | 102 +++++++++--------- .../javascript/frameworks/react/context.py | 38 +++---- .../javascript/frameworks/react/discovery.py | 81 ++++++++------ .../javascript/frameworks/react/profiler.py | 31 +++--- .../javascript/frameworks/react/testgen.py | 14 +-- codeflash/languages/javascript/parse.py | 20 ++-- codeflash/languages/javascript/support.py | 36 +++---- 7 files changed, 158 insertions(+), 164 deletions(-) diff --git a/codeflash/languages/javascript/frameworks/react/analyzer.py b/codeflash/languages/javascript/frameworks/react/analyzer.py index db87c22e6..2de39c802 100644 --- a/codeflash/languages/javascript/frameworks/react/analyzer.py +++ b/codeflash/languages/javascript/frameworks/react/analyzer.py @@ -44,9 +44,7 @@ class OptimizationOpportunity: # Patterns for expensive operations inside render body -EXPENSIVE_OPS_RE = re.compile( - r"\.(filter|map|sort|reduce|flatMap|find|findIndex|every|some)\s*\(" -) +EXPENSIVE_OPS_RE = re.compile(r"\.(filter|map|sort|reduce|flatMap|find|findIndex|every|some)\s*\(") INLINE_OBJECT_IN_JSX_RE = re.compile(r"=\{\s*\{") # ={{ ... }} in JSX INLINE_ARRAY_IN_JSX_RE = re.compile(r"=\{\s*\[") # ={[ ... ]} in JSX FUNCTION_DEF_RE = re.compile( @@ -57,9 +55,7 @@ class OptimizationOpportunity: USEMEMO_RE = re.compile(r"\buseMemo\s*\(") -def detect_optimization_opportunities( - source: str, component_info: ReactComponentInfo -) -> list[OptimizationOpportunity]: +def detect_optimization_opportunities(source: str, component_info: ReactComponentInfo) -> list[OptimizationOpportunity]: """Detect optimization opportunities in a React component.""" opportunities: list[OptimizationOpportunity] = [] lines = source.splitlines() @@ -81,46 +77,47 @@ def detect_optimization_opportunities( # Check if component should be wrapped in React.memo if not component_info.is_memoized: - opportunities.append(OptimizationOpportunity( - type=OpportunityType.MISSING_REACT_MEMO, - line=component_info.start_line, - description=f"Component '{component_info.function_name}' is not wrapped in React.memo(). " - "If it receives stable props, wrapping can prevent unnecessary re-renders.", - severity=OpportunitySeverity.MEDIUM, - )) + opportunities.append( + OptimizationOpportunity( + type=OpportunityType.MISSING_REACT_MEMO, + line=component_info.start_line, + description=f"Component '{component_info.function_name}' is not wrapped in React.memo(). " + "If it receives stable props, wrapping can prevent unnecessary re-renders.", + severity=OpportunitySeverity.MEDIUM, + ) + ) return opportunities -def _detect_inline_props( - lines: list[str], offset: int, opportunities: list[OptimizationOpportunity] -) -> None: +def _detect_inline_props(lines: list[str], offset: int, opportunities: list[OptimizationOpportunity]) -> None: """Detect inline object/array literals in JSX prop positions.""" for i, line in enumerate(lines): line_num = offset + i + 1 if INLINE_OBJECT_IN_JSX_RE.search(line): - opportunities.append(OptimizationOpportunity( - type=OpportunityType.INLINE_OBJECT_PROP, - line=line_num, - description="Inline object literal in JSX prop creates a new reference on every render. " - "Extract to useMemo or a module-level constant.", - severity=OpportunitySeverity.HIGH, - )) + opportunities.append( + OptimizationOpportunity( + type=OpportunityType.INLINE_OBJECT_PROP, + line=line_num, + description="Inline object literal in JSX prop creates a new reference on every render. " + "Extract to useMemo or a module-level constant.", + severity=OpportunitySeverity.HIGH, + ) + ) if INLINE_ARRAY_IN_JSX_RE.search(line): - opportunities.append(OptimizationOpportunity( - type=OpportunityType.INLINE_ARRAY_PROP, - line=line_num, - description="Inline array literal in JSX prop creates a new reference on every render. " - "Extract to useMemo or a module-level constant.", - severity=OpportunitySeverity.HIGH, - )) + opportunities.append( + OptimizationOpportunity( + type=OpportunityType.INLINE_ARRAY_PROP, + line=line_num, + description="Inline array literal in JSX prop creates a new reference on every render. " + "Extract to useMemo or a module-level constant.", + severity=OpportunitySeverity.HIGH, + ) + ) def _detect_missing_usecallback( - component_source: str, - lines: list[str], - offset: int, - opportunities: list[OptimizationOpportunity], + component_source: str, lines: list[str], offset: int, opportunities: list[OptimizationOpportunity] ) -> None: """Detect arrow functions or function expressions that could use useCallback.""" has_usecallback = bool(USECALLBACK_RE.search(component_source)) @@ -132,30 +129,31 @@ def _detect_missing_usecallback( if FUNCTION_DEF_RE.search(stripped) and "useCallback" not in stripped and "useMemo" not in stripped: # Skip if the component already uses useCallback extensively if not has_usecallback: - opportunities.append(OptimizationOpportunity( - type=OpportunityType.MISSING_USECALLBACK, - line=line_num, - description="Function defined inside render body creates a new reference on every render. " - "Wrap with useCallback() if passed as a prop to child components.", - severity=OpportunitySeverity.MEDIUM, - )) + opportunities.append( + OptimizationOpportunity( + type=OpportunityType.MISSING_USECALLBACK, + line=line_num, + description="Function defined inside render body creates a new reference on every render. " + "Wrap with useCallback() if passed as a prop to child components.", + severity=OpportunitySeverity.MEDIUM, + ) + ) def _detect_missing_usememo( - component_source: str, - lines: list[str], - offset: int, - opportunities: list[OptimizationOpportunity], + component_source: str, lines: list[str], offset: int, opportunities: list[OptimizationOpportunity] ) -> None: """Detect expensive computations that could benefit from useMemo.""" for i, line in enumerate(lines): line_num = offset + i + 1 stripped = line.strip() if EXPENSIVE_OPS_RE.search(stripped) and "useMemo" not in stripped: - opportunities.append(OptimizationOpportunity( - type=OpportunityType.MISSING_USEMEMO, - line=line_num, - description="Expensive array operation in render body runs on every render. " - "Wrap with useMemo() and specify dependencies.", - severity=OpportunitySeverity.HIGH, - )) + opportunities.append( + OptimizationOpportunity( + type=OpportunityType.MISSING_USEMEMO, + line=line_num, + description="Expensive array operation in render body runs on every render. " + "Wrap with useMemo() and specify dependencies.", + severity=OpportunitySeverity.HIGH, + ) + ) diff --git a/codeflash/languages/javascript/frameworks/react/context.py b/codeflash/languages/javascript/frameworks/react/context.py index b5dc2b871..0d53e5c8b 100644 --- a/codeflash/languages/javascript/frameworks/react/context.py +++ b/codeflash/languages/javascript/frameworks/react/context.py @@ -12,6 +12,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from tree_sitter import Node + from codeflash.languages.javascript.frameworks.react.analyzer import OptimizationOpportunity from codeflash.languages.javascript.frameworks.react.discovery import ReactComponentInfo from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer @@ -73,24 +75,16 @@ def to_prompt_string(self) -> str: def extract_react_context( - component_info: ReactComponentInfo, - source: str, - analyzer: TreeSitterAnalyzer, - module_root: Path, + component_info: ReactComponentInfo, source: str, analyzer: TreeSitterAnalyzer, module_root: Path ) -> ReactContext: """Extract React-specific context for a component. Analyzes the component source to find props types, hooks, child components, and optimization opportunities. """ - from codeflash.languages.javascript.frameworks.react.analyzer import ( # noqa: PLC0415 - detect_optimization_opportunities, - ) + from codeflash.languages.javascript.frameworks.react.analyzer import detect_optimization_opportunities - context = ReactContext( - props_interface=component_info.props_type, - is_already_memoized=component_info.is_memoized, - ) + context = ReactContext(props_interface=component_info.props_type, is_already_memoized=component_info.is_memoized) # Extract hook usage details from the component source lines = source.splitlines() @@ -114,7 +108,7 @@ def extract_react_context( def _extract_hook_usages(component_source: str) -> list[HookUsage]: """Parse hook calls and their dependency arrays from component source.""" - import re # noqa: PLC0415 + import re hooks: list[HookUsage] = [] # Match useXxx( patterns @@ -124,7 +118,7 @@ def _extract_hook_usages(component_source: str) -> list[HookUsage]: hook_name = match.group(1) # Try to determine if there's a dependency array # Look for ], [ pattern after the hook call (simplified heuristic) - rest_of_line = component_source[match.end():] + rest_of_line = component_source[match.end() :] has_deps = False dep_count = 0 @@ -143,7 +137,7 @@ def _extract_hook_usages(component_source: str) -> list[HookUsage]: # Count items in the array (rough: count commas + 1 for non-empty) array_start = preceding.rfind("[") if array_start >= 0: - array_content = preceding[array_start + 1:-1].strip() + array_content = preceding[array_start + 1 : -1].strip() if array_content: dep_count = array_content.count(",") + 1 else: @@ -151,18 +145,14 @@ def _extract_hook_usages(component_source: str) -> list[HookUsage]: has_deps = True break - hooks.append(HookUsage( - name=hook_name, - has_dependency_array=has_deps, - dependency_count=dep_count, - )) + hooks.append(HookUsage(name=hook_name, has_dependency_array=has_deps, dependency_count=dep_count)) return hooks def _extract_child_components(component_source: str, analyzer: TreeSitterAnalyzer, full_source: str) -> list[str]: """Find child component names rendered in JSX.""" - import re # noqa: PLC0415 + import re # Match JSX tags that start with uppercase (React components) jsx_component_re = re.compile(r"<([A-Z][a-zA-Z0-9.]*)") @@ -177,7 +167,7 @@ def _extract_child_components(component_source: str, analyzer: TreeSitterAnalyze def _extract_context_subscriptions(component_source: str) -> list[str]: """Find React context subscriptions via useContext calls.""" - import re # noqa: PLC0415 + import re context_re = re.compile(r"\buseContext\s*\(\s*(\w+)") return [match.group(1) for match in context_re.finditer(component_source)] @@ -188,13 +178,13 @@ def _find_type_definition(type_name: str, source: str, analyzer: TreeSitterAnaly source_bytes = source.encode("utf-8") tree = analyzer.parse(source_bytes) - def search_node(node): + def search_node(node: Node) -> str | None: if node.type in ("interface_declaration", "type_alias_declaration"): name_node = node.child_by_field_name("name") if name_node: - name = source_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8") + name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf-8") if name == type_name: - return source_bytes[node.start_byte:node.end_byte].decode("utf-8") + return source_bytes[node.start_byte : node.end_byte].decode("utf-8") for child in node.children: result = search_node(child) if result: diff --git a/codeflash/languages/javascript/frameworks/react/discovery.py b/codeflash/languages/javascript/frameworks/react/discovery.py index 194088885..9e39de817 100644 --- a/codeflash/languages/javascript/frameworks/react/discovery.py +++ b/codeflash/languages/javascript/frameworks/react/discovery.py @@ -8,12 +8,14 @@ import logging import re -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: + from tree_sitter import Node + from codeflash.languages.javascript.treesitter import FunctionNode, TreeSitterAnalyzer logger = logging.getLogger(__name__) @@ -23,13 +25,28 @@ HOOK_NAME_RE = re.compile(r"^use[A-Z]\w*$") # Built-in React hooks -BUILTIN_HOOKS = frozenset({ - "useState", "useEffect", "useContext", "useReducer", "useCallback", - "useMemo", "useRef", "useImperativeHandle", "useLayoutEffect", - "useInsertionEffect", "useDebugValue", "useDeferredValue", - "useTransition", "useId", "useSyncExternalStore", "useOptimistic", - "useActionState", "useFormStatus", -}) +BUILTIN_HOOKS = frozenset( + { + "useState", + "useEffect", + "useContext", + "useReducer", + "useCallback", + "useMemo", + "useRef", + "useImperativeHandle", + "useLayoutEffect", + "useInsertionEffect", + "useDebugValue", + "useDeferredValue", + "useTransition", + "useId", + "useSyncExternalStore", + "useOptimistic", + "useActionState", + "useFormStatus", + } +) class ComponentType(str, Enum): @@ -105,9 +122,7 @@ def find_react_components(source: str, file_path: Path, analyzer: TreeSitterAnal logger.debug("Skipping server component file: %s", file_path) return [] - functions = analyzer.find_functions( - source, include_methods=False, include_arrow_functions=True, require_name=True - ) + functions = analyzer.find_functions(source, include_methods=False, include_arrow_functions=True, require_name=True) components: list[ReactComponentInfo] = [] for func in functions: @@ -119,16 +134,18 @@ def find_react_components(source: str, file_path: Path, analyzer: TreeSitterAnal props_type = _extract_props_type(func, source, analyzer) is_memoized = _is_wrapped_in_memo(func, source) - components.append(ReactComponentInfo( - function_name=func.name, - component_type=comp_type, - uses_hooks=tuple(hooks_used), - returns_jsx=comp_type != ComponentType.HOOK and _function_returns_jsx(func, source, analyzer), - props_type=props_type, - is_memoized=is_memoized, - start_line=func.start_line, - end_line=func.end_line, - )) + components.append( + ReactComponentInfo( + function_name=func.name, + component_type=comp_type, + uses_hooks=tuple(hooks_used), + returns_jsx=comp_type != ComponentType.HOOK and _function_returns_jsx(func, source, analyzer), + props_type=props_type, + is_memoized=is_memoized, + start_line=func.start_line, + end_line=func.end_line, + ) + ) return components @@ -157,11 +174,14 @@ def _function_returns_jsx(func: FunctionNode, source: str, analyzer: TreeSitterA return False -def _node_contains_jsx(node) -> bool: +def _node_contains_jsx(node: Node) -> bool: """Recursively check if a tree-sitter node contains JSX.""" if node.type in ( - "jsx_element", "jsx_self_closing_element", "jsx_fragment", - "jsx_expression", "jsx_opening_element", + "jsx_element", + "jsx_self_closing_element", + "jsx_fragment", + "jsx_expression", + "jsx_opening_element", ): return True @@ -208,7 +228,7 @@ def _extract_props_type(func: FunctionNode, source: str, analyzer: TreeSitterAna # Get the type annotation node (skip the colon) for child in type_node.children: if child.type != ":": - return source_bytes[child.start_byte:child.end_byte].decode("utf-8") + return source_bytes[child.start_byte : child.end_byte].decode("utf-8") # Destructured params with type: { foo, bar }: Props if param.type == "object_pattern": # Look for next sibling that is a type_annotation @@ -216,7 +236,7 @@ def _extract_props_type(func: FunctionNode, source: str, analyzer: TreeSitterAna if next_sib and next_sib.type == "type_annotation": for child in next_sib.children: if child.type != ":": - return source_bytes[child.start_byte:child.end_byte].decode("utf-8") + return source_bytes[child.start_byte : child.end_byte].decode("utf-8") return None @@ -234,7 +254,7 @@ def _is_wrapped_in_memo(func: FunctionNode, source: str) -> bool: func_node = parent.child_by_field_name("function") if func_node: source_bytes = source.encode("utf-8") - func_text = source_bytes[func_node.start_byte:func_node.end_byte].decode("utf-8") + func_text = source_bytes[func_node.start_byte : func_node.end_byte].decode("utf-8") if func_text in ("React.memo", "memo"): return True parent = parent.parent @@ -242,10 +262,5 @@ def _is_wrapped_in_memo(func: FunctionNode, source: str) -> bool: # Also check for memo wrapping at the export level: # export default memo(MyComponent) name = func.name - memo_patterns = [ - f"React.memo({name})", - f"memo({name})", - f"React.memo({name},", - f"memo({name},", - ] + memo_patterns = [f"React.memo({name})", f"memo({name})", f"React.memo({name},", f"memo({name},"] return any(pattern in source for pattern in memo_patterns) diff --git a/codeflash/languages/javascript/frameworks/react/profiler.py b/codeflash/languages/javascript/frameworks/react/profiler.py index 9d273b70b..880793c11 100644 --- a/codeflash/languages/javascript/frameworks/react/profiler.py +++ b/codeflash/languages/javascript/frameworks/react/profiler.py @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from tree_sitter import Node + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer logger = logging.getLogger(__name__) @@ -81,7 +83,7 @@ def instrument_component_with_profiler(source: str, component_name: str, analyze def instrument_all_components_for_tracing(source: str, file_path: Path, analyzer: TreeSitterAnalyzer) -> str: """Instrument ALL components in a file for tracing/discovery mode.""" - from codeflash.languages.javascript.frameworks.react.discovery import find_react_components # noqa: PLC0415 + from codeflash.languages.javascript.frameworks.react.discovery import find_react_components components = find_react_components(source, file_path, analyzer) if not components: @@ -96,13 +98,13 @@ def instrument_all_components_for_tracing(source: str, file_path: Path, analyzer return result -def _find_component_function(root_node, component_name: str, source_bytes: bytes): +def _find_component_function(root_node: Node, component_name: str, source_bytes: bytes) -> Node | None: """Find the tree-sitter node for a named component function.""" # Check function declarations if root_node.type == "function_declaration": name_node = root_node.child_by_field_name("name") if name_node: - name = source_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8") + name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf-8") if name == component_name: return root_node @@ -110,7 +112,7 @@ def _find_component_function(root_node, component_name: str, source_bytes: bytes if root_node.type == "variable_declarator": name_node = root_node.child_by_field_name("name") if name_node: - name = source_bytes[name_node.start_byte:name_node.end_byte].decode("utf-8") + name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf-8") if name == component_name: return root_node @@ -129,14 +131,17 @@ def _find_component_function(root_node, component_name: str, source_bytes: bytes return None -def _find_jsx_returns(func_node, source_bytes: bytes) -> list: +def _find_jsx_returns(func_node: Node, source_bytes: bytes) -> list[Node]: """Find all return statements that contain JSX within a function node.""" - returns = [] + returns: list[Node] = [] - def walk(node): + def walk(node: Node) -> None: # Don't descend into nested functions if node != func_node and node.type in ( - "function_declaration", "arrow_function", "function", "method_definition", + "function_declaration", + "arrow_function", + "function", + "method_definition", ): return @@ -154,11 +159,9 @@ def walk(node): return returns -def _contains_jsx(node) -> bool: +def _contains_jsx(node: Node) -> bool: """Check if a tree-sitter node contains JSX elements.""" - if node.type in ( - "jsx_element", "jsx_self_closing_element", "jsx_fragment", - ): + if node.type in ("jsx_element", "jsx_self_closing_element", "jsx_fragment"): return True for child in node.children: if _contains_jsx(child): @@ -166,7 +169,7 @@ def _contains_jsx(node) -> bool: return False -def _wrap_return_with_profiler(source: str, return_node, profiler_id: str, safe_name: str) -> str: +def _wrap_return_with_profiler(source: str, return_node: Node, profiler_id: str, safe_name: str) -> str: """Wrap a return statement's JSX with React.Profiler.""" source_bytes = source.encode("utf-8") @@ -238,7 +241,7 @@ def _ensure_react_import(source: str) -> str: if "from 'react'" in source or 'from "react"' in source: # React is imported but maybe not as the default. That's fine for JSX. # We need React.Profiler so add it - if "React" not in source.split("from")[0] if "from" in source else "": + if "React" not in source.split("from", maxsplit=1)[0] if "from" in source else "": return 'import React from "react";\n' + source return source return 'import React from "react";\n' + source diff --git a/codeflash/languages/javascript/frameworks/react/testgen.py b/codeflash/languages/javascript/frameworks/react/testgen.py index fd621b05e..4a3eeaf95 100644 --- a/codeflash/languages/javascript/frameworks/react/testgen.py +++ b/codeflash/languages/javascript/frameworks/react/testgen.py @@ -16,9 +16,7 @@ def build_react_testgen_context( - component_info: ReactComponentInfo, - react_context: ReactContext, - code_context: CodeContext, + component_info: ReactComponentInfo, react_context: ReactContext, code_context: CodeContext ) -> dict: """Assemble context dict for the React testgen LLM prompt.""" return { @@ -101,14 +99,12 @@ def post_process_react_tests(test_source: str, component_info: ReactComponentInf # Ensure act import if state updates are detected if "act(" in result and "import" in result and "act" not in result.split("from '@testing-library/react'")[0]: - result = result.replace( - "from '@testing-library/react'", - "act, " + "from '@testing-library/react'", - 1, - ) + result = result.replace("from '@testing-library/react'", "act, " + "from '@testing-library/react'", 1) # Ensure user-event import if user interactions are tested - if ("click" in result.lower() or "type" in result.lower() or "userEvent" in result) and "@testing-library/user-event" not in result: + if ( + "click" in result.lower() or "type" in result.lower() or "userEvent" in result + ) and "@testing-library/user-event" not in result: # Add user-event import after testing-library import result = re.sub( r"(import .+ from '@testing-library/react';?\n)", diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py index 03aee9d38..820deeaec 100644 --- a/codeflash/languages/javascript/parse.py +++ b/codeflash/languages/javascript/parse.py @@ -34,9 +34,7 @@ # React Profiler render marker pattern # Format: !######REACT_RENDER:{component}:{phase}:{actualDuration}:{baseDuration}:{renderCount}######! -REACT_RENDER_MARKER_PATTERN = re.compile( - r"!######REACT_RENDER:([^:]+):([^:]+):([^:]+):([^:]+):(\d+)######!" -) +REACT_RENDER_MARKER_PATTERN = re.compile(r"!######REACT_RENDER:([^:]+):([^:]+):([^:]+):([^:]+):(\d+)######!") @dataclass(frozen=True) @@ -58,13 +56,15 @@ def parse_react_render_markers(stdout: str) -> list[RenderProfile]: profiles: list[RenderProfile] = [] for match in REACT_RENDER_MARKER_PATTERN.finditer(stdout): try: - profiles.append(RenderProfile( - component_name=match.group(1), - phase=match.group(2), - actual_duration_ms=float(match.group(3)), - base_duration_ms=float(match.group(4)), - render_count=int(match.group(5)), - )) + profiles.append( + RenderProfile( + component_name=match.group(1), + phase=match.group(2), + actual_duration_ms=float(match.group(3)), + base_duration_ms=float(match.group(4)), + render_count=int(match.group(5)), + ) + ) except (ValueError, IndexError) as e: logger.debug("Failed to parse React render marker: %s", e) return profiles diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 45007caaa..7ec6bccd0 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -75,7 +75,7 @@ def dir_excludes(self) -> frozenset[str]: def get_framework_info(self, project_root: Path) -> FrameworkInfo: """Get cached framework info for the project.""" if self._cached_framework_root != project_root or self._cached_framework_info is None: - from codeflash.languages.javascript.frameworks.detector import detect_framework # noqa: PLC0415 + from codeflash.languages.javascript.frameworks.detector import detect_framework self._cached_framework_info = detect_framework(project_root) self._cached_framework_root = project_root @@ -120,14 +120,12 @@ def discover_functions( react_component_map: dict[str, Any] = {} project_root = file_path.parent # Will be refined by caller try: - from codeflash.languages.javascript.frameworks.react.discovery import ( # noqa: PLC0415 - classify_component, - ) + from codeflash.languages.javascript.frameworks.react.discovery import classify_component for func in tree_functions: comp_type = classify_component(func, source, analyzer) if comp_type is not None: - from codeflash.languages.javascript.frameworks.react.discovery import ( # noqa: PLC0415 + from codeflash.languages.javascript.frameworks.react.discovery import ( _extract_hooks_used, _is_wrapped_in_memo, ) @@ -473,13 +471,8 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, react_context_str = "" if function.metadata and function.metadata.get("is_react_component"): try: - from codeflash.languages.javascript.frameworks.react.discovery import ( # noqa: PLC0415 - ReactComponentInfo, - find_react_components, - ) - from codeflash.languages.javascript.frameworks.react.context import ( # noqa: PLC0415 - extract_react_context, - ) + from codeflash.languages.javascript.frameworks.react.context import extract_react_context + from codeflash.languages.javascript.frameworks.react.discovery import find_react_components components = find_react_components(source, function.file_path, analyzer) for comp in components: @@ -535,7 +528,7 @@ def _find_class_definition( source_bytes = source.encode("utf8") tree = analyzer.parse(source_bytes) - def find_class_node(node): + def find_class_node(node: Any) -> Any: """Recursively find a class declaration with the given name.""" if node.type in ("class_declaration", "class"): name_node = node.child_by_field_name("name") @@ -1953,10 +1946,9 @@ def _build_runtime_map( continue key = test_qualified_name + "#" + abs_path_str - parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr] - cur_invid = ( - inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) - ) # type: ignore[union-attr] + iteration_id = inv_id.iteration_id or "" + parts = iteration_id.split("_").__len__() + cur_invid = iteration_id.split("_")[0] if parts < 3 else "_".join(iteration_id.split("_")[:-1]) match_key = key + "#" + cur_invid if match_key not in unique_inv_ids: unique_inv_ids[match_key] = 0 @@ -1967,7 +1959,7 @@ def _build_runtime_map( def compare_test_results( self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None - ) -> tuple[bool, list]: + ) -> tuple[bool, list[Any]]: """Compare test results between original and candidate code. Args: @@ -2084,8 +2076,8 @@ def get_module_path(self, source_file: Path, project_root: Path, tests_root: Pat return rel_path except ValueError: # Fallback if paths are on different drives (Windows) - rel_path = source_file.relative_to(project_root) - return "../" + rel_path.with_suffix("").as_posix() + fallback_path = source_file.relative_to(project_root) + return "../" + fallback_path.with_suffix("").as_posix() def verify_requirements(self, project_root: Path, test_framework: str = "jest") -> tuple[bool, list[str]]: """Verify that all JavaScript requirements are met. @@ -2253,7 +2245,7 @@ def instrument_source_for_line_profiler( logger.warning("Failed to instrument source for line profiling: %s", e) return False - def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict: + def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict[str, Any]: from codeflash.languages.javascript.line_profiler import JavaScriptLineProfiler if line_profiler_output_file.exists(): @@ -2265,7 +2257,7 @@ def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict: logger.warning("No line profiler output file found at %s", line_profiler_output_file) return {"timings": {}, "unit": 0, "str_out": ""} - def _format_js_line_profile_output(self, parsed_results: dict) -> str: + def _format_js_line_profile_output(self, parsed_results: dict[str, Any]) -> str: """Format JavaScript line profiler results for display.""" if not parsed_results.get("timings"): return "" From 07ab6186e5827b5891d509b74cb372e2642569d9 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 02:54:55 +0000 Subject: [PATCH 03/37] fix: add missing type parameter to dict return type in testgen --- codeflash/languages/javascript/frameworks/react/testgen.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/languages/javascript/frameworks/react/testgen.py b/codeflash/languages/javascript/frameworks/react/testgen.py index 4a3eeaf95..be09e858e 100644 --- a/codeflash/languages/javascript/frameworks/react/testgen.py +++ b/codeflash/languages/javascript/frameworks/react/testgen.py @@ -7,7 +7,7 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from codeflash.languages.base import CodeContext @@ -17,7 +17,7 @@ def build_react_testgen_context( component_info: ReactComponentInfo, react_context: ReactContext, code_context: CodeContext -) -> dict: +) -> dict[str, Any]: """Assemble context dict for the React testgen LLM prompt.""" return { "component_name": component_info.function_name, From 257ed8cf0f359c5cb8b0b0affd42de45e3246ac0 Mon Sep 17 00:00:00 2001 From: Sarthak Agarwal Date: Fri, 20 Feb 2026 08:26:42 +0530 Subject: [PATCH 04/37] add benchmarking --- codeflash/api/aiservice.py | 8 ++ codeflash/api/schemas.py | 17 +++ .../frameworks/react/benchmarking.py | 107 ++++++++++++++++++ codeflash/result/critic.py | 38 +++++++ codeflash/result/explanation.py | 8 ++ 5 files changed, 178 insertions(+) create mode 100644 codeflash/languages/javascript/frameworks/react/benchmarking.py diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index b8bc9454b..a144a9dd3 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -135,6 +135,8 @@ def optimize_code( is_async: bool = False, n_candidates: int = 5, is_numerical_code: bool | None = None, + is_react_component: bool = False, + react_context: str | None = None, ) -> list[OptimizedCandidate]: """Optimize the given code for performance by making a request to the Django endpoint. @@ -188,6 +190,12 @@ def optimize_code( if module_system: payload["module_system"] = module_system + # React-specific fields + if is_react_component: + payload["is_react_component"] = True + if react_context: + payload["react_context"] = react_context + # DEBUG: Print payload language field logger.debug( f"Sending optimize request with language='{payload['language']}' (type: {type(payload['language'])})" diff --git a/codeflash/api/schemas.py b/codeflash/api/schemas.py index 37e2c72a5..db64cb514 100644 --- a/codeflash/api/schemas.py +++ b/codeflash/api/schemas.py @@ -120,6 +120,10 @@ class OptimizeRequest: repo_name: str | None = None current_username: str | None = None + # === React-specific === + is_react_component: bool = False + react_context: str = "" + def to_payload(self) -> dict[str, Any]: """Convert to API payload dict, maintaining backward compatibility.""" payload = { @@ -150,6 +154,12 @@ def to_payload(self) -> dict[str, Any]: if self.language_info.module_system != ModuleSystem.UNKNOWN: payload["module_system"] = self.language_info.module_system.value + # React-specific fields + if self.is_react_component: + payload["is_react_component"] = True + if self.react_context: + payload["react_context"] = self.react_context + return payload @@ -187,6 +197,9 @@ class TestGenRequest: # === Metadata === codeflash_version: str = "" + # === React-specific === + is_react_component: bool = False + def to_payload(self) -> dict[str, Any]: """Convert to API payload dict, maintaining backward compatibility.""" payload = { @@ -218,6 +231,10 @@ def to_payload(self) -> dict[str, Any]: if self.language_info.module_system != ModuleSystem.UNKNOWN: payload["module_system"] = self.language_info.module_system.value + # React-specific fields + if self.is_react_component: + payload["is_react_component"] = True + return payload diff --git a/codeflash/languages/javascript/frameworks/react/benchmarking.py b/codeflash/languages/javascript/frameworks/react/benchmarking.py new file mode 100644 index 000000000..dfee93ad7 --- /dev/null +++ b/codeflash/languages/javascript/frameworks/react/benchmarking.py @@ -0,0 +1,107 @@ +"""React render benchmarking and comparison. + +Compares original vs optimized render profiles from React Profiler +instrumentation to quantify re-render reduction and render time improvement. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from codeflash.languages.javascript.parse import RenderProfile + + +@dataclass(frozen=True) +class RenderBenchmark: + """Comparison of original vs optimized render metrics.""" + + component_name: str + original_render_count: int + optimized_render_count: int + original_avg_duration_ms: float + optimized_avg_duration_ms: float + + @property + def render_count_reduction_pct(self) -> float: + """Percentage reduction in render count (0-100).""" + if self.original_render_count == 0: + return 0.0 + return ( + (self.original_render_count - self.optimized_render_count) + / self.original_render_count + * 100 + ) + + @property + def duration_reduction_pct(self) -> float: + """Percentage reduction in render duration (0-100).""" + if self.original_avg_duration_ms == 0: + return 0.0 + return ( + (self.original_avg_duration_ms - self.optimized_avg_duration_ms) + / self.original_avg_duration_ms + * 100 + ) + + @property + def render_speedup_x(self) -> float: + """Render time speedup factor (e.g., 2.5x means 2.5 times faster).""" + if self.optimized_avg_duration_ms == 0: + return 0.0 + return self.original_avg_duration_ms / self.optimized_avg_duration_ms + + +def compare_render_benchmarks( + original_profiles: list[RenderProfile], + optimized_profiles: list[RenderProfile], +) -> RenderBenchmark | None: + """Compare original and optimized render profiles. + + Aggregates render counts and durations across all render events + for the same component, then computes the benchmark comparison. + """ + if not original_profiles or not optimized_profiles: + return None + + # Use the first profile's component name + component_name = original_profiles[0].component_name + + # Aggregate original metrics + orig_count = max((p.render_count for p in original_profiles), default=0) + orig_durations = [p.actual_duration_ms for p in original_profiles] + orig_avg_duration = sum(orig_durations) / len(orig_durations) if orig_durations else 0.0 + + # Aggregate optimized metrics + opt_count = max((p.render_count for p in optimized_profiles), default=0) + opt_durations = [p.actual_duration_ms for p in optimized_profiles] + opt_avg_duration = sum(opt_durations) / len(opt_durations) if opt_durations else 0.0 + + return RenderBenchmark( + component_name=component_name, + original_render_count=orig_count, + optimized_render_count=opt_count, + original_avg_duration_ms=orig_avg_duration, + optimized_avg_duration_ms=opt_avg_duration, + ) + + +def format_render_benchmark_for_pr(benchmark: RenderBenchmark) -> str: + """Format render benchmark data for PR comment body.""" + lines = [ + "### React Render Performance", + "", + "| Metric | Before | After | Improvement |", + "|--------|--------|-------|-------------|", + f"| Renders | {benchmark.original_render_count} | {benchmark.optimized_render_count} " + f"| {benchmark.render_count_reduction_pct:.1f}% fewer |", + f"| Avg render time | {benchmark.original_avg_duration_ms:.2f}ms " + f"| {benchmark.optimized_avg_duration_ms:.2f}ms " + f"| {benchmark.duration_reduction_pct:.1f}% faster |", + ] + + if benchmark.render_speedup_x > 1: + lines.append(f"\nRender time improved **{benchmark.render_speedup_x:.1f}x**.") + + return "\n".join(lines) diff --git a/codeflash/result/critic.py b/codeflash/result/critic.py index 600c4a537..e04f01d50 100644 --- a/codeflash/result/critic.py +++ b/codeflash/result/critic.py @@ -21,6 +21,7 @@ class AcceptanceReason(Enum): RUNTIME = "runtime" THROUGHPUT = "throughput" CONCURRENCY = "concurrency" + RENDER_COUNT = "render_count" NONE = "none" @@ -208,3 +209,40 @@ def coverage_critic(original_code_coverage: CoverageData | None) -> bool: if original_code_coverage: return original_code_coverage.coverage >= COVERAGE_THRESHOLD return False + + +# Minimum render count reduction percentage to accept a React optimization +MIN_RENDER_COUNT_REDUCTION_PCT = 0.20 # 20% + + +def render_efficiency_critic( + original_render_count: int, + optimized_render_count: int, + original_render_duration: float, + optimized_render_duration: float, + best_render_count_until_now: int | None = None, +) -> bool: + """Evaluate whether a React optimization reduces re-renders or render time sufficiently. + + Accepts if: + - Render count is reduced by >= 20% + - OR render duration is reduced by >= MIN_IMPROVEMENT_THRESHOLD + - AND the candidate is the best seen so far + """ + if original_render_count == 0: + return False + + # Check render count reduction + count_reduction = (original_render_count - optimized_render_count) / original_render_count + count_improved = count_reduction >= MIN_RENDER_COUNT_REDUCTION_PCT + + # Check render duration reduction + duration_improved = False + if original_render_duration > 0: + duration_gain = (original_render_duration - optimized_render_duration) / original_render_duration + duration_improved = duration_gain > MIN_IMPROVEMENT_THRESHOLD + + # Check if this is the best candidate so far + is_best = best_render_count_until_now is None or optimized_render_count <= best_render_count_until_now + + return (count_improved or duration_improved) and is_best diff --git a/codeflash/result/explanation.py b/codeflash/result/explanation.py index f0aff73d0..55e0f31f7 100644 --- a/codeflash/result/explanation.py +++ b/codeflash/result/explanation.py @@ -30,6 +30,7 @@ class Explanation: original_concurrency_metrics: Optional[ConcurrencyMetrics] = None best_concurrency_metrics: Optional[ConcurrencyMetrics] = None acceptance_reason: AcceptanceReason = AcceptanceReason.RUNTIME + render_benchmark_markdown: Optional[str] = None @property def perf_improvement_line(self) -> str: @@ -37,6 +38,7 @@ def perf_improvement_line(self) -> str: AcceptanceReason.RUNTIME: "runtime", AcceptanceReason.THROUGHPUT: "throughput", AcceptanceReason.CONCURRENCY: "concurrency", + AcceptanceReason.RENDER_COUNT: "render count", AcceptanceReason.NONE: "", }.get(self.acceptance_reason, "") @@ -144,10 +146,16 @@ def __str__(self) -> str: else: performance_description = f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n" + # Include React render benchmark if available + render_info = "" + if self.render_benchmark_markdown: + render_info = self.render_benchmark_markdown + "\n\n" + return ( f"Optimized {self.function_name} in {self.file_path}\n" f"{self.perf_improvement_line}\n" + performance_description + + render_info + (benchmark_info if benchmark_info else "") + self.raw_explanation_message + " \n\n" From 63fff653f1a8ffd5f365a2bdf4c8ec5ab75a21a8 Mon Sep 17 00:00:00 2001 From: Sarthak Agarwal Date: Fri, 20 Feb 2026 09:01:18 +0530 Subject: [PATCH 05/37] fix regex --- .../codeflash_benchmark/version.py | 2 +- .../javascript/frameworks/react/context.py | 2 +- .../javascript/frameworks/react/discovery.py | 9 +- codeflash/version.py | 2 +- tests/integration/test_react_e2e.py | 160 ++++++++++++++++ tests/react/__init__.py | 0 tests/react/fixtures/Counter.tsx | 21 +++ tests/react/fixtures/DataTable.tsx | 43 +++++ tests/react/fixtures/MemoizedList.tsx | 29 +++ tests/react/fixtures/ServerComponent.tsx | 17 ++ tests/react/fixtures/TaskList.tsx | 66 +++++++ tests/react/fixtures/UserCard.tsx | 22 +++ tests/react/fixtures/__init__.py | 0 tests/react/fixtures/package.json | 16 ++ tests/react/fixtures/useDebounce.ts | 17 ++ tests/react/test_analyzer.py | 137 ++++++++++++++ tests/react/test_benchmarking.py | 173 ++++++++++++++++++ tests/react/test_context.py | 157 ++++++++++++++++ tests/react/test_detector.py | 158 ++++++++++++++++ tests/react/test_discovery.py | 143 +++++++++++++++ tests/react/test_profiler.py | 111 +++++++++++ tests/react/test_testgen.py | 64 +++++++ uv.lock | 1 - 23 files changed, 1343 insertions(+), 7 deletions(-) create mode 100644 tests/integration/test_react_e2e.py create mode 100644 tests/react/__init__.py create mode 100644 tests/react/fixtures/Counter.tsx create mode 100644 tests/react/fixtures/DataTable.tsx create mode 100644 tests/react/fixtures/MemoizedList.tsx create mode 100644 tests/react/fixtures/ServerComponent.tsx create mode 100644 tests/react/fixtures/TaskList.tsx create mode 100644 tests/react/fixtures/UserCard.tsx create mode 100644 tests/react/fixtures/__init__.py create mode 100644 tests/react/fixtures/package.json create mode 100644 tests/react/fixtures/useDebounce.ts create mode 100644 tests/react/test_analyzer.py create mode 100644 tests/react/test_benchmarking.py create mode 100644 tests/react/test_context.py create mode 100644 tests/react/test_detector.py create mode 100644 tests/react/test_discovery.py create mode 100644 tests/react/test_profiler.py create mode 100644 tests/react/test_testgen.py diff --git a/codeflash-benchmark/codeflash_benchmark/version.py b/codeflash-benchmark/codeflash_benchmark/version.py index 18606e8d2..0a00478d2 100644 --- a/codeflash-benchmark/codeflash_benchmark/version.py +++ b/codeflash-benchmark/codeflash_benchmark/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.3.0" +__version__ = "0.20.1.post127.dev0+c276ff32" diff --git a/codeflash/languages/javascript/frameworks/react/context.py b/codeflash/languages/javascript/frameworks/react/context.py index 0d53e5c8b..c05fc4003 100644 --- a/codeflash/languages/javascript/frameworks/react/context.py +++ b/codeflash/languages/javascript/frameworks/react/context.py @@ -112,7 +112,7 @@ def _extract_hook_usages(component_source: str) -> list[HookUsage]: hooks: list[HookUsage] = [] # Match useXxx( patterns - hook_pattern = re.compile(r"\b(use[A-Z]\w*)\s*\(") + hook_pattern = re.compile(r"\b(use[A-Z]\w*)\s*(?:<[^>]*>)?\s*\(") for match in hook_pattern.finditer(component_source): hook_name = match.group(1) diff --git a/codeflash/languages/javascript/frameworks/react/discovery.py b/codeflash/languages/javascript/frameworks/react/discovery.py index 9e39de817..4d1a1c552 100644 --- a/codeflash/languages/javascript/frameworks/react/discovery.py +++ b/codeflash/languages/javascript/frameworks/react/discovery.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) PASCAL_CASE_RE = re.compile(r"^[A-Z][a-zA-Z0-9]*$") -HOOK_CALL_RE = re.compile(r"\buse[A-Z]\w*\s*\(") +HOOK_CALL_RE = re.compile(r"\buse[A-Z]\w*\s*(?:<[^>]*>)?\s*\(") HOOK_NAME_RE = re.compile(r"^use[A-Z]\w*$") # Built-in React hooks @@ -198,12 +198,15 @@ def _node_contains_jsx(node: Node) -> bool: return False +HOOK_EXTRACT_RE = re.compile(r"\b(use[A-Z]\w*)\s*(?:<[^>]*>)?\s*\(") + + def _extract_hooks_used(function_source: str) -> list[str]: """Extract hook names called within a function body.""" hooks = [] seen = set() - for match in HOOK_CALL_RE.finditer(function_source): - hook_name = match.group(0).rstrip("( \t") + for match in HOOK_EXTRACT_RE.finditer(function_source): + hook_name = match.group(1) if hook_name not in seen: seen.add(hook_name) hooks.append(hook_name) diff --git a/codeflash/version.py b/codeflash/version.py index 5c0c09b55..0a00478d2 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.20.1" +__version__ = "0.20.1.post127.dev0+c276ff32" diff --git a/tests/integration/test_react_e2e.py b/tests/integration/test_react_e2e.py new file mode 100644 index 000000000..3addedb4b --- /dev/null +++ b/tests/integration/test_react_e2e.py @@ -0,0 +1,160 @@ +"""End-to-end integration test for the React optimization pipeline. + +Tests the full flow: framework detection → component discovery → context extraction +→ profiler marker parsing → benchmarking → critic evaluation. + +Note: This does not invoke the LLM or run actual Jest tests. It validates the +pipeline wiring by running each stage on fixture files and verifying outputs. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +FIXTURES_DIR = Path(__file__).parent.parent / "react" / "fixtures" + + +@pytest.fixture(autouse=True) +def clear_framework_cache(): + from codeflash.languages.javascript.frameworks.detector import detect_framework + + detect_framework.cache_clear() + yield + detect_framework.cache_clear() + + +class TestReactPipelineE2E: + def test_framework_detection_from_fixture(self): + from codeflash.languages.javascript.frameworks.detector import detect_framework + + info = detect_framework(FIXTURES_DIR) + assert info.name == "react" + assert info.react_version_major == 18 + assert info.has_testing_library is True + + def test_component_discovery(self): + from codeflash.languages.javascript.frameworks.react.discovery import find_react_components + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer + + analyzer = TreeSitterAnalyzer("tsx") + + # Counter.tsx — should find 1 function component + source = FIXTURES_DIR.joinpath("Counter.tsx").read_text(encoding="utf-8") + components = find_react_components(source, FIXTURES_DIR / "Counter.tsx", analyzer) + assert any(c.function_name == "Counter" for c in components) + + # ServerComponent.tsx — should be skipped (use server) + source = FIXTURES_DIR.joinpath("ServerComponent.tsx").read_text(encoding="utf-8") + components = find_react_components(source, FIXTURES_DIR / "ServerComponent.tsx", analyzer) + assert components == [] + + # useDebounce.ts — should be detected as hook, not component + ts_analyzer = TreeSitterAnalyzer("typescript") + source = FIXTURES_DIR.joinpath("useDebounce.ts").read_text(encoding="utf-8") + components = find_react_components(source, FIXTURES_DIR / "useDebounce.ts", ts_analyzer) + hooks = [c for c in components if c.component_type.value == "hook"] + assert len(hooks) == 1 + assert hooks[0].function_name == "useDebounce" + + def test_optimization_opportunity_detection(self): + from codeflash.languages.javascript.frameworks.react.analyzer import ( + OpportunityType, + detect_optimization_opportunities, + ) + from codeflash.languages.javascript.frameworks.react.discovery import ( + ComponentType, + ReactComponentInfo, + find_react_components, + ) + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer + + analyzer = TreeSitterAnalyzer("tsx") + + # DataTable has expensive operations without useMemo + source = FIXTURES_DIR.joinpath("DataTable.tsx").read_text(encoding="utf-8") + components = find_react_components(source, FIXTURES_DIR / "DataTable.tsx", analyzer) + data_table = [c for c in components if c.function_name == "DataTable"][0] + opps = detect_optimization_opportunities(source, data_table) + opp_types = [o.type for o in opps] + assert OpportunityType.MISSING_USEMEMO in opp_types + + # UserCard has inline objects in JSX + source = FIXTURES_DIR.joinpath("UserCard.tsx").read_text(encoding="utf-8") + components = find_react_components(source, FIXTURES_DIR / "UserCard.tsx", analyzer) + user_card = [c for c in components if c.function_name == "UserCard"][0] + opps = detect_optimization_opportunities(source, user_card) + opp_types = [o.type for o in opps] + assert OpportunityType.INLINE_OBJECT_PROP in opp_types + + def test_context_extraction(self): + from codeflash.languages.javascript.frameworks.react.context import extract_react_context + from codeflash.languages.javascript.frameworks.react.discovery import find_react_components + from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer + + analyzer = TreeSitterAnalyzer("tsx") + source = FIXTURES_DIR.joinpath("TaskList.tsx").read_text(encoding="utf-8") + components = find_react_components(source, FIXTURES_DIR / "TaskList.tsx", analyzer) + task_list = [c for c in components if c.function_name == "TaskList"][0] + + context = extract_react_context(task_list, source, analyzer, FIXTURES_DIR) + assert len(context.hooks_used) > 0 + assert len(context.optimization_opportunities) > 0 + + prompt = context.to_prompt_string() + assert "useState" in prompt or "Hooks used" in prompt + + def test_profiler_marker_parsing(self): + from codeflash.languages.javascript.parse import parse_react_render_markers + + stdout = ( + "PASS src/TaskList.test.tsx\n" + "!######REACT_RENDER:TaskList:mount:25.3:40.1:1######!\n" + "!######REACT_RENDER:TaskList:update:5.2:40.1:5######!\n" + "!######REACT_RENDER:TaskList:update:4.8:40.1:10######!\n" + ) + profiles = parse_react_render_markers(stdout) + assert len(profiles) == 3 + assert profiles[0].component_name == "TaskList" + assert profiles[2].render_count == 10 + + def test_benchmarking_and_critic(self): + from codeflash.languages.javascript.frameworks.react.benchmarking import ( + compare_render_benchmarks, + format_render_benchmark_for_pr, + ) + from codeflash.languages.javascript.parse import RenderProfile + from codeflash.result.critic import render_efficiency_critic + + original = [ + RenderProfile("TaskList", "mount", 25.0, 40.0, 1), + RenderProfile("TaskList", "update", 5.0, 40.0, 25), + RenderProfile("TaskList", "update", 4.5, 40.0, 47), + ] + optimized = [ + RenderProfile("TaskList", "mount", 20.0, 35.0, 1), + RenderProfile("TaskList", "update", 2.0, 35.0, 3), + ] + + benchmark = compare_render_benchmarks(original, optimized) + assert benchmark is not None + assert benchmark.original_render_count == 47 + assert benchmark.optimized_render_count == 3 + assert benchmark.render_count_reduction_pct > 90 + + # Critic should accept this optimization + accepted = render_efficiency_critic( + original_render_count=benchmark.original_render_count, + optimized_render_count=benchmark.optimized_render_count, + original_render_duration=benchmark.original_avg_duration_ms, + optimized_render_duration=benchmark.optimized_avg_duration_ms, + ) + assert accepted is True + + # PR formatting + pr_output = format_render_benchmark_for_pr(benchmark) + assert "47" in pr_output + assert "3" in pr_output + assert "React Render Performance" in pr_output \ No newline at end of file diff --git a/tests/react/__init__.py b/tests/react/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/react/fixtures/Counter.tsx b/tests/react/fixtures/Counter.tsx new file mode 100644 index 000000000..0d9b5941d --- /dev/null +++ b/tests/react/fixtures/Counter.tsx @@ -0,0 +1,21 @@ +import React, { useState } from 'react'; + +interface CounterProps { + initialCount?: number; + label?: string; +} + +export function Counter({ initialCount = 0, label = 'Count' }: CounterProps) { + const [count, setCount] = useState(initialCount); + + const increment = () => setCount(c => c + 1); + const decrement = () => setCount(c => c - 1); + + return ( +
+ {label}: {count} + + +
+ ); +} diff --git a/tests/react/fixtures/DataTable.tsx b/tests/react/fixtures/DataTable.tsx new file mode 100644 index 000000000..4fdb273f7 --- /dev/null +++ b/tests/react/fixtures/DataTable.tsx @@ -0,0 +1,43 @@ +import React from 'react'; + +interface DataTableProps { + items: Array<{ id: number; name: string; value: number }>; + filterText: string; + sortBy: 'name' | 'value'; +} + +export function DataTable({ items, filterText, sortBy }: DataTableProps) { + // These expensive operations run on every render - should use useMemo + const filteredItems = items.filter(item => + item.name.toLowerCase().includes(filterText.toLowerCase()) + ); + + const sortedItems = filteredItems.sort((a, b) => { + if (sortBy === 'name') return a.name.localeCompare(b.name); + return a.value - b.value; + }); + + const total = sortedItems.reduce((sum, item) => sum + item.value, 0); + + return ( +
+ + + + + + + + + {sortedItems.map(item => ( + + + + + ))} + +
NameValue
{item.name}{item.value}
+
Total: {total}
+
+ ); +} diff --git a/tests/react/fixtures/MemoizedList.tsx b/tests/react/fixtures/MemoizedList.tsx new file mode 100644 index 000000000..8de7234d4 --- /dev/null +++ b/tests/react/fixtures/MemoizedList.tsx @@ -0,0 +1,29 @@ +import React, { memo } from 'react'; + +interface ListItemProps { + text: string; + isSelected: boolean; +} + +const ListItem = memo(function ListItem({ text, isSelected }: ListItemProps) { + return ( +
  • + {text} +
  • + ); +}); + +interface MemoizedListProps { + items: string[]; + selectedIndex: number; +} + +export const MemoizedList = memo(function MemoizedList({ items, selectedIndex }: MemoizedListProps) { + return ( +
      + {items.map((item, index) => ( + + ))} +
    + ); +}); diff --git a/tests/react/fixtures/ServerComponent.tsx b/tests/react/fixtures/ServerComponent.tsx new file mode 100644 index 000000000..180da48f1 --- /dev/null +++ b/tests/react/fixtures/ServerComponent.tsx @@ -0,0 +1,17 @@ +"use server"; + +interface ServerPageProps { + id: string; +} + +export async function ServerPage({ id }: ServerPageProps) { + const data = await fetch(`/api/data/${id}`); + const json = await data.json(); + + return ( +
    +

    {json.title}

    +

    {json.description}

    +
    + ); +} diff --git a/tests/react/fixtures/TaskList.tsx b/tests/react/fixtures/TaskList.tsx new file mode 100644 index 000000000..65d337534 --- /dev/null +++ b/tests/react/fixtures/TaskList.tsx @@ -0,0 +1,66 @@ +import React, { useState, useContext, useCallback } from 'react'; + +interface Task { + id: number; + title: string; + completed: boolean; + priority: 'low' | 'medium' | 'high'; +} + +interface TaskListProps { + tasks: Task[]; + onToggle: (id: number) => void; + onDelete: (id: number) => void; + filter: 'all' | 'active' | 'completed'; +} + +export function TaskList({ tasks, onToggle, onDelete, filter }: TaskListProps) { + const [sortBy, setSortBy] = useState<'title' | 'priority'>('title'); + + // Inline filtering and sorting without useMemo + const filteredTasks = tasks.filter(task => { + if (filter === 'active') return !task.completed; + if (filter === 'completed') return task.completed; + return true; + }); + + const sortedTasks = filteredTasks.sort((a, b) => { + if (sortBy === 'title') return a.title.localeCompare(b.title); + const priority = { low: 0, medium: 1, high: 2 }; + return priority[b.priority] - priority[a.priority]; + }); + + // Inline function defined in render body + const handleToggle = (id: number) => { + onToggle(id); + }; + + return ( +
    +
    + + +
    +
      + {sortedTasks.map(task => ( +
    • + handleToggle(task.id)} + /> + {task.title} + +
    • + ))} +
    +
    Total: {sortedTasks.length} tasks
    +
    + ); +} diff --git a/tests/react/fixtures/UserCard.tsx b/tests/react/fixtures/UserCard.tsx new file mode 100644 index 000000000..3e4bf08b8 --- /dev/null +++ b/tests/react/fixtures/UserCard.tsx @@ -0,0 +1,22 @@ +import React from 'react'; + +interface UserCardProps { + name: string; + email: string; + role: string; + onEdit: (email: string) => void; +} + +export function UserCard({ name, email, role, onEdit }: UserCardProps) { + return ( +
    +

    {name}

    +

    {email}

    + {role} + +
    + ); +} diff --git a/tests/react/fixtures/__init__.py b/tests/react/fixtures/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/react/fixtures/package.json b/tests/react/fixtures/package.json new file mode 100644 index 000000000..76bea756d --- /dev/null +++ b/tests/react/fixtures/package.json @@ -0,0 +1,16 @@ +{ + "name": "test-react-project", + "version": "1.0.0", + "dependencies": { + "react": "^18.2.0", + "react-dom": "^18.2.0" + }, + "devDependencies": { + "@testing-library/react": "^14.0.0", + "@testing-library/user-event": "^14.0.0", + "@testing-library/jest-dom": "^6.0.0", + "typescript": "^5.0.0", + "jest": "^29.0.0", + "@types/react": "^18.0.0" + } +} diff --git a/tests/react/fixtures/useDebounce.ts b/tests/react/fixtures/useDebounce.ts new file mode 100644 index 000000000..0a45ef2ef --- /dev/null +++ b/tests/react/fixtures/useDebounce.ts @@ -0,0 +1,17 @@ +import { useState, useEffect } from 'react'; + +export function useDebounce(value: T, delay: number): T { + const [debouncedValue, setDebouncedValue] = useState(value); + + useEffect(() => { + const timer = setTimeout(() => { + setDebouncedValue(value); + }, delay); + + return () => { + clearTimeout(timer); + }; + }, [value, delay]); + + return debouncedValue; +} diff --git a/tests/react/test_analyzer.py b/tests/react/test_analyzer.py new file mode 100644 index 000000000..4f9bad026 --- /dev/null +++ b/tests/react/test_analyzer.py @@ -0,0 +1,137 @@ +"""Tests for React optimization opportunity detection.""" + +from __future__ import annotations + +from codeflash.languages.javascript.frameworks.react.analyzer import ( + OpportunitySeverity, + OpportunityType, + detect_optimization_opportunities, +) +from codeflash.languages.javascript.frameworks.react.discovery import ( + ComponentType, + ReactComponentInfo, +) + + +def _make_component_info( + name: str = "TestComponent", + start_line: int = 1, + end_line: int = 20, + is_memoized: bool = False, +) -> ReactComponentInfo: + return ReactComponentInfo( + function_name=name, + component_type=ComponentType.FUNCTION, + returns_jsx=True, + is_memoized=is_memoized, + start_line=start_line, + end_line=end_line, + ) + + +class TestDetectInlineObjects: + def test_inline_style_prop(self): + source = 'function TestComponent() {\n return
    hello
    ;\n}' + info = _make_component_info(end_line=3) + opps = detect_optimization_opportunities(source, info) + types = [o.type for o in opps] + assert OpportunityType.INLINE_OBJECT_PROP in types + + def test_inline_array_prop(self): + source = "function TestComponent() {\n return