diff --git a/src/runpod_flash/cli/commands/build.py b/src/runpod_flash/cli/commands/build.py index 65e7e7fe..d1d6f0af 100644 --- a/src/runpod_flash/cli/commands/build.py +++ b/src/runpod_flash/cli/commands/build.py @@ -31,7 +31,7 @@ from .build_utils.lb_handler_generator import LBHandlerGenerator from .build_utils.manifest import ManifestBuilder from .build_utils.resource_config_generator import generate_all_resource_configs -from .build_utils.scanner import RemoteDecoratorScanner +from .build_utils.scanner import RuntimeScanner logger = logging.getLogger(__name__) @@ -302,9 +302,16 @@ def run_build( copy_project_files(files, project_dir, build_dir) try: - scanner = RemoteDecoratorScanner(build_dir) + scanner = RuntimeScanner(build_dir) remote_functions = scanner.discover_remote_functions() + if scanner.import_errors: + console.print("\n[red bold]Failed to load:[/red bold]") + for filename, err in scanner.import_errors.items(): + console.print(f" [red]{filename}[/red]: {err}") + if not remote_functions: + raise typer.Exit(1) + manifest_builder = ManifestBuilder( app_name, remote_functions, @@ -326,6 +333,8 @@ def run_build( deployment_manifest_path = flash_dir / "flash_manifest.json" shutil.copy2(manifest_path, deployment_manifest_path) + except typer.Exit: + raise except (ImportError, SyntaxError) as e: console.print(f"[red]Error:[/red] Code analysis failed: {e}") logger.exception("Code analysis failed") @@ -534,7 +543,7 @@ def validate_project_structure(project_dir: Path) -> bool: Validate that directory is a Flash project. A Flash project is any directory containing Python files. The - RemoteDecoratorScanner validates that @remote functions exist. + RuntimeScanner validates that @remote functions exist. Args: project_dir: Directory to validate diff --git a/src/runpod_flash/cli/commands/build_utils/manifest.py b/src/runpod_flash/cli/commands/build_utils/manifest.py index d01ca0dd..64f0af80 100644 --- a/src/runpod_flash/cli/commands/build_utils/manifest.py +++ b/src/runpod_flash/cli/commands/build_utils/manifest.py @@ -73,9 +73,7 @@ def __init__( ): self.project_name = project_name self.remote_functions = remote_functions - self.scanner = ( - scanner # Optional: RemoteDecoratorScanner with resource config info - ) + self.scanner = scanner # Optional: RuntimeScanner with resource config info self.build_dir = build_dir self.python_version = ( python_version or f"{sys.version_info.major}.{sys.version_info.minor}" @@ -205,6 +203,12 @@ def _extract_config_properties(config: Dict[str, Any], resource_config) -> None: if hasattr(resource_config, "workersMax"): config["workersMax"] = resource_config.workersMax + if ( + hasattr(resource_config, "idleTimeout") + and resource_config.idleTimeout is not None + ): + config["idleTimeout"] = resource_config.idleTimeout + if ( hasattr(resource_config, "scalerType") and resource_config.scalerType is not None diff --git a/src/runpod_flash/cli/commands/build_utils/scanner.py b/src/runpod_flash/cli/commands/build_utils/scanner.py index 5beb563d..06c021fa 100644 --- a/src/runpod_flash/cli/commands/build_utils/scanner.py +++ b/src/runpod_flash/cli/commands/build_utils/scanner.py @@ -1,69 +1,45 @@ -"""AST scanner for discovering @remote decorated functions and classes.""" +"""project scanner for discovering @remote decorated functions and classes. + +imports user modules and inspects live objects via their __remote_config__ +attribute (stamped by @remote and Endpoint). this handles all python +language constructs without needing to rebuild an interpreter via AST. +""" import ast +import importlib.util +import inspect import logging import os -import re +import sys +import types from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Set from runpod_flash.cli.utils.ignore import get_file_tree, load_ignore_patterns +from runpod_flash.core.resources.load_balancer_sls_resource import ( + LoadBalancerSlsResource, +) +from runpod_flash.core.resources.serverless import ServerlessResource +from runpod_flash.endpoint import Endpoint logger = logging.getLogger(__name__) def file_to_url_prefix(file_path: Path, project_root: Path) -> str: - """Derive the local dev server URL prefix from a source file path. - - Args: - file_path: Absolute path to the Python source file - project_root: Absolute path to the project root directory - - Returns: - URL prefix starting with "/" (e.g., /longruns/stage1) - - Example: - longruns/stage1.py → /longruns/stage1 - """ + """e.g. longruns/stage1.py -> /longruns/stage1""" rel = file_path.relative_to(project_root).with_suffix("") return "/" + str(rel).replace(os.sep, "/") def file_to_resource_name(file_path: Path, project_root: Path) -> str: - """Derive the manifest resource name from a source file path. - - Slashes and hyphens are replaced with underscores to produce a valid - Python identifier suitable for use as a resource name. - - Args: - file_path: Absolute path to the Python source file - project_root: Absolute path to the project root directory - - Returns: - Resource name using underscores (e.g., longruns_stage1) - - Example: - longruns/stage1.py → longruns_stage1 - my-worker.py → my_worker - """ + """e.g. longruns/stage1.py -> longruns_stage1, my-worker.py -> my_worker""" rel = file_path.relative_to(project_root).with_suffix("") return str(rel).replace(os.sep, "_").replace("/", "_").replace("-", "_") def file_to_module_path(file_path: Path, project_root: Path) -> str: - """Derive the Python dotted module path from a source file path. - - Args: - file_path: Absolute path to the Python source file - project_root: Absolute path to the project root directory - - Returns: - Dotted module path (e.g., longruns.stage1) - - Example: - longruns/stage1.py → longruns.stage1 - """ + """e.g. longruns/stage1.py -> longruns.stage1""" rel = file_path.relative_to(project_root).with_suffix("") return str(rel).replace(os.sep, ".").replace("/", ".") @@ -79,1168 +55,492 @@ class RemoteFunctionMetadata: is_async: bool is_class: bool file_path: Path - http_method: Optional[str] = None # HTTP method for LB endpoints: GET, POST, etc. - http_path: Optional[str] = None # HTTP path for LB endpoints: /api/process - is_load_balanced: bool = False # LoadBalancerSlsResource or LiveLoadBalancer - is_live_resource: bool = ( - False # LiveLoadBalancer (vs deployed LoadBalancerSlsResource) - ) - config_variable: Optional[str] = None # Variable name like "gpu_config" - calls_remote_functions: bool = ( - False # Does this function call other @remote functions? - ) - called_remote_functions: List[str] = field( - default_factory=list - ) # Names of @remote functions called - is_lb_route_handler: bool = ( - False # LB @remote with method= and path= — runs directly as HTTP handler - ) - class_methods: List[str] = field( - default_factory=list - ) # Public methods for @remote classes - param_names: List[str] = field( - default_factory=list - ) # Function params excluding self - class_method_params: Dict[str, List[str]] = field( - default_factory=dict - ) # method_name -> param_names (for classes) - docstring: Optional[str] = None # First line of function/class docstring - class_method_docstrings: Dict[str, Optional[str]] = field( - default_factory=dict - ) # method_name -> first line of docstring - local: bool = False # Execute locally instead of remote dispatch - - -class RemoteDecoratorScanner: - """Scans Python files for @remote decorators and extracts metadata.""" - - def __init__(self, project_dir: Path): - self.project_dir = project_dir - self.py_files: List[Path] = [] - self.resource_configs: Dict[str, str] = {} # name -> name - self.resource_types: Dict[str, str] = {} # name -> type - self.resource_flags: Dict[str, Dict[str, bool]] = {} # name -> {flag: bool} - self.resource_variables: Dict[str, str] = {} # name -> variable_name - # tracks Endpoint(...) variable assignments for LB route detection. - # maps variable_name -> resource_name - self._endpoint_variables: Dict[str, str] = {} - # maps module_path:variable_name -> resource_name for cross-module lookups - self._endpoint_variables_qualified: Dict[str, str] = {} + http_method: Optional[str] = None + http_path: Optional[str] = None + is_load_balanced: bool = False + is_live_resource: bool = False + config_variable: Optional[str] = None + calls_remote_functions: bool = False + called_remote_functions: List[str] = field(default_factory=list) + is_lb_route_handler: bool = False + class_methods: List[str] = field(default_factory=list) + param_names: List[str] = field(default_factory=list) + class_method_params: Dict[str, List[str]] = field(default_factory=dict) + docstring: Optional[str] = None + class_method_docstrings: Dict[str, Optional[str]] = field(default_factory=dict) + local: bool = False + + +def _first_docstring_line(obj: Any) -> Optional[str]: + """extract the first line of an object's docstring, or None.""" + doc = inspect.getdoc(obj) + if doc: + return doc.split("\n")[0].strip() + return None - def discover_remote_functions(self) -> List[RemoteFunctionMetadata]: - """Discover all @remote decorated functions and classes.""" - functions = [] - # Use .gitignore / .flashignore aware file walker with early directory pruning. - # This avoids descending into .venv, __pycache__, .flash, etc. - spec = load_ignore_patterns(self.project_dir) - all_files = get_file_tree(self.project_dir, spec) - self.py_files = [ - f for f in all_files if f.suffix == ".py" and f.name != "__init__.py" +def _get_param_names(func: Any) -> List[str]: + """extract parameter names from a callable, excluding 'self'.""" + try: + sig = inspect.signature(func) + return [ + name + for name, param in sig.parameters.items() + if name != "self" + and param.kind + not in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ) ] + except (ValueError, TypeError): + return [] + + +def _unwrap_to_original(obj: Any) -> Any: + """follow __wrapped__ to get the original function for signature/docstring.""" + return inspect.unwrap(obj, stop=lambda f: not hasattr(f, "__wrapped__")) + + +def _resource_type_name(resource_config: Any) -> str: + """get the class name of a resource config, unwrapping Endpoint if needed.""" + if isinstance(resource_config, Endpoint): + inner = resource_config._build_resource_config() + return type(inner).__name__ + return type(resource_config).__name__ + + +def _resource_name(resource_config: Any) -> str: + """get the name from a resource config object.""" + if isinstance(resource_config, Endpoint): + return resource_config.name or "" + return getattr(resource_config, "name", "") or "" + + +def _is_lb_type(resource_config: Any) -> bool: + """check if a resource config is a load-balanced type.""" + if isinstance(resource_config, Endpoint): + return resource_config.is_load_balanced + if isinstance(resource_config, LoadBalancerSlsResource): + return True + type_name = type(resource_config).__name__ + return type_name in ( + "LoadBalancerSlsResource", + "CpuLoadBalancerSlsResource", + "LiveLoadBalancer", + "CpuLiveLoadBalancer", + ) - # First pass: extract all resource configs from all files - for py_file in self.py_files: - try: - content = py_file.read_text(encoding="utf-8") - tree = ast.parse(content) - self._extract_resource_configs(tree, py_file) - except UnicodeDecodeError: - pass - except SyntaxError as e: - logger.warning(f"Syntax error in {py_file}: {e}") - except Exception: - pass - - # Second pass: extract @remote decorated functions - for py_file in self.py_files: - try: - content = py_file.read_text(encoding="utf-8") - tree = ast.parse(content) - functions.extend(self._extract_remote_functions(tree, py_file)) - except UnicodeDecodeError: - pass - except SyntaxError as e: - logger.warning(f"Syntax error in {py_file}: {e}") - except Exception: - pass - # Third pass: analyze function call graphs - remote_function_names = {f.function_name for f in functions} - for py_file in self.py_files: - try: - content = py_file.read_text(encoding="utf-8") - tree = ast.parse(content) - - # Build a map of function/class name -> AST node for this file. - # class nodes are included so that cross-endpoint calls inside - # class methods are detected when the endpoint wraps a class. - func_node_map: Dict[str, ast.AST] = {} - for node in ast.walk(tree): - if isinstance( - node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) - ): - func_node_map.setdefault(node.name, node) - - # Find each @remote function and analyze its calls - for func_meta in [f for f in functions if f.file_path == py_file]: - node = func_node_map.get(func_meta.function_name) - if node is not None: - self._analyze_function_calls( - node, func_meta, remote_function_names - ) - except UnicodeDecodeError: - pass - except SyntaxError as e: - logger.warning(f"Syntax error in {py_file}: {e}") - except Exception: - pass +def _is_live_type(resource_config: Any) -> bool: + """check if a resource config is a live (on-demand) type.""" + if isinstance(resource_config, Endpoint): + return True + type_name = type(resource_config).__name__ + return type_name in ( + "LiveServerless", + "CpuLiveServerless", + "LiveLoadBalancer", + "CpuLiveLoadBalancer", + ) - return functions - def _extract_resource_configs(self, tree: ast.AST, py_file: Path) -> None: - """Extract resource config variable assignments and determine type flags. +def _extract_class_info( + cls: type, +) -> tuple[List[str], Dict[str, List[str]], Dict[str, Optional[str]]]: + """extract public methods, their params, and their docstrings from a class. - This method extracts resource configurations and determines is_load_balanced - and is_live_resource flags using string-based type matching. Handles both - legacy resource classes (LiveServerless, etc.) and the unified Endpoint class. - """ - module_path = self._get_module_path(py_file) - - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - # Look for assignments like: gpu_config = LiveServerless(...) or api = LiveLoadBalancer(...) - for target in node.targets: - if isinstance(target, ast.Name): - variable_name = target.id - config_type = self._get_call_type(node.value) - - # handle Endpoint(...) assignments (LB pattern) - if config_type == "Endpoint": - self._register_endpoint_variable( - node.value, variable_name, module_path - ) - continue - - # Accept any class that looks like a resource config (DeployableResource) - if config_type and self._is_resource_config_type(config_type): - # Extract the resource's name parameter (the actual identifier) - # If extraction fails, fall back to variable name - resource_name = self._extract_resource_name(node.value) - if not resource_name: - resource_name = variable_name - - # Store mapping using the resource's name (or variable name as fallback) - self.resource_configs[resource_name] = resource_name - self.resource_types[resource_name] = config_type - - # Store variable name for config discovery during provisioning - self.resource_variables[resource_name] = variable_name - - # Also store variable name mapping for local lookups in same module - var_key = f"{module_path}:{variable_name}" - self.resource_configs[var_key] = resource_name - self.resource_types[var_key] = config_type - self.resource_variables[var_key] = variable_name - - # Determine boolean flags using string-based type checking - # This is determined by isinstance() at scan time in production, - # but we use string matching for reliability - is_load_balanced = config_type in [ - "LoadBalancerSlsResource", - "LiveLoadBalancer", - "CpuLiveLoadBalancer", - ] - is_live_resource = config_type in [ - "LiveLoadBalancer", - "CpuLiveLoadBalancer", - ] - - # Store flags for this resource - self.resource_flags[resource_name] = { - "is_load_balanced": is_load_balanced, - "is_live_resource": is_live_resource, - } - # Also store for variable key - self.resource_flags[var_key] = { - "is_load_balanced": is_load_balanced, - "is_live_resource": is_live_resource, - } - - def _register_endpoint_variable( - self, call_node: ast.Call, variable_name: str, module_path: str - ) -> None: - """Register an Endpoint(...) variable assignment for LB route detection. - - ep = Endpoint(name="my-api", ...) is stored so that @ep.get("/path") - decorators in pass 2 can be resolved back to this endpoint config. - """ - resource_name = self._extract_resource_name(call_node) - if not resource_name: - resource_name = variable_name - - var_key = f"{module_path}:{variable_name}" - - # track as endpoint variable for route decorator resolution - self._endpoint_variables[variable_name] = resource_name - self._endpoint_variables_qualified[var_key] = resource_name - - # also register in the standard resource tracking dicts so that - # downstream code (manifest, server gen) can find it - self.resource_configs[resource_name] = resource_name - self.resource_types[resource_name] = "Endpoint" - self.resource_variables[resource_name] = variable_name - - self.resource_configs[var_key] = resource_name - self.resource_types[var_key] = "Endpoint" - self.resource_variables[var_key] = variable_name - - # endpoint variable assignments are LB (route-based) by default. - # if someone uses @ep directly as a decorator (QB), they would - # use the inline @Endpoint(...) pattern instead of an assignment. - flags = {"is_load_balanced": True, "is_live_resource": True} - self.resource_flags[resource_name] = flags - self.resource_flags[var_key] = flags - - def _extract_remote_functions( - self, tree: ast.AST, py_file: Path - ) -> List[RemoteFunctionMetadata]: - """Extract @remote, @Endpoint(...), and @ep.get/post/... decorated functions and classes.""" - module_path = self._get_module_path(py_file) - functions = [] - - for node in ast.walk(tree): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): - # check for @ep.get("/path") endpoint route decorator (LB mode) - route_info = self._find_endpoint_route_decorator( - node.decorator_list, module_path - ) - if route_info is not None: - functions.append( - self._build_endpoint_route_metadata( - node, route_info, module_path, py_file - ) - ) - continue + uses vars(cls) to preserve source-definition order rather than + inspect.getmembers which sorts alphabetically. + """ + methods: List[str] = [] + method_params: Dict[str, List[str]] = {} + method_docstrings: Dict[str, Optional[str]] = {} - # check for @Endpoint(...) decorator (QB mode) - endpoint_decorator = self._find_endpoint_qb_decorator( - node.decorator_list - ) - if endpoint_decorator is not None: - functions.append( - self._build_endpoint_qb_metadata( - node, endpoint_decorator, module_path, py_file - ) - ) - continue + for name, member in vars(cls).items(): + if name.startswith("_"): + continue + if not (inspect.isfunction(member) or inspect.iscoroutinefunction(member)): + continue + methods.append(name) + method_params[name] = _get_param_names(member) + method_docstrings[name] = _first_docstring_line(member) - # Check if this node has @remote decorator - remote_decorator = self._find_remote_decorator(node.decorator_list) - - if remote_decorator: - # Extract resource config name from decorator - resource_config_name = self._extract_resource_config_name( - remote_decorator, module_path - ) - - if resource_config_name: - is_async = isinstance(node, ast.AsyncFunctionDef) - is_class = isinstance(node, ast.ClassDef) - - # Get resource type for this config - resource_type = self._get_resource_type(resource_config_name) - - # Extract HTTP routing metadata (for LB endpoints) - http_method, http_path = self._extract_http_routing( - remote_decorator - ) - - # Extract local execution flag - local = self._extract_local_flag(remote_decorator) - - # Get flags for this resource - flags = self.resource_flags.get( - resource_config_name, - {"is_load_balanced": False, "is_live_resource": False}, - ) - - # An LB route handler is an LB @remote function that has - # both method= and path= declared. Its body runs directly - # on the LB endpoint — it is NOT a remote dispatch stub. - is_lb_route_handler = ( - flags["is_load_balanced"] - and http_method is not None - and http_path is not None - ) - - # Extract docstring (first line only) - raw_docstring = ast.get_docstring(node) - docstring: Optional[str] = None - if raw_docstring: - docstring = raw_docstring.split("\n")[0].strip() - - # Extract public methods for @remote classes - class_methods: List[str] = [] - class_method_params: Dict[str, List[str]] = {} - class_method_docstrings: Dict[str, Optional[str]] = {} - if is_class: - for n in node.body: - if isinstance( - n, (ast.FunctionDef, ast.AsyncFunctionDef) - ) and not n.name.startswith("_"): - class_methods.append(n.name) - class_method_params[n.name] = [ - arg.arg - for arg in n.args.args - if arg.arg != "self" - ] - raw_method_doc = ast.get_docstring(n) - if raw_method_doc: - class_method_docstrings[n.name] = ( - raw_method_doc.split("\n")[0].strip() - ) - else: - class_method_docstrings[n.name] = None - - # Extract param names for functions (not classes) - param_names: List[str] = [] - if not is_class and isinstance( - node, (ast.FunctionDef, ast.AsyncFunctionDef) - ): - param_names = [ - arg.arg for arg in node.args.args if arg.arg != "self" - ] - - metadata = RemoteFunctionMetadata( - function_name=node.name, - module_path=module_path, - resource_config_name=resource_config_name, - resource_type=resource_type, - is_async=is_async, - is_class=is_class, - file_path=py_file, - http_method=http_method, - http_path=http_path, - is_load_balanced=flags["is_load_balanced"], - is_live_resource=flags["is_live_resource"], - config_variable=self.resource_variables.get( - resource_config_name - ), - is_lb_route_handler=is_lb_route_handler, - class_methods=class_methods, - param_names=param_names, - class_method_params=class_method_params, - docstring=docstring, - class_method_docstrings=class_method_docstrings, - local=local, - ) - functions.append(metadata) + return methods, method_params, method_docstrings - return functions - def _find_remote_decorator(self, decorators: List[ast.expr]) -> Optional[ast.expr]: - """Find @remote decorator in a list of decorators.""" - for decorator in decorators: - # Handle @remote or @remote(...) - if isinstance(decorator, ast.Name): - if decorator.id == "remote": - return decorator - elif isinstance(decorator, ast.Call): - if isinstance(decorator.func, ast.Name): - if decorator.func.id == "remote": - return decorator - elif isinstance(decorator.func, ast.Attribute): - if decorator.func.attr == "remote": - return decorator +def _import_module_from_file(file_path: Path, module_name: str) -> Any: + """import a python file as a module. returns the module or None on failure. + temporarily injects into sys.modules for the duration of exec_module + (so relative imports within the file resolve), then restores the + previous entry to avoid leaking user modules into the cli process. + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if not spec or not spec.loader: return None - def _find_endpoint_qb_decorator( - self, decorators: List[ast.expr] - ) -> Optional[ast.Call]: - """Find @Endpoint(...) decorator used in QB (direct decoration) mode. + module = importlib.util.module_from_spec(spec) + old_module = sys.modules.get(module_name) + sys.modules[module_name] = module - Returns the Call node if found, None otherwise. - """ - for decorator in decorators: - if isinstance(decorator, ast.Call): - if ( - isinstance(decorator.func, ast.Name) - and decorator.func.id == "Endpoint" - ): - return decorator - if ( - isinstance(decorator.func, ast.Attribute) - and decorator.func.attr == "Endpoint" - ): - return decorator - return None - - def _find_endpoint_route_decorator( - self, decorators: List[ast.expr], module_path: str - ) -> Optional[Dict[str, str]]: - """Find @ep.get("/path") style route decorators on Endpoint instances. + try: + spec.loader.exec_module(module) + return module + except Exception as e: + logger.debug("failed to import %s: %s", file_path.name, e) + raise + finally: + _restore_module(module_name, old_module) - Returns dict with keys: variable_name, resource_name, method, path. - Returns None if no matching decorator found. - """ - for decorator in decorators: - if not isinstance(decorator, ast.Call): - continue - if not isinstance(decorator.func, ast.Attribute): - continue - method_name = decorator.func.attr - if method_name not in ("get", "post", "put", "delete", "patch"): - continue - if not isinstance(decorator.func.value, ast.Name): - continue - variable_name = decorator.func.value.id +def _restore_module(module_name: str, old_module: Any) -> None: + if old_module is not None: + sys.modules[module_name] = old_module + else: + sys.modules.pop(module_name, None) - # resolve variable to a known Endpoint instance - var_key = f"{module_path}:{variable_name}" - resource_name = self._endpoint_variables_qualified.get(var_key) - if resource_name is None: - resource_name = self._endpoint_variables.get(variable_name) - if resource_name is None: - continue - # extract path from first positional arg - path = None - if decorator.args and isinstance(decorator.args[0], ast.Constant): - path = decorator.args[0].value - if not path: - continue +def _metadata_from_remote_config( + obj: Any, + attr_name: str, + module_path: str, + file_path: Path, + variable_name: Optional[str] = None, +) -> Optional[RemoteFunctionMetadata]: + """build RemoteFunctionMetadata from an object with __remote_config__. - return { - "variable_name": variable_name, - "resource_name": resource_name, - "method": method_name.upper(), - "path": path, - } + obj is the decorated function/class (or wrapper). attr_name is the + module-level attribute name it was found under. + """ + config = getattr(obj, "__remote_config__", None) + if not isinstance(config, dict): return None - def _build_endpoint_qb_metadata( - self, - node: ast.AST, - decorator: ast.Call, - module_path: str, - py_file: Path, - ) -> RemoteFunctionMetadata: - """Build metadata for a @Endpoint(...) QB-decorated function or class.""" - resource_name = self._extract_resource_name(decorator) - if not resource_name: - resource_name = node.name - - is_async = isinstance(node, ast.AsyncFunctionDef) - is_class = isinstance(node, ast.ClassDef) - - docstring = None - raw_docstring = ast.get_docstring(node) - if raw_docstring: - docstring = raw_docstring.split("\n")[0].strip() - - class_methods: List[str] = [] - class_method_params: Dict[str, List[str]] = {} - class_method_docstrings: Dict[str, Optional[str]] = {} - param_names: List[str] = [] - - if is_class: - for n in node.body: - if isinstance( - n, (ast.FunctionDef, ast.AsyncFunctionDef) - ) and not n.name.startswith("_"): - class_methods.append(n.name) - class_method_params[n.name] = [ - arg.arg for arg in n.args.args if arg.arg != "self" - ] - raw_method_doc = ast.get_docstring(n) - class_method_docstrings[n.name] = ( - raw_method_doc.split("\n")[0].strip() - if raw_method_doc - else None - ) - elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - param_names = [arg.arg for arg in node.args.args if arg.arg != "self"] - - return RemoteFunctionMetadata( - function_name=node.name, - module_path=module_path, - resource_config_name=resource_name, - resource_type="Endpoint", - is_async=is_async, - is_class=is_class, - file_path=py_file, - http_method=None, - http_path=None, - is_load_balanced=False, - is_live_resource=True, - config_variable=None, - is_lb_route_handler=False, - class_methods=class_methods, - param_names=param_names, - class_method_params=class_method_params, - docstring=docstring, - class_method_docstrings=class_method_docstrings, - ) - - def _build_endpoint_route_metadata( - self, - node: ast.AST, - route_info: Dict[str, str], - module_path: str, - py_file: Path, - ) -> RemoteFunctionMetadata: - """Build metadata for a @ep.get("/path") LB route handler.""" - is_async = isinstance(node, ast.AsyncFunctionDef) - - docstring = None - raw_docstring = ast.get_docstring(node) - if raw_docstring: - docstring = raw_docstring.split("\n")[0].strip() - - param_names: List[str] = [] - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - param_names = [arg.arg for arg in node.args.args if arg.arg != "self"] - - return RemoteFunctionMetadata( - function_name=node.name, - module_path=module_path, - resource_config_name=route_info["resource_name"], - resource_type="Endpoint", - is_async=is_async, - is_class=False, - file_path=py_file, - http_method=route_info["method"], - http_path=route_info["path"], - is_load_balanced=True, - is_live_resource=True, - config_variable=route_info["variable_name"], - is_lb_route_handler=True, - param_names=param_names, - docstring=docstring, - ) - - def _extract_resource_config_name( - self, decorator: ast.expr, module_path: str - ) -> Optional[str]: - """Extract resource_config name from @remote decorator.""" - if isinstance(decorator, ast.Name): - # @remote without arguments - return None - - if isinstance(decorator, ast.Call): - # @remote(...) with arguments - # Look for resource_config= or first positional arg - for keyword in decorator.keywords: - if keyword.arg == "resource_config": - return self._extract_name_from_expr(keyword.value, module_path) - - # Try first positional argument - if decorator.args: - return self._extract_name_from_expr(decorator.args[0], module_path) - + resource_config = config.get("resource_config") + if resource_config is None: return None - def _extract_name_from_expr( - self, expr: ast.expr, module_path: str - ) -> Optional[str]: - """Extract config name from an expression (Name or Call). + method = config.get("method") + path = config.get("path") + is_lb_route = config.get("is_lb_route_handler", False) + + # determine if the decorated target is a class. + # RemoteClassWrapper has _wrapped_class pointing at the original class. + # for plain classes (e.g. local=True), use inspect.isclass. + original = _unwrap_to_original(obj) + is_class = False + target_class = None + + if inspect.isclass(obj) and hasattr(obj, "_wrapped_class"): + is_class = True + target_class = obj._wrapped_class + elif inspect.isclass(original): + is_class = True + target_class = original + elif inspect.isclass(obj): + is_class = True + target_class = obj + + is_async = False + if not is_class: + is_async = inspect.iscoroutinefunction(original) or inspect.iscoroutinefunction( + obj + ) - Returns the resource's name (from the name= parameter), not the variable name. - """ - if isinstance(expr, ast.Name): - # Variable reference: @remote(gpu_config) - variable_name = expr.id - - # Try module-scoped lookup first (current module) - var_key = f"{module_path}:{variable_name}" - if var_key in self.resource_configs: - # Return the actual resource name (mapped from variable) - return self.resource_configs[var_key] - - # Try simple name lookup - if variable_name in self.resource_configs: - return self.resource_configs[variable_name] - - # Fall back to the variable name itself (unresolved reference) - return variable_name - - elif isinstance(expr, ast.Call): - # Direct instantiation: @remote(LiveServerless(name="gpu_config")) - # Extract the name= parameter - resource_name = self._extract_resource_name(expr) - if resource_name: - return resource_name + # function/class name: for classes, use the original class name. + # for functions, use __name__ from the unwrapped function. + if is_class and target_class is not None: + func_name = target_class.__name__ + else: + func_name = getattr(original, "__name__", None) or attr_name - return None + res_name = _resource_name(resource_config) + res_type = _resource_type_name(resource_config) + is_lb = _is_lb_type(resource_config) or (method is not None and path is not None) + is_live = _is_live_type(resource_config) - # All ServerlessResource subclasses exported by runpod_flash.__init__.py. - # Checked at test time by test_resource_config_types_matches_exports(). - _RESOURCE_CONFIG_TYPES = frozenset( - { - "ServerlessEndpoint", - "CpuServerlessEndpoint", - "LoadBalancerSlsResource", - "CpuLoadBalancerSlsResource", - "LiveServerless", - "CpuLiveServerless", - "LiveLoadBalancer", - "CpuLiveLoadBalancer", - } + docstring_source = ( + target_class if is_class and target_class is not None else original ) + docstring = _first_docstring_line(docstring_source) - def _is_resource_config_type(self, type_name: str) -> bool: - """Check if a type name is a known ServerlessResource subclass.""" - return type_name in self._RESOURCE_CONFIG_TYPES - - def _get_call_type(self, expr: ast.expr) -> Optional[str]: - """Get the type name of a call expression.""" - if isinstance(expr, ast.Call): - if isinstance(expr.func, ast.Name): - return expr.func.id - elif isinstance(expr.func, ast.Attribute): - return expr.func.attr - - return None - - def _extract_resource_name(self, expr: ast.expr) -> Optional[str]: - """Extract the 'name' parameter from a resource config instantiation. - - For example, from LiveServerless(name="01_01_gpu_worker", ...) - returns "01_01_gpu_worker". - """ - if isinstance(expr, ast.Call): - for keyword in expr.keywords: - if keyword.arg == "name": - if isinstance(keyword.value, ast.Constant): - return keyword.value.value - return None + class_methods: List[str] = [] + class_method_params: Dict[str, List[str]] = {} + class_method_docstrings: Dict[str, Optional[str]] = {} + param_names: List[str] = [] - def _analyze_function_calls( - self, - func_node: ast.AST, - function_metadata: RemoteFunctionMetadata, - remote_function_names: set[str], - ) -> None: - """Analyze if a function calls other @remote functions. - - Only matches direct calls (e.g. ``generate(prompt)``) — not attribute - calls (e.g. ``model.generate(prompt)``). In flash, ``@remote`` - functions are always invoked as direct calls after import, never via - attribute access. Matching on bare method names would cause false - positives whenever an unrelated object happens to have a method with - the same name as a ``@remote`` function. - - Args: - func_node: AST node for the function - function_metadata: Metadata to update with call information - remote_function_names: Set of all @remote function names - """ - # Walk AST looking for function calls - for node in ast.walk(func_node): - if isinstance(node, ast.Call): - # Handle direct calls: some_function() - if isinstance(node.func, ast.Name): - called_name = node.func.id - if called_name in remote_function_names: - function_metadata.calls_remote_functions = True - if called_name not in function_metadata.called_remote_functions: - function_metadata.called_remote_functions.append( - called_name - ) - - def _get_resource_type(self, resource_config_name: str) -> str: - """Get the resource type for a given config name.""" - if resource_config_name in self.resource_types: - return self.resource_types[resource_config_name] - # Default to LiveServerless if type not found - return "LiveServerless" - - def _sanitize_resource_name(self, name: str) -> str: - """Sanitize resource config name for use in filenames. - - Replaces invalid filename characters with underscores and ensures - the name starts with a letter or underscore (valid for Python identifiers). - - Args: - name: Raw resource config name - - Returns: - Sanitized name safe for use in filenames and as Python identifiers - """ - # Replace invalid characters with underscores - sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name) - - # Ensure it starts with a letter or underscore - if sanitized and not (sanitized[0].isalpha() or sanitized[0] == "_"): - sanitized = f"_{sanitized}" + if is_class and target_class is not None: + class_methods, class_method_params, class_method_docstrings = ( + _extract_class_info(target_class) + ) + elif not is_class: + param_names = _get_param_names(original) + + local_flag = getattr(obj, "__flash_local__", False) + + return RemoteFunctionMetadata( + function_name=func_name, + module_path=module_path, + resource_config_name=res_name, + resource_type=res_type, + is_async=is_async, + is_class=is_class, + file_path=file_path, + http_method=method, + http_path=path, + is_load_balanced=is_lb, + is_live_resource=is_live, + config_variable=variable_name, + is_lb_route_handler=is_lb_route, + class_methods=class_methods, + param_names=param_names, + class_method_params=class_method_params, + docstring=docstring, + class_method_docstrings=class_method_docstrings, + local=local_flag, + ) - return sanitized or "_" - def _get_module_path(self, py_file: Path) -> str: - """Convert file path to module path.""" +def _find_endpoint_instances(module: Any) -> Dict[str, Endpoint]: + """find all Endpoint instances in a module's namespace.""" + endpoints: Dict[str, Endpoint] = {} + for name in dir(module): try: - # Get relative path from project directory - rel_path = py_file.relative_to(self.project_dir) - - # Remove .py extension and convert / to . - module = str(rel_path.with_suffix("")).replace("/", ".").replace("\\", ".") - - return module - except ValueError: - # If relative_to fails, just use filename - return py_file.stem - - def _extract_http_routing( - self, decorator: ast.expr - ) -> tuple[Optional[str], Optional[str]]: - """Extract HTTP method and path from @remote decorator. - - Returns: - Tuple of (method, path) or (None, None) if not found. - method: GET, POST, PUT, DELETE, PATCH - path: /api/endpoint routes - - Raises: - ValueError: If method is not a valid HTTP verb - """ - if not isinstance(decorator, ast.Call): - return None, None - - http_method = None - http_path = None - - # Extract keyword arguments: method="POST", path="/api/process" - for keyword in decorator.keywords: - if keyword.arg == "method": - if isinstance(keyword.value, ast.Constant): - http_method = keyword.value.value - elif keyword.arg == "path": - if isinstance(keyword.value, ast.Constant): - http_path = keyword.value.value - - # Validate HTTP method if provided - valid_methods = {"GET", "POST", "PUT", "DELETE", "PATCH"} - if http_method is not None and http_method.upper() not in valid_methods: - raise ValueError( - f"Invalid HTTP method '{http_method}'. Must be one of: {', '.join(valid_methods)}" - ) - - return http_method, http_path - - def _extract_local_flag(self, decorator: ast.expr) -> bool: - """Extract local=True/False from @remote decorator. - - Returns True if the decorator has local=True, False otherwise. - """ - if not isinstance(decorator, ast.Call): - return False - - for keyword in decorator.keywords: - if keyword.arg == "local" and isinstance(keyword.value, ast.Constant): - return bool(keyword.value.value) - - return False - - -def detect_main_app( - project_root: Path, explicit_lb_exists: bool = False -) -> Optional[dict]: - """Detect main.py FastAPI app and return load balancer config. - - Searches for main.py/app.py/server.py and parses AST to find FastAPI app. - Only returns config if app has custom routes (not just @remote calls). - - Args: - project_root: Root directory of Flash project - explicit_lb_exists: If True, skip auto-detection (explicit config takes precedence) - - Returns: - Dict with app metadata: { - 'file_path': Path, - 'app_variable': str, - 'has_routes': bool, - } - Returns None if no FastAPI app found with custom routes or explicit_lb_exists is True. - """ - if explicit_lb_exists: - # Explicit load balancer config exists, skip auto-detection - return None - for filename in ["main.py", "app.py", "server.py"]: - main_path = project_root / filename - if not main_path.exists(): + obj = getattr(module, name) + except Exception: continue + if isinstance(obj, Endpoint) and not obj.is_client: + endpoints[name] = obj + return endpoints - try: - content = main_path.read_text(encoding="utf-8") - tree = ast.parse(content) - - # Find FastAPI app instantiation - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - if isinstance(node.value, ast.Call): - call_type = None - if isinstance(node.value.func, ast.Name): - call_type = node.value.func.id - elif isinstance(node.value.func, ast.Attribute): - call_type = node.value.func.attr - - if call_type == "FastAPI": - app_variable = None - for target in node.targets: - if isinstance(target, ast.Name): - app_variable = target.id - break - - if app_variable: - # Check for custom routes (not just @remote) - has_routes = _has_custom_routes(tree, app_variable) - - # 1. Extract direct app routes - module_path = filename.replace(".py", "") - app_routes = _extract_fastapi_routes( - tree, app_variable, module_path - ) - - # 2. Find included routers - included_routers = _find_included_routers( - tree, app_variable - ) - - # 3. Extract routes from each included router - router_routes = [] - for router_var, prefix in included_routers: - router_file = _resolve_router_import( - router_var, tree, main_path, project_root - ) - if router_file and router_file.exists(): - try: - router_content = router_file.read_text( - encoding="utf-8" - ) - router_tree = ast.parse(router_content) - router_module = router_file.stem - - routes = _extract_router_routes( - router_tree, - router_var, - router_module, - prefix, - router_file, - ) - router_routes.extend(routes) - except ( - UnicodeDecodeError, - SyntaxError, - PermissionError, - OSError, - ) as e: - logger.debug( - f"Failed to parse router file {router_file}: {e}" - ) - - # 4. Combine all routes - all_fastapi_routes = app_routes + router_routes - - # 5. Update file_path for all routes - for route in all_fastapi_routes: - if route.file_path == Path(module_path): - route.file_path = main_path - - return { - "file_path": main_path, - "app_variable": app_variable, - "has_routes": has_routes - or bool(all_fastapi_routes), - "fastapi_routes": all_fastapi_routes, - } - except UnicodeDecodeError: - logger.debug(f"Skipping non-UTF-8 file: {main_path}") - except SyntaxError as e: - logger.debug(f"Syntax error in {main_path}: {e}") - except Exception as e: - logger.debug(f"Failed to parse {main_path}: {e}") - - return None - - -def _has_custom_routes(tree: ast.AST, app_variable: str) -> bool: - """Check if FastAPI app has custom routes (beyond @remote). - - Args: - tree: AST tree of the file - app_variable: Name of the FastAPI app variable - - Returns: - True if app has route decorators (app.get, app.post, etc.) - """ - for node in ast.walk(tree): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - for decorator in node.decorator_list: - # Look for app.get(), app.post(), app.put(), etc. - if isinstance(decorator, ast.Call): - if isinstance(decorator.func, ast.Attribute): - if ( - isinstance(decorator.func.value, ast.Name) - and decorator.func.value.id == app_variable - and decorator.func.attr - in ["get", "post", "put", "delete", "patch"] - ): - return True - # Also check for @app.get without parentheses (decorator without Call) - elif isinstance(decorator, ast.Attribute): - if ( - isinstance(decorator.value, ast.Name) - and decorator.value.id == app_variable - and decorator.attr in ["get", "post", "put", "delete", "patch"] - ): - return True - - return False +def _find_remote_decorated(module: Any) -> Dict[str, Any]: + """find all objects with __remote_config__ in a module's namespace.""" + results: Dict[str, Any] = {} + for name in dir(module): + try: + obj = getattr(module, name) + except Exception: + continue + if hasattr(obj, "__remote_config__"): + results[name] = obj + return results -def _find_included_routers(tree: ast.AST, app_variable: str) -> List[Tuple[str, str]]: - """Find all routers included in FastAPI app via include_router(). - Args: - tree: AST of the main file - app_variable: FastAPI app variable name (e.g., 'app') +def _analyze_cross_calls_ast( + file_path: Path, + function_names: Set[str], + remote_function_names: Set[str], +) -> Dict[str, List[str]]: + """find which functions call other @remote functions via AST. - Returns: - List of (router_variable_name, prefix) tuples - Example: [('user_router', '/users'), ('api_router', '/api')] + returns a dict of function_name -> list of called remote function names. + only matches direct calls (foo()), not attribute calls (obj.foo()). """ - included_routers = [] + result: Dict[str, List[str]] = {} + try: + tree = ast.parse(file_path.read_text(encoding="utf-8")) + except Exception: + return result for node in ast.walk(tree): - # Look for: app.include_router(router_var, prefix="/path", ...) - if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call): - call = node.value - - # Check if it's app.include_router - if ( - isinstance(call.func, ast.Attribute) - and call.func.attr == "include_router" - and isinstance(call.func.value, ast.Name) - and call.func.value.id == app_variable - ): - # Extract router variable from first positional arg - if call.args and isinstance(call.args[0], ast.Name): - router_var = call.args[0].id - - # Extract prefix from keyword args - prefix = "" - for keyword in call.keywords: - if keyword.arg == "prefix" and isinstance( - keyword.value, ast.Constant - ): - prefix = keyword.value.value - - included_routers.append((router_var, prefix)) + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + continue + if node.name not in function_names: + continue - return included_routers + called: List[str] = [] + for child in ast.walk(node): + if isinstance(child, ast.Call) and isinstance(child.func, ast.Name): + if child.func.id in remote_function_names: + if child.func.id not in called: + called.append(child.func.id) + if called: + result[node.name] = called + return result -def _resolve_router_import( - router_variable: str, tree: ast.AST, main_file_path: Path, project_root: Path -) -> Optional[Path]: - """Resolve router variable to its source file via import statements. - Args: - router_variable: Name like 'user_router' - tree: AST of main.py - main_file_path: Path to main.py - project_root: Project root directory +class RuntimeScanner: + """discovers @remote decorated functions and Endpoint instances by importing modules. - Returns: - Path to module defining the router, or None if not found + imports each python file in the project, inspects live objects for + __remote_config__ attributes and Endpoint instances, and produces + RemoteFunctionMetadata for the build system. """ - for node in ast.walk(tree): - # Handle: from module import router_variable - if isinstance(node, ast.ImportFrom): - for alias in node.names: - imported_name = alias.asname if alias.asname else alias.name - if imported_name == router_variable: - # Found the import! - module_path = node.module - if not module_path: - continue - - # Handle relative imports - if node.level > 0: - # Relative import (e.g., from .routers import user_router) - parent_dir = main_file_path.parent - for _ in range(node.level - 1): - parent_dir = parent_dir.parent - module_parts = module_path.split(".") if module_path else [] - target_path = parent_dir.joinpath(*module_parts) - else: - # Absolute import (e.g., from routers.users import user_router) - module_parts = module_path.split(".") - target_path = project_root.joinpath(*module_parts) - - # Try both .py file and __init__.py in package - py_file = target_path.with_suffix(".py") - if py_file.exists(): - return py_file - - init_file = target_path / "__init__.py" - if init_file.exists(): - return init_file - - return None - - -# Supported HTTP methods for route extraction -SUPPORTED_HTTP_METHODS = ["get", "post", "put", "delete", "patch", "options", "head"] - -def _normalize_route_path(url_prefix: str, http_path: str) -> str: - """Normalize URL path by combining prefix and path with proper slashes. - - Args: - url_prefix: URL prefix (e.g., '/users' or '') - http_path: HTTP path (e.g., '/' or '/list') - - Returns: - Normalized path starting with '/' (e.g., '/users/list') - """ - # Handle empty cases - if not url_prefix and not http_path: - return "/" - if not url_prefix: - return f"/{http_path.lstrip('/')}" if http_path else "/" - if not http_path or http_path == "/": - # When path is just "/", append it to prefix - return f"/{url_prefix.strip('/')}/" + def __init__(self, project_dir: Path): + self.project_dir = project_dir + self.resource_configs: Dict[str, str] = {} + self.resource_types: Dict[str, str] = {} + self.resource_flags: Dict[str, Dict[str, bool]] = {} + self.resource_variables: Dict[str, str] = {} + # populated after discover_remote_functions() runs + self.import_errors: Dict[str, str] = {} - # Combine prefix and path - prefix_part = url_prefix.strip("/") - path_part = http_path.lstrip("/") + def discover_remote_functions(self) -> List[RemoteFunctionMetadata]: + """discover all @remote decorated functions and classes by importing modules.""" + spec = load_ignore_patterns(self.project_dir) + all_files = get_file_tree(self.project_dir, spec) + py_files = sorted( + f for f in all_files if f.suffix == ".py" and f.name != "__init__.py" + ) - return f"/{prefix_part}/{path_part}" + root_str = str(self.project_dir) + added_to_path = root_str not in sys.path + if added_to_path: + sys.path.insert(0, root_str) + synthetic_packages = self._register_parent_packages(py_files) -def _extract_routes_from_decorator( - tree: ast.AST, - decorator_object: str, - module_path: str, - url_prefix: str, - file_path: Path, -) -> List[RemoteFunctionMetadata]: - """Extract routes from decorators (@app.get, @router.post, etc.). + functions: List[RemoteFunctionMetadata] = [] + seen_functions: Set[str] = set() + failed_files: List[Path] = [] - Unified implementation for both FastAPI app and APIRouter route extraction. + try: + for py_file in py_files: + module_path = file_to_module_path(py_file, self.project_dir) + try: + module = _import_module_from_file(py_file, module_path) + except Exception as e: + failed_files.append(py_file) + rel_path = os.path.relpath(py_file, self.project_dir) + self.import_errors[rel_path] = f"{type(e).__name__}: {e}" + continue + if module is None: + failed_files.append(py_file) + continue - Args: - tree: AST of the file - decorator_object: Variable name ('app', 'user_router', etc.) - module_path: Module path for imports - url_prefix: URL prefix to prepend (empty string for no prefix) - file_path: Path to the source file + file_functions = self._extract_from_module( + module, module_path, py_file, seen_functions + ) + functions.extend(file_functions) + + # cross-call analysis + remote_names = {f.function_name for f in functions} + files_with_functions: Dict[Path, Set[str]] = {} + for f in functions: + files_with_functions.setdefault(f.file_path, set()).add(f.function_name) + + for file_path, func_names in files_with_functions.items(): + calls = _analyze_cross_calls_ast(file_path, func_names, remote_names) + for f in functions: + if f.file_path == file_path and f.function_name in calls: + f.calls_remote_functions = True + f.called_remote_functions = calls[f.function_name] + + finally: + if added_to_path: + try: + sys.path.remove(root_str) + except ValueError: + pass + for pkg_name in synthetic_packages: + sys.modules.pop(pkg_name, None) + + self._populate_resource_dicts(functions) + return functions - Returns: - List of RemoteFunctionMetadata for discovered routes - """ - routes = [] + def _register_parent_packages(self, py_files: List[Path]) -> List[str]: + """register synthetic parent packages in sys.modules for dotted imports. - for node in ast.walk(tree): - if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - continue + when a file lives in a subdirectory (e.g. workers/gpu.py), python + needs a 'workers' package in sys.modules for the dotted module name + 'workers.gpu' to resolve. instead of creating __init__.py files on + disk, we inject empty module objects into sys.modules. - for decorator in node.decorator_list: - # Look for @decorator_object.METHOD(...) pattern - if not isinstance(decorator, ast.Call): - continue + returns the list of package names that were added so the caller + can remove them after scanning. + """ + added: List[str] = [] + for f in py_files: + rel = f.relative_to(self.project_dir) + parts = rel.parent.parts + for i in range(len(parts)): + pkg_name = ".".join(parts[: i + 1]) + if pkg_name not in sys.modules: + pkg = types.ModuleType(pkg_name) + pkg.__path__ = [str(self.project_dir / Path(*parts[: i + 1]))] + pkg.__package__ = pkg_name + sys.modules[pkg_name] = pkg + added.append(pkg_name) + return added + + def _extract_from_module( + self, + module: Any, + module_path: str, + file_path: Path, + seen: Set[str], + ) -> List[RemoteFunctionMetadata]: + """extract RemoteFunctionMetadata from a single imported module.""" + results: List[RemoteFunctionMetadata] = [] - if not isinstance(decorator.func, ast.Attribute): - continue + remote_objects = _find_remote_decorated(module) + endpoint_instances = _find_endpoint_instances(module) - # Check decorator object match - if not ( - isinstance(decorator.func.value, ast.Name) - and decorator.func.value.id == decorator_object - ): + # map resource_config object id -> variable name + config_to_varname: Dict[int, str] = {} + for member_name in dir(module): + try: + member = getattr(module, member_name) + except Exception: continue - - # Get HTTP method - method = decorator.func.attr - if method not in SUPPORTED_HTTP_METHODS: + if isinstance(member, ServerlessResource): + config_to_varname[id(member)] = member_name + elif isinstance(member, Endpoint) and not member.is_client: + config_to_varname[id(member)] = member_name + # also map the internal cached resource config so that + # functions decorated via @ep.get("/path") can trace back + # to this variable name through their __remote_config__ + cached = getattr(member, "_cached_resource_config", None) + if cached is not None: + config_to_varname[id(cached)] = member_name + + for attr_name, obj in remote_objects.items(): + dedup_key = f"{module_path}:{attr_name}" + if dedup_key in seen: continue - # Extract path from first positional argument - http_path = None - if decorator.args and isinstance(decorator.args[0], ast.Constant): - http_path = decorator.args[0].value - - if not http_path: - continue + config = getattr(obj, "__remote_config__", {}) + resource_config = config.get("resource_config") + + var_name: Optional[str] = None + if resource_config is not None: + var_name = config_to_varname.get(id(resource_config)) + if var_name is None and isinstance(resource_config, Endpoint): + var_name = config_to_varname.get(id(resource_config)) + if var_name is None and hasattr(resource_config, "name"): + for ep_name, ep in endpoint_instances.items(): + if ep is resource_config or ( + hasattr(ep, "name") + and ep.name == getattr(resource_config, "name", None) + ): + var_name = ep_name + break - # Normalize path with prefix - full_path = _normalize_route_path(url_prefix, http_path) - - routes.append( - RemoteFunctionMetadata( - function_name=node.name, - module_path=module_path, - resource_config_name="load_balancer", - resource_type="CpuLiveLoadBalancer", - is_async=isinstance(node, ast.AsyncFunctionDef), - is_class=False, - file_path=file_path, - http_method=method.upper(), - http_path=full_path, - is_load_balanced=True, - is_live_resource=True, - config_variable=None, - ) + meta = _metadata_from_remote_config( + obj, attr_name, module_path, file_path, variable_name=var_name ) - break # Only process first matching decorator per function + if meta is not None: + seen.add(dedup_key) + results.append(meta) - return routes + return results + def _populate_resource_dicts(self, functions: List[RemoteFunctionMetadata]) -> None: + """populate resource tracking dicts for ManifestBuilder compatibility.""" + for f in functions: + name = f.resource_config_name + if name in self.resource_configs: + continue -def _extract_router_routes( - tree: ast.AST, - router_variable: str, - module_path: str, - url_prefix: str, - router_file: Path, -) -> List[RemoteFunctionMetadata]: - """Extract routes from APIRouter decorators (@router.get, etc.). - - Args: - tree: AST of module containing the router - router_variable: Router variable name (e.g., 'user_router') - module_path: Module path for imports - url_prefix: URL prefix from include_router (e.g., '/users') - router_file: Path to router file - - Returns: - Routes with prefixed paths - """ - return _extract_routes_from_decorator( - tree, router_variable, module_path, url_prefix, router_file - ) - - -def _extract_fastapi_routes( - tree: ast.AST, app_variable: str, module_path: str -) -> List[RemoteFunctionMetadata]: - """Extract routes from FastAPI decorators (@app.get, @app.post, etc.). - - Args: - tree: AST tree of the file - app_variable: FastAPI app variable name (e.g., 'app', 'router') - module_path: Module import path (e.g., 'main') - - Returns: - List of RemoteFunctionMetadata for each FastAPI route found - """ - # Use unified extraction with no prefix and placeholder path - # The caller will update file_path with the actual path - return _extract_routes_from_decorator( - tree, app_variable, module_path, "", Path(module_path) - ) + self.resource_configs[name] = name + self.resource_types[name] = f.resource_type + self.resource_flags[name] = { + "is_load_balanced": f.is_load_balanced, + "is_live_resource": f.is_live_resource, + } + if f.config_variable: + self.resource_variables[name] = f.config_variable diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 4d103947..1990f0c8 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -32,8 +32,9 @@ def __init__(self, **_kw): pass +from ..utils.ignore import get_file_tree, load_ignore_patterns from .build_utils.scanner import ( - RemoteDecoratorScanner, + RuntimeScanner, file_to_module_path, file_to_resource_name, file_to_url_prefix, @@ -92,23 +93,14 @@ class WorkerInfo: ) # fn_or_method_name -> first line of docstring -def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: - """Scan the project for all @remote decorated functions. +def _scan_project_workers( + project_root: Path, +) -> tuple[List[WorkerInfo], RuntimeScanner]: + """scan the project for all @remote decorated functions. - Walks all .py files (excluding .flash/, __pycache__, __init__.py) and - builds WorkerInfo for each file that contains @remote functions. - - Files with QB functions produce one WorkerInfo per file (QB type). - Files with LB functions produce one WorkerInfo per file (LB type). - A file can have both QB and LB functions (unusual but supported). - - Args: - project_root: Root directory of the Flash project - - Returns: - List of WorkerInfo, one entry per discovered source file + returns (workers, scanner) so callers can inspect import_errors. """ - scanner = RemoteDecoratorScanner(project_root) + scanner = RuntimeScanner(project_root) remote_functions = scanner.discover_remote_functions() # Group by file path @@ -189,7 +181,7 @@ def _scan_project_workers(project_root: Path) -> List[WorkerInfo]: ) ) - return workers + return workers, scanner def _ensure_gitignore(project_root: Path) -> None: @@ -875,7 +867,7 @@ def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> No if not py_changed: continue try: - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) _generate_flash_server(project_root, workers) logger.debug("server.py regenerated (%d changed)", len(py_changed)) except Exception as e: @@ -890,8 +882,9 @@ def _watch_and_regenerate(project_root: Path, stop_event: threading.Event) -> No def _discover_resources(project_root: Path): """Discover deployable resources in project files. - Uses ResourceDiscovery to find all DeployableResource instances by - parsing @remote decorators and importing the referenced config variables. + Imports each python file and inspects module-level objects for + DeployableResource instances and Endpoint facades. Endpoint facades + are unwrapped via _build_resource_config() to get the inner resource. Args: project_root: Root directory of the Flash project @@ -899,35 +892,65 @@ def _discover_resources(project_root: Path): Returns: List of discovered DeployableResource instances """ - from ...core.discovery import ResourceDiscovery + from ...core.resources.base import DeployableResource + from ...endpoint import Endpoint + + from .build_utils.scanner import _import_module_from_file + spec = load_ignore_patterns(project_root) + all_files = get_file_tree(project_root, spec) py_files = sorted( - p - for p in project_root.rglob("*.py") - if not any( - skip in p.parts - for skip in (".flash", ".venv", "venv", "__pycache__", ".git") - ) + f for f in all_files if f.suffix == ".py" and f.name != "__init__.py" ) - # Add project root to sys.path so cross-module imports resolve - # (e.g. api/routes.py doing "from longruns.stage1 import stage1_process"). root_str = str(project_root) added_to_path = root_str not in sys.path if added_to_path: sys.path.insert(0, root_str) resources = [] + seen_names: set[str] = set() try: for py_file in py_files: try: - discovery = ResourceDiscovery(str(py_file), max_depth=0) - resources.extend(discovery.discover()) + module_path = file_to_module_path(py_file, project_root) + module = _import_module_from_file(py_file, module_path) + if module is None: + continue + + for name in dir(module): + try: + obj = getattr(module, name) + except Exception: + continue + + resource = None + if isinstance(obj, DeployableResource): + resource = obj + elif isinstance(obj, Endpoint) and not obj.is_client: + resource = obj._build_resource_config() + elif hasattr(obj, "__remote_config__"): + cfg = getattr(obj, "__remote_config__", {}) + rc = cfg.get("resource_config") + if isinstance(rc, Endpoint) and not rc.is_client: + resource = rc._build_resource_config() + elif isinstance(rc, DeployableResource): + resource = rc + + if resource is not None: + res_name = getattr(resource, "name", None) or name + if res_name not in seen_names: + seen_names.add(res_name) + resources.append(resource) + except Exception as e: logger.debug("Discovery failed for %s: %s", py_file, e) finally: if added_to_path: - sys.path.remove(root_str) + try: + sys.path.remove(root_str) + except ValueError: + pass if resources: console.print(f"\n[dim]Discovered {len(resources)} resource(s):[/dim]") @@ -1016,27 +1039,27 @@ def run_command( ) # Discover @remote functions - workers = _scan_project_workers(project_root) + workers, scanner = _scan_project_workers(project_root) if not workers: - console.print("[red]Error:[/red] No endpoints found.") - console.print("Decorate your functions with @Endpoint to get started.") - console.print("\nQueue-based (one function per endpoint):") - console.print( - " from runpod_flash import Endpoint, GpuGroup\n" - "\n" - " @Endpoint(name='my-worker', gpu=GpuGroup.ANY)\n" - " async def process(input_data: dict) -> dict:\n" - " return {'result': input_data}" - ) - console.print("\nLoad-balanced (multiple routes, shared workers):") - console.print( - " api = Endpoint(name='my-api', cpu='cpu3g-2-8', workers=(1, 3))\n" - "\n" - " @api.post('/compute')\n" - " async def compute(data: dict) -> dict:\n" - " return {'result': data}" - ) + if scanner.import_errors: + console.print("\n[red bold]Failed to load:[/red bold]") + for filename, err in scanner.import_errors.items(): + console.print(f" [red]{filename}[/red]: {err}") + console.print() + else: + console.print( + "\n[red bold]No endpoints found.[/red bold]\n" + "\n" + " [dim]Queue-based:[/dim]\n" + " @Endpoint(name='worker', gpu=GpuGroup.ANY)\n" + " async def process(input_data: dict) -> dict: ...\n" + "\n" + " [dim]Load-balanced:[/dim]\n" + " api = Endpoint(name='api', cpu='cpu3g-2-8')\n" + " @api.post('/compute')\n" + " async def compute(data: dict) -> dict: ...\n" + ) raise typer.Exit(1) # find a free port, counting up from the requested one diff --git a/src/runpod_flash/client.py b/src/runpod_flash/client.py index dadaa13f..04d9ed64 100644 --- a/src/runpod_flash/client.py +++ b/src/runpod_flash/client.py @@ -241,6 +241,7 @@ def decorator(func_or_class): # Local execution mode - execute without provisioning remote servers if local: func_or_class.__remote_config__ = routing_config + func_or_class.__flash_local__ = True return func_or_class # Determine if we should execute locally or create a stub diff --git a/src/runpod_flash/core/discovery.py b/src/runpod_flash/core/discovery.py deleted file mode 100644 index 728d03ed..00000000 --- a/src/runpod_flash/core/discovery.py +++ /dev/null @@ -1,509 +0,0 @@ -"""Resource discovery for auto-provisioning during flash run startup.""" - -import ast -import importlib.util -import logging -import sys -from pathlib import Path -from typing import Dict, List, Set - -from .resources.base import DeployableResource - -log = logging.getLogger(__name__) - - -class ResourceDiscovery: - """Discovers DeployableResource instances by parsing @remote decorators.""" - - def __init__(self, entry_point: str, max_depth: int = 2): - """Initialize resource discovery. - - Args: - entry_point: Path to entry point file (e.g., "main.py") - max_depth: Maximum depth for recursive module scanning - """ - self.entry_point = Path(entry_point) - self.max_depth = max_depth - self._cache: Dict[str, List[DeployableResource]] = {} - self._scanned_modules: Set[str] = set() - - def discover(self) -> List[DeployableResource]: - """Discover all DeployableResource instances in entry point and imports. - - Returns: - List of discovered deployable resources - """ - if str(self.entry_point) in self._cache: - return self._cache[str(self.entry_point)] - - resources = [] - - try: - # Parse entry point to find @remote decorators - resource_var_names = self._find_resource_config_vars(self.entry_point) - - # Import entry point module to resolve variables (if any found) - if resource_var_names: - module = self._import_module(self.entry_point) - - if module: - # Resolve variable names to actual DeployableResource objects - for var_name in resource_var_names: - resource = self._resolve_resource_variable(module, var_name) - if resource: - resources.append(resource) - else: - log.warning(f"Failed to import {self.entry_point}") - - log.info(f"[Discovery] After entry point: {len(resources)} resource(s)") - - # Recursively scan imported modules (static imports) - imported_resources = self._scan_imports(self.entry_point, depth=1) - resources.extend(imported_resources) - - log.info(f"[Discovery] After static imports: {len(resources)} resource(s)") - - # Fallback: Scan project directory for Python files with @remote decorators - # This handles dynamic imports (importlib.util) that AST parsing misses - if not resources: - log.debug( - "No resources found via static imports, scanning project directory" - ) - directory_resources = self._scan_project_directory() - resources.extend(directory_resources) - log.info( - f"[Discovery] After directory scan: {len(resources)} resource(s)" - ) - - log.info(f"[Discovery] Total: {len(resources)} resource(s) discovered") - for res in resources: - res_name = getattr(res, "name", "Unknown") - res_type = res.__class__.__name__ - log.info(f"[Discovery] • {res_name} ({res_type})") - - # Cache results - self._cache[str(self.entry_point)] = resources - - except Exception as e: - log.error(f"Error discovering resources in {self.entry_point}: {e}") - - return resources - - def _find_resource_config_vars(self, file_path: Path) -> Set[str]: - """Find variable names used in @remote or @Endpoint decorators via AST parsing. - - Detects: - - @remote(resource_config=var) / @remote(var) patterns - - ep = Endpoint(...) variables used as LB route decorators (@ep.get, @ep.post, etc) - - @Endpoint(...) used directly as a function/class decorator (QB pattern) - - Args: - file_path: Path to Python file to parse - - Returns: - Set of variable names referenced in decorators - """ - var_names = set() - - try: - tree = ast.parse(file_path.read_text(encoding="utf-8")) - - # pass 1: find Endpoint variable assignments (ep = Endpoint(...)) - endpoint_vars = set() - for node in ast.walk(tree): - if isinstance(node, ast.Assign) and len(node.targets) == 1: - target = node.targets[0] - if isinstance(target, ast.Name) and isinstance( - node.value, ast.Call - ): - func = node.value.func - call_name = None - if isinstance(func, ast.Name): - call_name = func.id - elif isinstance(func, ast.Attribute): - call_name = func.attr - if call_name == "Endpoint": - endpoint_vars.add(target.id) - - # pass 2: find decorator references - for node in ast.walk(tree): - if isinstance( - node, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef) - ): - for decorator in node.decorator_list: - if self._is_remote_decorator(decorator): - var_name = self._extract_resource_config_var(decorator) - if var_name: - var_names.add(var_name) - - # @ep.get("/path"), @ep.post("/path"), etc - elif self._is_endpoint_route_decorator( - decorator, endpoint_vars - ): - var_name = self._extract_endpoint_var_from_route(decorator) - if var_name: - var_names.add(var_name) - - # @Endpoint(name=..., gpu=...) directly on function/class (QB) - elif self._is_endpoint_direct_decorator(decorator): - var_names.add(node.name) - - except Exception as e: - log.warning(f"Failed to parse {file_path}: {e}") - - return var_names - - def _is_endpoint_route_decorator( - self, decorator: ast.expr, endpoint_vars: Set[str] - ) -> bool: - """Check if decorator is @ep.get/post/put/delete/patch for a known Endpoint variable.""" - if not isinstance(decorator, ast.Call): - return False - func = decorator.func - if not isinstance(func, ast.Attribute): - return False - if func.attr not in ("get", "post", "put", "delete", "patch"): - return False - if isinstance(func.value, ast.Name) and func.value.id in endpoint_vars: - return True - return False - - def _extract_endpoint_var_from_route(self, decorator: ast.Call) -> str: - """Extract the Endpoint variable name from @ep.get("/path") decorator.""" - func = decorator.func - if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): - return func.value.id - return "" - - def _is_endpoint_direct_decorator(self, decorator: ast.expr) -> bool: - """Check if decorator is @Endpoint(...) used directly on a function/class (QB pattern). - - Matches @Endpoint(name=..., gpu=...) but NOT @ep.get()/@ep.post() (which are - attribute calls on an Endpoint variable, handled separately). - """ - if not isinstance(decorator, ast.Call): - return False - func = decorator.func - if isinstance(func, ast.Name) and func.id == "Endpoint": - return True - if isinstance(func, ast.Attribute) and func.attr == "Endpoint": - return True - return False - - def _is_remote_decorator(self, decorator: ast.expr) -> bool: - """Check if decorator is @remote. - - Args: - decorator: AST decorator node - - Returns: - True if decorator is @remote - """ - if isinstance(decorator, ast.Call): - func_name = None - if isinstance(decorator.func, ast.Name): - func_name = decorator.func.id - elif isinstance(decorator.func, ast.Attribute): - func_name = decorator.func.attr - - return func_name == "remote" - - return False - - def _extract_resource_config_var(self, decorator: ast.Call) -> str: - """Extract resource_config variable name from @remote decorator. - - Handles both: - - @remote(resource_config=my_config) - - @remote(my_config) (positional argument) - - Args: - decorator: AST Call node for @remote decorator - - Returns: - Variable name or empty string - """ - # Check keyword argument: resource_config=var_name - for keyword in decorator.keywords: - if keyword.arg == "resource_config": - if isinstance(keyword.value, ast.Name): - return keyword.value.id - - # Check positional argument: @remote(var_name) - if decorator.args and isinstance(decorator.args[0], ast.Name): - return decorator.args[0].id - - return "" - - def _import_module(self, file_path: Path): - """Import a Python module from file path. - - Args: - file_path: Path to Python file - - Returns: - Imported module or None if import fails - """ - try: - # Create module spec - module_name = file_path.stem - spec = importlib.util.spec_from_file_location(module_name, file_path) - - if not spec or not spec.loader: - return None - - # Load module - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - - self._scanned_modules.add(module_name) - - return module - - except Exception as e: - log.warning(f"Failed to import {file_path}: {e}") - return None - - def _resolve_resource_variable(self, module, var_name: str) -> DeployableResource: - """Resolve variable name to DeployableResource instance. - - Handles: - - Legacy resource config objects (LiveServerless, etc) - - Endpoint facade objects (unwraps via _build_resource_config()) - - QB-decorated functions/classes (unwraps __remote_config__["resource_config"]) - - Args: - module: Imported module - var_name: Variable name to resolve - - Returns: - DeployableResource instance or None - """ - try: - obj = getattr(module, var_name, None) - - if obj and isinstance(obj, DeployableResource): - return obj - - # unwrap Endpoint facade to its internal resource config - if obj and hasattr(obj, "_build_resource_config"): - resource = obj._build_resource_config() - if isinstance(resource, DeployableResource): - return resource - - # unwrap @Endpoint(...)-decorated function/class (QB pattern). - # Endpoint.__call__ wraps via @remote which attaches __remote_config__ - if obj is not None and hasattr(obj, "__remote_config__"): - resource = obj.__remote_config__.get("resource_config") - if isinstance(resource, DeployableResource): - return resource - - if obj is not None: - log.warning( - f"Resource '{var_name}' failed to resolve to DeployableResource " - f"(found type: {type(obj).__name__}). " - f"Check that '{var_name}' is defined as a ServerlessResource, " - f"Endpoint, or other DeployableResource type." - ) - - except Exception as e: - log.warning(f"Failed to resolve variable '{var_name}': {e}") - - return None - - def _scan_imports(self, file_path: Path, depth: int) -> List[DeployableResource]: - """Recursively scan imported modules for resources. - - Args: - file_path: Path to Python file - depth: Current recursion depth - - Returns: - List of discovered resources from imports - """ - if depth > self.max_depth: - return [] - - resources = [] - - try: - tree = ast.parse(file_path.read_text(encoding="utf-8")) - - # Find import statements - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - module_name = alias.name - if module_name not in self._scanned_modules: - imported_resources = self._scan_imported_module( - module_name, depth - ) - resources.extend(imported_resources) - - elif isinstance(node, ast.ImportFrom): - if node.module and node.module not in self._scanned_modules: - imported_resources = self._scan_imported_module( - node.module, depth - ) - resources.extend(imported_resources) - - except Exception as e: - log.debug(f"Failed to scan imports in {file_path}: {e}") - - return resources - - def _scan_imported_module( - self, module_name: str, depth: int - ) -> List[DeployableResource]: - """Scan an imported module for resources. - - Args: - module_name: Name of module to scan - depth: Current recursion depth - - Returns: - List of discovered resources - """ - resources = [] - - try: - # Try to find module file - module_path = self._resolve_module_path(module_name) - - if not module_path or not module_path.exists(): - return [] - - # Mark as scanned to avoid cycles - self._scanned_modules.add(module_name) - - # Find resources in this module - resource_vars = self._find_resource_config_vars(module_path) - - if resource_vars: - # Import module and resolve variables - module = self._import_module(module_path) - if module: - for var_name in resource_vars: - resource = self._resolve_resource_variable(module, var_name) - if resource: - resources.append(resource) - - # Recursively scan imports - imported_resources = self._scan_imports(module_path, depth + 1) - resources.extend(imported_resources) - - except Exception as e: - log.debug(f"Failed to scan imported module '{module_name}': {e}") - - return resources - - def _resolve_module_path(self, module_name: str) -> Path: - """Resolve module name to file path. - - Args: - module_name: Name of module (e.g., "workers.gpu") - - Returns: - Path to module file or None - """ - try: - # Handle relative imports from entry point directory - parts = module_name.split(".") - current_dir = self.entry_point.parent - - # Try as relative path first - module_path = current_dir.joinpath(*parts) - - # Check for .py file - if module_path.with_suffix(".py").exists(): - return module_path.with_suffix(".py") - - # Check for package (__init__.py) - if (module_path / "__init__.py").exists(): - return module_path / "__init__.py" - - except Exception as e: - log.debug(f"Failed to resolve module path for '{module_name}': {e}") - - return None - - def _scan_project_directory(self) -> List[DeployableResource]: - """Scan project directory for Python files with @remote decorators. - - This is a fallback for projects that use dynamic imports (importlib.util) - which cannot be detected via static AST import scanning. - - Returns: - List of discovered resources - """ - resources = [] - project_root = self.entry_point.parent - - try: - # Find all Python files in project (excluding common ignore patterns) - python_files = [] - for pattern in ["**/*.py"]: - for file_path in project_root.glob(pattern): - # Skip entry point (already processed) - if file_path == self.entry_point: - continue - - # Skip common directories - rel_path = str(file_path.relative_to(project_root)) - if any( - skip in rel_path - for skip in [ - ".venv/", - "venv/", - "__pycache__/", - ".git/", - "site-packages/", - ".pytest_cache/", - "build/", - "dist/", - ".tox/", - "node_modules/", - ".flash/", - ] - ): - continue - - python_files.append(file_path) - - log.debug(f"Scanning {len(python_files)} Python files in {project_root}") - - # Check each file for @remote or Endpoint decorators - for file_path in python_files: - try: - # quick check: does file contain relevant patterns? - content = file_path.read_text(encoding="utf-8") - if "@remote" not in content and "Endpoint(" not in content: - continue - - # Find resource config variables via AST - resource_vars = self._find_resource_config_vars(file_path) - if not resource_vars: - continue - - # Import module and resolve variables - module = self._import_module(file_path) - if module: - for var_name in resource_vars: - resource = self._resolve_resource_variable(module, var_name) - if resource: - resources.append(resource) - - except Exception as e: - log.debug(f"Failed to scan {file_path}: {e}") - continue - - except Exception as e: - log.warning(f"Failed to scan project directory: {e}") - - return resources - - def clear_cache(self): - """Clear discovery cache (for reload mode).""" - self._cache.clear() - self._scanned_modules.clear() diff --git a/src/runpod_flash/endpoint.py b/src/runpod_flash/endpoint.py index 5d3f51dc..9c0c3f7c 100644 --- a/src/runpod_flash/endpoint.py +++ b/src/runpod_flash/endpoint.py @@ -354,9 +354,11 @@ def __init__( "id and image are mutually exclusive. id= connects to an " "existing endpoint, image= deploys a new one." ) - if name is None and id is None: - raise ValueError("name or id is required.") + if name is None and id is None and image is not None: + raise ValueError("name or id is required when image= is set.") + # name can be None here for QB decorator mode (@Endpoint(gpu=...)). + # it gets derived from the decorated function/class in __call__(). self.name = name self.id = id self._gpu = _normalize_gpu(gpu) @@ -566,6 +568,10 @@ async def process(data: dict) -> dict: ... "routes with .get()/.post()/etc. use one pattern or the other." ) + # auto-derive name from the decorated function/class if not provided + if self.name is None: + self.name = func_or_class.__name__ + self._qb_target = func_or_class resource_config = self._build_resource_config() @@ -583,6 +589,11 @@ async def process(data: dict) -> dict: ... def _route(self, method: str, path: str): """register an http route on this endpoint (lb mode).""" + if self.name is None: + raise ValueError( + "name is required for load-balanced endpoints. " + "use Endpoint(name='my-api', ...) when registering routes." + ) method = method.upper() if method not in _VALID_HTTP_METHODS: raise ValueError( diff --git a/src/runpod_flash/execute_class.py b/src/runpod_flash/execute_class.py index 643bc378..1ef809d9 100644 --- a/src/runpod_flash/execute_class.py +++ b/src/runpod_flash/execute_class.py @@ -202,6 +202,9 @@ def create_remote_class( raise ValueError("Class must have a __name__ attribute") class RemoteClassWrapper: + # store a reference to the original class for introspection + _wrapped_class = cls + def __init__(self, *args, **kwargs): self._class_type = cls self._resource_config = resource_config diff --git a/src/runpod_flash/runtime/resource_provisioner.py b/src/runpod_flash/runtime/resource_provisioner.py index 356f32f2..0c58ebbe 100644 --- a/src/runpod_flash/runtime/resource_provisioner.py +++ b/src/runpod_flash/runtime/resource_provisioner.py @@ -143,6 +143,8 @@ def create_resource_from_manifest( deployment_kwargs["scalerValue"] = resource_data["scalerValue"] if "instanceIds" in resource_data: deployment_kwargs["instanceIds"] = resource_data["instanceIds"] + if "idleTimeout" in resource_data: + deployment_kwargs["idleTimeout"] = resource_data["idleTimeout"] # Reconstruct NetworkVolume from manifest data if present if "networkVolume" in resource_data: diff --git a/tests/integration/test_build_pipeline.py b/tests/integration/test_build_pipeline.py index 8a644308..8ba1624e 100644 --- a/tests/integration/test_build_pipeline.py +++ b/tests/integration/test_build_pipeline.py @@ -1,6 +1,6 @@ """Integration tests for the build pipeline. -Exercises: scanner → manifest → handler generation with real code paths. +Exercises: scanner -> manifest -> handler generation with real code paths. Only external I/O (install_dependencies) is mocked. """ @@ -10,9 +10,7 @@ from runpod_flash.cli.commands.build_utils.handler_generator import HandlerGenerator from runpod_flash.cli.commands.build_utils.manifest import ManifestBuilder -from runpod_flash.cli.commands.build_utils.scanner import ( - RemoteDecoratorScanner, -) +from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner @pytest.fixture() @@ -21,7 +19,7 @@ def build_project(tmp_path): project_dir = tmp_path / "myproject" project_dir.mkdir() - # Create a .flashignore so scanner doesn't try .gitignore + # create a .flashignore so scanner doesn't try .gitignore (project_dir / ".flashignore").write_text("") return project_dir @@ -35,10 +33,10 @@ def _write_worker_file(project_dir, filename, content): class TestBuildProducesValidManifest: - """Scanner → ManifestBuilder → valid manifest JSON.""" + """Scanner -> ManifestBuilder -> valid manifest JSON.""" def test_build_produces_valid_manifest(self, build_project): - """Minimal @remote file → scanner → manifest with correct structure.""" + """Minimal @remote file -> scanner -> manifest with correct structure.""" _write_worker_file( build_project, "worker.py", @@ -53,12 +51,13 @@ def process(data): """, ) - scanner = RemoteDecoratorScanner(build_project) + scanner = RuntimeScanner(build_project) functions = scanner.discover_remote_functions() assert len(functions) == 1 assert functions[0].function_name == "process" - assert functions[0].resource_config_name == "test-gpu" + # LiveServerless appends -fb suffix + assert functions[0].resource_config_name == "test-gpu-fb" builder = ManifestBuilder( project_name="test-project", @@ -70,14 +69,14 @@ def process(data): assert "version" in manifest assert "resources" in manifest assert "function_registry" in manifest - assert "test-gpu" in manifest["resources"] + assert "test-gpu-fb" in manifest["resources"] - resource = manifest["resources"]["test-gpu"] + resource = manifest["resources"]["test-gpu-fb"] assert len(resource["functions"]) == 1 assert resource["functions"][0]["name"] == "process" def test_build_with_multiple_resources(self, build_project): - """File with GPU + CPU resources → both appear in manifest.""" + """File with GPU + CPU resources -> both appear in manifest.""" _write_worker_file( build_project, "worker.py", @@ -97,7 +96,7 @@ def cpu_task(x): """, ) - scanner = RemoteDecoratorScanner(build_project) + scanner = RuntimeScanner(build_project) functions = scanner.discover_remote_functions() assert len(functions) == 2 @@ -109,8 +108,9 @@ def cpu_task(x): ) manifest = builder.build() - assert "my-gpu" in manifest["resources"] - assert "my-cpu" in manifest["resources"] + # Live* classes append -fb suffix + assert "my-gpu-fb" in manifest["resources"] + assert "my-cpu-fb" in manifest["resources"] assert len(manifest["function_registry"]) == 2 diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py index 11a3d14a..296714a3 100644 --- a/tests/integration/test_lb_remote_execution.py +++ b/tests/integration/test_lb_remote_execution.py @@ -160,7 +160,7 @@ async def qb_func(): def test_scanner_discovers_load_balancer_resources(self): """Test that scanner can discover LiveLoadBalancer and LoadBalancerSlsResource.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner from pathlib import Path import tempfile @@ -188,7 +188,7 @@ def get_status(): py_file = project_dir / "test_api.py" py_file.write_text(code) - scanner = RemoteDecoratorScanner(project_dir) + scanner = RuntimeScanner(project_dir) functions = scanner.discover_remote_functions() # Verify both resources were discovered @@ -200,7 +200,9 @@ def get_status(): assert "LoadBalancerSlsResource" in resource_types # Verify resource configs were extracted - assert "test-api" in scanner.resource_types - assert scanner.resource_types["test-api"] == "LiveLoadBalancer" - assert "deployed-api" in scanner.resource_types - assert scanner.resource_types["deployed-api"] == "LoadBalancerSlsResource" + assert "test-api-fb" in scanner.resource_types + assert scanner.resource_types["test-api-fb"] == "LiveLoadBalancer" + assert "deployed-api-fb" in scanner.resource_types + assert ( + scanner.resource_types["deployed-api-fb"] == "LoadBalancerSlsResource" + ) diff --git a/tests/unit/cli/commands/build_utils/test_path_utilities.py b/tests/unit/cli/commands/build_utils/test_path_utilities.py index 73ec3557..8f993385 100644 --- a/tests/unit/cli/commands/build_utils/test_path_utilities.py +++ b/tests/unit/cli/commands/build_utils/test_path_utilities.py @@ -140,7 +140,7 @@ class TestIsLbRouteHandlerField: def test_lb_function_with_method_and_path_is_handler(self, tmp_path): """An LB @remote function with method= and path= is marked as LB route handler.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner (tmp_path / "routes.py").write_text( """ @@ -154,7 +154,7 @@ async def compute(data: dict) -> dict: """ ) - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) functions = scanner.discover_remote_functions() assert len(functions) == 1 @@ -162,7 +162,7 @@ async def compute(data: dict) -> dict: def test_qb_function_is_not_handler(self, tmp_path): """A QB @remote function is NOT marked as LB route handler.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner (tmp_path / "worker.py").write_text( """ @@ -176,7 +176,7 @@ async def process(data: dict) -> dict: """ ) - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) functions = scanner.discover_remote_functions() assert len(functions) == 1 @@ -184,7 +184,7 @@ async def process(data: dict) -> dict: def test_init_py_files_excluded(self, tmp_path): """__init__.py files are excluded from scanning.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner (tmp_path / "__init__.py").write_text( """ @@ -209,7 +209,7 @@ async def process(data: dict) -> dict: """ ) - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) functions = scanner.discover_remote_functions() # Only the worker.py function should be discovered, not __init__.py diff --git a/tests/unit/cli/commands/build_utils/test_scanner.py b/tests/unit/cli/commands/build_utils/test_scanner.py deleted file mode 100644 index 1343b1e6..00000000 --- a/tests/unit/cli/commands/build_utils/test_scanner.py +++ /dev/null @@ -1,1480 +0,0 @@ -"""Tests for RemoteDecoratorScanner.""" - -import tempfile -from pathlib import Path - -import runpod_flash -from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner -from runpod_flash.core.resources.serverless import ServerlessResource - - -def test_discover_simple_function(): - """Test discovering a simple @remote function.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - # Create a simple test file - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_config = LiveServerless(name="test_gpu") - -@remote(gpu_config) -async def my_function(data): - return processed_data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].function_name == "my_function" - assert functions[0].resource_config_name == "test_gpu" - assert functions[0].is_async is True - assert functions[0].is_class is False - - -def test_discover_class(): - """Test discovering a @remote class.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_config = LiveServerless(name="test_gpu") - -@remote(gpu_config) -class MyModel: - def __init__(self): - pass - - def process(self, data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].function_name == "MyModel" - assert functions[0].is_class is True - - -def test_discover_multiple_functions_same_config(): - """Test discovering multiple functions with same resource config.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_config = LiveServerless(name="gpu_worker") - -@remote(gpu_config) -async def process_data(data): - return data - -@remote(gpu_config) -async def analyze_data(data): - return analysis -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 2 - assert all(f.resource_config_name == "gpu_worker" for f in functions) - assert functions[0].function_name in ["process_data", "analyze_data"] - - -def test_discover_functions_different_configs(): - """Test discovering functions with different resource configs.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, CpuLiveServerless, remote - -gpu_config = LiveServerless(name="gpu_worker") -cpu_config = CpuLiveServerless(name="cpu_worker") - -@remote(gpu_config) -async def gpu_task(data): - return data - -@remote(cpu_config) -async def cpu_task(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 2 - resource_configs = {f.resource_config_name for f in functions} - assert resource_configs == {"gpu_worker", "cpu_worker"} - - -def test_discover_nested_module(): - """Test discovering functions in nested modules.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - # Create nested structure - workers_dir = project_dir / "workers" / "gpu" - workers_dir.mkdir(parents=True) - - test_file = workers_dir / "inference.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="gpu_inference") - -@remote(config) -async def inference(model, data): - return results -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].module_path == "workers.gpu.inference" - assert functions[0].function_name == "inference" - - -def test_discover_inline_config(): - """Test discovering with inline resource config.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -@remote(LiveServerless(name="inline_config")) -async def my_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].resource_config_name == "inline_config" - - -def test_ignore_non_remote_functions(): - """Test that non-decorated functions are ignored.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -async def normal_function(data): - return data - -class NormalClass: - pass -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 0 - - -def test_discover_sync_function(): - """Test discovering synchronous @remote function.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="cpu_sync") - -@remote(config) -def sync_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].is_async is False - - -def test_exclude_venv_directory(): - """Test that .venv directory is excluded from scanning.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - # Create .venv directory with Python files - venv_dir = project_dir / ".venv" / "lib" / "python3.11" - venv_dir.mkdir(parents=True) - venv_file = venv_dir / "test_module.py" - venv_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="venv_config") - -@remote(config) -async def venv_function(data): - return data -""" - ) - - # Create legitimate project file - project_file = project_dir / "main.py" - project_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="project_config") - -@remote(config) -async def project_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - # Should only find the project function, not the venv one - assert len(functions) == 1 - assert functions[0].resource_config_name == "project_config" - - -def test_exclude_flash_directory(): - """Test that .flash directory is excluded from scanning.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - # Create .flash directory with Python files - flash_dir = project_dir / ".flash" / "build" - flash_dir.mkdir(parents=True) - flash_file = flash_dir / "generated.py" - flash_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="flash_config") - -@remote(config) -async def flash_function(data): - return data -""" - ) - - # Create legitimate project file - project_file = project_dir / "main.py" - project_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="project_config") - -@remote(config) -async def project_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - # Should only find the project function, not the flash one - assert len(functions) == 1 - assert functions[0].resource_config_name == "project_config" - - -def test_exclude_runpod_directory(): - """Test that .runpod directory is excluded from scanning.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - # Create .runpod directory with Python files - runpod_dir = project_dir / ".runpod" / "cache" - runpod_dir.mkdir(parents=True) - runpod_file = runpod_dir / "cached.py" - runpod_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="runpod_config") - -@remote(config) -async def runpod_function(data): - return data -""" - ) - - # Create legitimate project file - project_file = project_dir / "main.py" - project_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="project_config") - -@remote(config) -async def project_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - # Should only find the project function, not the runpod one - assert len(functions) == 1 - assert functions[0].resource_config_name == "project_config" - - -def test_exclude_nested_venv_directory(): - """Test that nested .venv directories (not just root-level) are excluded.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - # Create a nested .venv inside a subdirectory (the original bug) - nested_venv = project_dir / "subproject" / ".venv" / "lib" / "python3.11" - nested_venv.mkdir(parents=True) - venv_file = nested_venv / "site_package.py" - venv_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="nested_venv_config") - -@remote(config) -async def nested_venv_function(data): - return data -""" - ) - - # Create legitimate project file - project_file = project_dir / "worker.py" - project_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="project_config") - -@remote(config) -async def project_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - # Should only find the project function, not the nested venv one - assert len(functions) == 1 - assert functions[0].resource_config_name == "project_config" - - -def test_fallback_to_variable_name_when_name_parameter_missing(): - """Test that variable name is used when resource config has no name= parameter.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_config = LiveServerless() - -@remote(gpu_config) -async def my_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - # Should fall back to variable name when name parameter is missing - assert functions[0].resource_config_name == "gpu_config" - - -def test_ignore_non_serverless_classes_with_serverless_in_name(): - """Test that helper classes with 'Serverless' in name are ignored.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -class MyServerlessHelper: - def __init__(self): - pass - -helper = MyServerlessHelper() -config = LiveServerless(name="real_config") - -@remote(config) -async def my_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - # Should find function with real config but ignore helper class - assert len(functions) == 1 - assert functions[0].resource_config_name == "real_config" - - -def test_extract_resource_name_with_special_characters(): - """Test that resource names with special characters are extracted correctly.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "test_module.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="01_gpu-worker.v1") - -@remote(config) -async def my_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - # Should preserve special characters in resource name - assert functions[0].resource_config_name == "01_gpu-worker.v1" - - -def test_scanner_extracts_config_variable_names(): - """Test that scanner captures config variable names.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - test_file = project_dir / "endpoint.py" - - test_file.write_text( - """ -from runpod_flash import LiveLoadBalancer, remote - -gpu_config = LiveLoadBalancer(name="my-endpoint") - -@remote(gpu_config, method="GET", path="/health") -async def health(): - return {"status": "ok"} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].config_variable == "gpu_config" - assert functions[0].resource_config_name == "my-endpoint" - - -def test_cpu_live_load_balancer_flags(): - """Test that CpuLiveLoadBalancer is correctly flagged as load-balanced and live.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - test_file = project_dir / "cpu_endpoint.py" - - test_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer, remote - -cpu_config = CpuLiveLoadBalancer(name="cpu_worker") - -@remote(cpu_config, method="POST", path="/validate") -async def validate_data(text): - return {"valid": True} - -@remote(cpu_config, method="GET", path="/health") -async def health(): - return {"status": "ok"} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 2 - - # Check that both functions have the correct flags - for func in functions: - assert func.resource_config_name == "cpu_worker" - assert func.is_load_balanced is True, ( - "CpuLiveLoadBalancer should be marked as load-balanced" - ) - assert func.is_live_resource is True, ( - "CpuLiveLoadBalancer should be marked as live resource" - ) - assert func.resource_type == "CpuLiveLoadBalancer" - - # Check specific HTTP metadata for each function - validate_func = next(f for f in functions if f.function_name == "validate_data") - assert validate_func.http_method == "POST" - assert validate_func.http_path == "/validate" - - health_func = next(f for f in functions if f.function_name == "health") - assert health_func.http_method == "GET" - assert health_func.http_path == "/health" - - -def test_extract_fastapi_routes(): - """Test that FastAPI routes are extracted from decorators.""" - import ast - from runpod_flash.cli.commands.build_utils.scanner import _extract_fastapi_routes - - code = """ -from fastapi import FastAPI - -app = FastAPI() - -@app.get("/") -def home(): - return {"hello": "world"} - -@app.post("/api/users") -async def create_user(): - return {} - -@app.put("/api/users/{id}") -async def update_user(id: int): - return {"id": id} -""" - - tree = ast.parse(code) - routes = _extract_fastapi_routes(tree, "app", "main") - - assert len(routes) == 3 - - # Check first route (GET /) - home_route = next(r for r in routes if r.function_name == "home") - assert home_route.http_method == "GET" - assert home_route.http_path == "/" - assert home_route.is_async is False - assert home_route.is_load_balanced is True - assert home_route.is_live_resource is True - - # Check second route (POST /api/users) - create_route = next(r for r in routes if r.function_name == "create_user") - assert create_route.http_method == "POST" - assert create_route.http_path == "/api/users" - assert create_route.is_async is True - - # Check third route (PUT /api/users/{id}) - update_route = next(r for r in routes if r.function_name == "update_user") - assert update_route.http_method == "PUT" - assert update_route.http_path == "/api/users/{id}" - assert update_route.is_async is True - - -def test_extract_fastapi_routes_with_different_app_variable(): - """Test that FastAPI routes work with different app variable names.""" - import ast - from runpod_flash.cli.commands.build_utils.scanner import _extract_fastapi_routes - - code = """ -from fastapi import FastAPI - -router = FastAPI() - -@router.get("/health") -def health_check(): - return {"status": "ok"} -""" - - tree = ast.parse(code) - routes = _extract_fastapi_routes(tree, "router", "main") - - assert len(routes) == 1 - assert routes[0].function_name == "health_check" - assert routes[0].http_method == "GET" - assert routes[0].http_path == "/health" - - -def test_extract_fastapi_routes_ignores_non_matching(): - """Test that only matching app variable routes are extracted.""" - import ast - from runpod_flash.cli.commands.build_utils.scanner import _extract_fastapi_routes - - code = """ -from fastapi import FastAPI - -app = FastAPI() -other_app = FastAPI() - -@app.get("/") -def home(): - return {} - -@other_app.get("/other") -def other(): - return {} -""" - - tree = ast.parse(code) - routes = _extract_fastapi_routes(tree, "app", "main") - - # Should only extract routes from 'app', not 'other_app' - assert len(routes) == 1 - assert routes[0].function_name == "home" - - -def test_extract_routes_from_single_included_router(): - """Test that routes from a single included router are discovered.""" - from runpod_flash.cli.commands.build_utils.scanner import detect_main_app - - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create main.py with router include - main_file = project_root / "main.py" - main_file.write_text(""" -from fastapi import FastAPI -from routers import user_router - -app = FastAPI() -app.include_router(user_router, prefix="/users") - -@app.get("/") -def home(): - return {"msg": "Home"} -""") - - # Create routers/__init__.py - routers_dir = project_root / "routers" - routers_dir.mkdir() - (routers_dir / "__init__.py").write_text(""" -from fastapi import APIRouter - -user_router = APIRouter() - -@user_router.get("/") -def list_users(): - return [] - -@user_router.post("/") -def create_user(): - return {} -""") - - # Detect routes - main_app_config = detect_main_app(project_root, explicit_lb_exists=False) - - assert main_app_config is not None - routes = main_app_config["fastapi_routes"] - - # Should find 3 routes total: app.get + 2 router routes - assert len(routes) == 3 - - # Check home route - home = next(r for r in routes if r.function_name == "home") - assert home.http_path == "/" - assert home.http_method == "GET" - - # Check router routes have prefix applied - list_users = next(r for r in routes if r.function_name == "list_users") - assert list_users.http_path == "/users/" - assert list_users.http_method == "GET" - - create_user = next(r for r in routes if r.function_name == "create_user") - assert create_user.http_path == "/users/" - assert create_user.http_method == "POST" - - -def test_multiple_included_routers(): - """Test multiple routers with different prefixes.""" - from runpod_flash.cli.commands.build_utils.scanner import detect_main_app - - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text(""" -from fastapi import FastAPI -from routers import user_router, admin_router - -app = FastAPI() -app.include_router(user_router, prefix="/users") -app.include_router(admin_router, prefix="/admin") -""") - - routers_dir = project_root / "routers" - routers_dir.mkdir() - (routers_dir / "__init__.py").write_text(""" -from fastapi import APIRouter - -user_router = APIRouter() -admin_router = APIRouter() - -@user_router.get("/list") -def list_users(): - return [] - -@admin_router.get("/dashboard") -def admin_dashboard(): - return {} -""") - - main_app_config = detect_main_app(project_root, explicit_lb_exists=False) - - routes = main_app_config["fastapi_routes"] - assert len(routes) == 2 - - user_route = next(r for r in routes if "users" in r.http_path) - assert user_route.http_path == "/users/list" - - admin_route = next(r for r in routes if "admin" in r.http_path) - assert admin_route.http_path == "/admin/dashboard" - - -def test_router_import_not_found(): - """Test that missing router files are handled gracefully.""" - from runpod_flash.cli.commands.build_utils.scanner import detect_main_app - - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text(""" -from fastapi import FastAPI -from nonexistent import some_router - -app = FastAPI() -app.include_router(some_router, prefix="/api") - -@app.get("/") -def home(): - return {} -""") - - main_app_config = detect_main_app(project_root, explicit_lb_exists=False) - - # Should still work, just skip the missing router - routes = main_app_config["fastapi_routes"] - assert len(routes) == 1 # Only the home route - assert routes[0].http_path == "/" - - -def test_router_with_no_prefix(): - """Test router included without a prefix.""" - from runpod_flash.cli.commands.build_utils.scanner import detect_main_app - - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text(""" -from fastapi import FastAPI -from routers import api_router - -app = FastAPI() -app.include_router(api_router) - -@app.get("/") -def home(): - return {} -""") - - routers_dir = project_root / "routers" - routers_dir.mkdir() - (routers_dir / "__init__.py").write_text(""" -from fastapi import APIRouter - -api_router = APIRouter() - -@api_router.get("/data") -def get_data(): - return {} -""") - - main_app_config = detect_main_app(project_root, explicit_lb_exists=False) - - routes = main_app_config["fastapi_routes"] - assert len(routes) == 2 - - # Router route should not have prefix - data_route = next(r for r in routes if r.function_name == "get_data") - assert data_route.http_path == "/data" - - -def test_router_in_separate_module_file(): - """Test router defined in a separate .py file (not __init__.py).""" - from runpod_flash.cli.commands.build_utils.scanner import detect_main_app - - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text(""" -from fastapi import FastAPI -from routers.users import user_router - -app = FastAPI() -app.include_router(user_router, prefix="/users") -""") - - routers_dir = project_root / "routers" - routers_dir.mkdir() - (routers_dir / "__init__.py").write_text("") # Empty init - - # Router in separate file - (routers_dir / "users.py").write_text(""" -from fastapi import APIRouter - -user_router = APIRouter() - -@user_router.get("/{user_id}") -def get_user(user_id: int): - return {} -""") - - main_app_config = detect_main_app(project_root, explicit_lb_exists=False) - - routes = main_app_config["fastapi_routes"] - assert len(routes) == 1 - assert routes[0].http_path == "/users/{user_id}" - assert routes[0].function_name == "get_user" - - -def test_router_with_async_handlers(): - """Test router with async route handlers.""" - from runpod_flash.cli.commands.build_utils.scanner import detect_main_app - - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - main_file = project_root / "main.py" - main_file.write_text(""" -from fastapi import FastAPI -from routers import async_router - -app = FastAPI() -app.include_router(async_router, prefix="/api") -""") - - routers_dir = project_root / "routers" - routers_dir.mkdir() - (routers_dir / "__init__.py").write_text(""" -from fastapi import APIRouter - -async_router = APIRouter() - -@async_router.post("/process") -async def process_data(): - return {} -""") - - main_app_config = detect_main_app(project_root, explicit_lb_exists=False) - - routes = main_app_config["fastapi_routes"] - assert len(routes) == 1 - assert routes[0].http_path == "/api/process" - assert routes[0].is_async is True - assert routes[0].http_method == "POST" - - -def test_class_methods_extraction(): - """Test that public methods are extracted from @remote classes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "gpu_worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_config = LiveServerless(name="gpu_worker") - -@remote(gpu_config) -class SimpleSD: - def __init__(self): - self.model = None - - def generate_image(self, prompt): - return {"image": "base64..."} - - def upscale(self, image): - return {"image": "upscaled..."} - - def _load_model(self): - pass - - def __repr__(self): - return "SimpleSD" -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - meta = functions[0] - assert meta.function_name == "SimpleSD" - assert meta.is_class is True - assert meta.class_methods == ["generate_image", "upscale"] - - -def test_class_methods_excludes_private_and_dunder(): - """Test that _private and __dunder__ methods are excluded from class_methods.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -class MyWorker: - def __init__(self): - pass - - def __repr__(self): - return "MyWorker" - - def _internal_helper(self): - pass - - async def process(self, data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].class_methods == ["process"] - - -def test_class_with_no_public_methods(): - """Test @remote class with only private/dunder methods has empty class_methods.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -class EmptyWorker: - def __init__(self): - pass - - def __call__(self, data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].class_methods == [] - - -def test_function_has_empty_class_methods(): - """Test that regular @remote functions have empty class_methods list.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -async def my_function(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].is_class is False - assert functions[0].class_methods == [] - - -def test_param_names_single_param(): - """Test that param_names extracts a single parameter.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -async def process(data): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].param_names == ["data"] - - -def test_param_names_zero_params(): - """Test that param_names is empty for zero-parameter functions.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -async def list_images() -> dict: - return {"images": []} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].param_names == [] - - -def test_param_names_multiple_params(): - """Test that param_names extracts multiple parameters.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -async def transform(text: str, operation: str = "uppercase") -> dict: - return {"result": text} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].param_names == ["text", "operation"] - - -def test_class_method_params_extraction(): - """Test that class_method_params extracts params for each public method.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -class ImageProcessor: - def __init__(self): - pass - - def generate(self, prompt: str, width: int = 512): - return {} - - def list_models(self): - return [] - - def _internal(self): - pass -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - meta = functions[0] - assert meta.is_class is True - assert meta.class_methods == ["generate", "list_models"] - assert meta.class_method_params == { - "generate": ["prompt", "width"], - "list_models": [], - } - # Classes should have empty param_names - assert meta.param_names == [] - - -def test_calls_remote_functions_direct_call(): - """Direct call to another @remote function sets calls_remote_functions=True.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "workers.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_config = LiveServerless(name="gpu_worker") -cpu_config = LiveServerless(name="cpu_worker") - -@remote(gpu_config) -async def generate(prompt): - return {"image": "..."} - -@remote(cpu_config) -async def orchestrate(prompt): - result = generate(prompt) - return result -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - orchestrate = next(f for f in functions if f.function_name == "orchestrate") - assert orchestrate.calls_remote_functions is True - assert "generate" in orchestrate.called_remote_functions - - -def test_attribute_call_does_not_trigger_calls_remote_functions(): - """Attribute call obj.remote_name() must not flag calls_remote_functions. - - Regression test: if a @remote function named 'generate' exists, then - model.generate() should NOT trigger a false positive. @remote functions - are always invoked as direct calls after import, never via attribute access. - """ - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "workers.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_config = LiveServerless(name="gpu_worker") -cpu_config = LiveServerless(name="cpu_worker") - -@remote(gpu_config) -async def generate(prompt): - return {"image": "..."} - -@remote(cpu_config) -async def run_pipeline(prompt): - model = load_model() - result = model.generate(prompt) - return result -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - run_pipeline = next(f for f in functions if f.function_name == "run_pipeline") - assert run_pipeline.calls_remote_functions is False - assert run_pipeline.called_remote_functions == [] - - -def test_no_remote_calls_sets_calls_remote_functions_false(): - """Standalone function with no remote calls stays False.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "worker.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="worker") - -@remote(config) -async def process(data): - return {"result": data.upper()} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].calls_remote_functions is False - assert functions[0].called_remote_functions == [] - - -def test_calls_remote_functions_multiple_cross_calls(): - """Multiple cross-calls tracked in called_remote_functions list.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "workers.py" - test_file.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu_a = LiveServerless(name="gpu_a") -gpu_b = LiveServerless(name="gpu_b") -cpu = LiveServerless(name="cpu") - -@remote(gpu_a) -async def transcribe(audio): - return {"text": "..."} - -@remote(gpu_b) -async def translate(text): - return {"translated": "..."} - -@remote(cpu) -async def pipeline(audio): - text = transcribe(audio) - translated = translate(text) - return translated -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - pipeline = next(f for f in functions if f.function_name == "pipeline") - assert pipeline.calls_remote_functions is True - assert set(pipeline.called_remote_functions) == {"transcribe", "translate"} - - -def test_calls_remote_functions_from_class_method(): - """Cross-call from a class method is detected when the endpoint wraps a class.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - helpers = project_dir / "helpers.py" - helpers.write_text( - """ -from runpod_flash import CpuLiveServerless, remote - -cpu = CpuLiveServerless(name="tokenizer") - -@remote(cpu) -async def tokenize(text): - return {"tokens": text.split()} -""" - ) - - model = project_dir / "model.py" - model.write_text( - """ -from runpod_flash import LiveServerless, remote - -gpu = LiveServerless(name="model") - -@remote(gpu) -class MyModel: - def __init__(self): - self.ready = True - - async def predict(self, text): - from helpers import tokenize - tokens = tokenize(text) - return {"tokens": tokens} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - my_model = next(f for f in functions if f.function_name == "MyModel") - assert my_model.calls_remote_functions is True - assert "tokenize" in my_model.called_remote_functions - - -def test_calls_remote_functions_endpoint_class_method(): - """Cross-call from an @Endpoint class method is detected.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - helpers = project_dir / "helpers.py" - helpers.write_text( - """ -from runpod_flash import Endpoint, CpuInstanceType - -@Endpoint(name="normalizer", cpu=CpuInstanceType.CPU3C_1_2) -async def normalize(text): - return {"normalized": text.lower()} -""" - ) - - model = project_dir / "model.py" - model.write_text( - """ -from runpod_flash import Endpoint, GpuGroup - -@Endpoint(name="nlp_model", gpu=GpuGroup.ADA_24) -class NLPModel: - def __init__(self): - self.ready = True - - async def predict(self, text): - from helpers import normalize - result = normalize(text) - return {"prediction": result} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - nlp_model = next(f for f in functions if f.function_name == "NLPModel") - assert nlp_model.calls_remote_functions is True - assert "normalize" in nlp_model.called_remote_functions - - -def test_local_flag_extracted_from_lb_route(): - """Test that local=True on an LB @remote is captured in metadata.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "endpoint.py" - test_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer, remote - -lb_config = CpuLiveLoadBalancer(name="lb_worker") - -@remote(lb_config, local=True, method="GET", path="/scripts") -async def list_scripts(): - return {"scripts": []} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].local is True - assert functions[0].is_lb_route_handler is True - - -def test_local_flag_defaults_false(): - """Test that local defaults to False when not specified.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "endpoint.py" - test_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer, remote - -lb_config = CpuLiveLoadBalancer(name="lb_worker") - -@remote(lb_config, method="POST", path="/classify") -async def classify(text: str): - return {"label": "positive"} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].local is False - - -def test_local_flag_on_lb_route_handler(): - """Test local=True on LB route with method and path is an LB route handler.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - test_file = project_dir / "api.py" - test_file.write_text( - """ -from runpod_flash import CpuLiveLoadBalancer, remote - -api = CpuLiveLoadBalancer(name="api_endpoint") - -@remote(api, local=True, method="POST", path="/process") -async def process(data: dict): - return {"result": data} - -@remote(api, method="GET", path="/health") -async def health(): - return {"status": "ok"} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 2 - - process_fn = next(f for f in functions if f.function_name == "process") - assert process_fn.local is True - assert process_fn.is_lb_route_handler is True - assert process_fn.http_method == "POST" - assert process_fn.http_path == "/process" - - health_fn = next(f for f in functions if f.function_name == "health") - assert health_fn.local is False - assert health_fn.is_lb_route_handler is True - - -def test_resource_config_types_matches_exports(): - """Static _RESOURCE_CONFIG_TYPES must include all ServerlessResource subclasses from runpod_flash.""" - for type_name in RemoteDecoratorScanner._RESOURCE_CONFIG_TYPES: - cls = getattr(runpod_flash, type_name, None) - assert cls is not None, f"{type_name} not found in runpod_flash exports" - assert issubclass(cls, ServerlessResource), ( - f"{type_name} is not a ServerlessResource subclass" - ) diff --git a/tests/unit/cli/commands/build_utils/test_scanner_endpoint.py b/tests/unit/cli/commands/build_utils/test_scanner_endpoint.py deleted file mode 100644 index be496a4e..00000000 --- a/tests/unit/cli/commands/build_utils/test_scanner_endpoint.py +++ /dev/null @@ -1,600 +0,0 @@ -"""Tests for scanner recognition of the new Endpoint API patterns.""" - -import tempfile -from pathlib import Path - -from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner - - -# -- QB mode: @Endpoint(...) on a function -- - - -class TestEndpointQBFunction: - def test_discover_simple_qb_function(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "worker.py").write_text( - """ -from runpod_flash import Endpoint, GpuGroup - -@Endpoint(name="gpu-worker", gpu=GpuGroup.ADA_24, workers=(0, 3)) -async def gpu_hello(input_data: dict) -> dict: - return {"result": input_data} -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - f = functions[0] - assert f.function_name == "gpu_hello" - assert f.resource_config_name == "gpu-worker" - assert f.resource_type == "Endpoint" - assert f.is_async is True - assert f.is_class is False - assert f.is_load_balanced is False - assert f.is_lb_route_handler is False - assert f.is_live_resource is True - assert f.param_names == ["input_data"] - - def test_discover_sync_qb_function(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "worker.py").write_text( - """ -from runpod_flash import Endpoint - -@Endpoint(name="sync-worker") -def process(data: dict) -> dict: - return data -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].is_async is False - assert functions[0].resource_config_name == "sync-worker" - - def test_qb_zero_params(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "worker.py").write_text( - """ -from runpod_flash import Endpoint - -@Endpoint(name="zero-params") -async def get_status() -> dict: - return {"status": "ok"} -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].param_names == [] - - def test_qb_multiple_params(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "worker.py").write_text( - """ -from runpod_flash import Endpoint - -@Endpoint(name="multi-param") -async def transform(text: str, operation: str = "upper") -> dict: - return {} -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].param_names == ["text", "operation"] - - def test_qb_docstring_extracted(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "worker.py").write_text( - ''' -from runpod_flash import Endpoint - -@Endpoint(name="documented") -async def process(data: dict) -> dict: - """Process incoming data and return results.""" - return data -''' - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].docstring == "Process incoming data and return results." - - -# -- QB mode: @Endpoint(...) on a class -- - - -class TestEndpointQBClass: - def test_discover_qb_class(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "worker.py").write_text( - """ -from runpod_flash import Endpoint, GpuGroup - -@Endpoint(name="model-worker", gpu=GpuGroup.ANY, workers=(0, 2)) -class TextModel: - def __init__(self): - self.ready = True - - async def predict(self, text: str) -> dict: - return {"text": text} - - async def info(self) -> dict: - return {"ready": self.ready} - - def _internal(self): - pass -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - f = functions[0] - assert f.function_name == "TextModel" - assert f.resource_config_name == "model-worker" - assert f.is_class is True - assert f.is_load_balanced is False - assert f.class_methods == ["predict", "info"] - assert f.class_method_params == { - "predict": ["text"], - "info": [], - } - - def test_qb_class_docstrings(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "worker.py").write_text( - ''' -from runpod_flash import Endpoint - -@Endpoint(name="doc-class") -class Worker: - """A documented worker.""" - - def process(self, data): - """Process data.""" - return data - - def info(self): - return {} -''' - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - f = functions[0] - assert f.docstring == "A documented worker." - assert f.class_method_docstrings["process"] == "Process data." - assert f.class_method_docstrings["info"] is None - - -# -- LB mode: ep = Endpoint(...) + @ep.get/post/... -- - - -class TestEndpointLBRoutes: - def test_discover_single_lb_route(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "api.py").write_text( - """ -from runpod_flash import Endpoint - -api = Endpoint(name="api-service", cpu="cpu3g-2-8", workers=(1, 3)) - -@api.post("/compute") -async def compute(data: dict) -> dict: - return {"result": data} -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - f = functions[0] - assert f.function_name == "compute" - assert f.resource_config_name == "api-service" - assert f.resource_type == "Endpoint" - assert f.is_load_balanced is True - assert f.is_lb_route_handler is True - assert f.is_live_resource is True - assert f.http_method == "POST" - assert f.http_path == "/compute" - assert f.config_variable == "api" - assert f.param_names == ["data"] - - def test_discover_multiple_lb_routes(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "api.py").write_text( - """ -from runpod_flash import Endpoint - -api = Endpoint(name="my-api", cpu="cpu3g-2-8", workers=(1, 3)) - -@api.get("/health") -async def health(): - return {"status": "ok"} - -@api.post("/echo") -async def echo(message: str = "hello") -> dict: - return {"echo": message} - -@api.post("/transform") -async def transform(text: str, operation: str = "uppercase") -> dict: - return {"result": text} - -@api.get("/info") -async def info(): - return {"service": "api"} -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 4 - - by_name = {f.function_name: f for f in functions} - - assert by_name["health"].http_method == "GET" - assert by_name["health"].http_path == "/health" - assert by_name["health"].param_names == [] - - assert by_name["echo"].http_method == "POST" - assert by_name["echo"].http_path == "/echo" - assert by_name["echo"].param_names == ["message"] - - assert by_name["transform"].http_method == "POST" - assert by_name["transform"].http_path == "/transform" - assert by_name["transform"].param_names == ["text", "operation"] - - assert by_name["info"].http_method == "GET" - assert by_name["info"].http_path == "/info" - - # all routes share the same resource - assert all(f.resource_config_name == "my-api" for f in functions) - assert all(f.config_variable == "api" for f in functions) - assert all(f.is_lb_route_handler for f in functions) - - def test_all_http_methods(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "api.py").write_text( - """ -from runpod_flash import Endpoint - -api = Endpoint(name="rest-api") - -@api.get("/items") -async def list_items(): - return [] - -@api.post("/items") -async def create_item(data: dict): - return data - -@api.put("/items") -async def update_item(data: dict): - return data - -@api.delete("/items") -async def delete_item(id: str): - return {"deleted": id} - -@api.patch("/items") -async def patch_item(data: dict): - return data -""" - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - methods = {f.function_name: f.http_method for f in functions} - assert methods == { - "list_items": "GET", - "create_item": "POST", - "update_item": "PUT", - "delete_item": "DELETE", - "patch_item": "PATCH", - } - - def test_lb_route_docstrings(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "api.py").write_text( - ''' -from runpod_flash import Endpoint - -api = Endpoint(name="doc-api") - -@api.get("/health") -async def health(): - """Check service health.""" - return {"status": "ok"} - -@api.post("/process") -async def process(data: dict): - return data -''' - ) - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - by_name = {f.function_name: f for f in functions} - assert by_name["health"].docstring == "Check service health." - assert by_name["process"].docstring is None - - -# -- mixed patterns: QB + LB in same project -- - - -class TestEndpointMixedPatterns: - def test_qb_and_lb_in_different_files(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - (project_dir / "worker.py").write_text( - """ -from runpod_flash import Endpoint, GpuGroup - -@Endpoint(name="gpu-worker", gpu=GpuGroup.ADA_24) -async def process(data: dict) -> dict: - return data -""" - ) - (project_dir / "api.py").write_text( - """ -from runpod_flash import Endpoint - -api = Endpoint(name="cpu-api", cpu="cpu3g-2-8") - -@api.get("/health") -async def health(): - return {"status": "ok"} - -@api.post("/compute") -async def compute(data: dict): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - qb_funcs = [f for f in functions if not f.is_load_balanced] - lb_funcs = [f for f in functions if f.is_load_balanced] - - assert len(qb_funcs) == 1 - assert qb_funcs[0].function_name == "process" - assert qb_funcs[0].resource_config_name == "gpu-worker" - - assert len(lb_funcs) == 2 - assert all(f.resource_config_name == "cpu-api" for f in lb_funcs) - - def test_endpoint_and_legacy_remote_coexist(self): - """new Endpoint API and legacy @remote coexist in the same project.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - - # new style - (project_dir / "new_worker.py").write_text( - """ -from runpod_flash import Endpoint - -@Endpoint(name="new-style") -async def new_process(data: dict): - return data -""" - ) - - # old style - (project_dir / "old_worker.py").write_text( - """ -from runpod_flash import LiveServerless, remote - -config = LiveServerless(name="old-style") - -@remote(config) -async def old_process(data: dict): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - names = {f.function_name for f in functions} - assert names == {"new_process", "old_process"} - - new_f = next(f for f in functions if f.function_name == "new_process") - assert new_f.resource_type == "Endpoint" - - old_f = next(f for f in functions if f.function_name == "old_process") - assert old_f.resource_type == "LiveServerless" - - def test_multiple_lb_endpoints_in_same_file(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "apis.py").write_text( - """ -from runpod_flash import Endpoint - -public_api = Endpoint(name="public") -admin_api = Endpoint(name="admin") - -@public_api.get("/items") -async def list_items(): - return [] - -@admin_api.post("/users") -async def create_user(data: dict): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 2 - by_name = {f.function_name: f for f in functions} - - assert by_name["list_items"].resource_config_name == "public" - assert by_name["list_items"].config_variable == "public_api" - - assert by_name["create_user"].resource_config_name == "admin" - assert by_name["create_user"].config_variable == "admin_api" - - -# -- edge cases -- - - -class TestEndpointScannerEdgeCases: - def test_non_endpoint_attribute_call_ignored(self): - """@app.get() from regular FastAPI should not match as Endpoint route.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "app.py").write_text( - """ -from fastapi import FastAPI - -app = FastAPI() - -@app.get("/health") -async def health(): - return {"status": "ok"} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - # FastAPI routes are detected via detect_main_app, not as Endpoint routes - assert len(functions) == 0 - - def test_endpoint_variable_name_as_fallback(self): - """if name= is missing, variable name is used as resource name.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "api.py").write_text( - """ -from runpod_flash import Endpoint - -my_api = Endpoint() - -@my_api.get("/health") -async def health(): - return {"status": "ok"} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].resource_config_name == "my_api" - assert functions[0].config_variable == "my_api" - - def test_unregistered_variable_route_ignored(self): - """@x.get() where x is not a known Endpoint is silently skipped.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "app.py").write_text( - """ -from somewhere import router - -@router.get("/stuff") -async def stuff(): - return {} -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 0 - - def test_endpoint_in_nested_dir(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - nested = project_dir / "workers" / "gpu" - nested.mkdir(parents=True) - - (nested / "inference.py").write_text( - """ -from runpod_flash import Endpoint, GpuGroup - -@Endpoint(name="gpu-inference", gpu=GpuGroup.ADA_24) -async def infer(data: dict) -> dict: - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 1 - assert functions[0].module_path == "workers.gpu.inference" - assert functions[0].resource_config_name == "gpu-inference" - - def test_cross_call_detection_with_endpoint(self): - """cross-call analysis works for @Endpoint-decorated functions.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "workers.py").write_text( - """ -from runpod_flash import Endpoint, GpuGroup - -@Endpoint(name="worker-a", gpu=GpuGroup.ANY) -async def generate(prompt: str): - return {"text": prompt} - -@Endpoint(name="worker-b", gpu=GpuGroup.ANY) -async def pipeline(prompt: str): - result = generate(prompt) - return result -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - pipeline = next(f for f in functions if f.function_name == "pipeline") - assert pipeline.calls_remote_functions is True - assert "generate" in pipeline.called_remote_functions - - def test_multiple_qb_endpoints_same_file(self): - with tempfile.TemporaryDirectory() as tmpdir: - project_dir = Path(tmpdir) - (project_dir / "workers.py").write_text( - """ -from runpod_flash import Endpoint, GpuGroup - -@Endpoint(name="worker-a", gpu=GpuGroup.ADA_24) -async def process_a(data: dict): - return data - -@Endpoint(name="worker-b", gpu=GpuGroup.ANY) -async def process_b(data: dict): - return data -""" - ) - - scanner = RemoteDecoratorScanner(project_dir) - functions = scanner.discover_remote_functions() - - assert len(functions) == 2 - configs = {f.resource_config_name for f in functions} - assert configs == {"worker-a", "worker-b"} diff --git a/tests/unit/cli/commands/build_utils/test_scanner_load_balancer.py b/tests/unit/cli/commands/build_utils/test_scanner_load_balancer.py deleted file mode 100644 index 9b6e72d5..00000000 --- a/tests/unit/cli/commands/build_utils/test_scanner_load_balancer.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Tests for main.py FastAPI app detection in scanner.""" - -import tempfile -from pathlib import Path - -from runpod_flash.cli.commands.build_utils.scanner import ( - detect_main_app, -) - - -class TestDetectMainApp: - """Test main.py FastAPI app detection.""" - - def test_detect_main_app_with_fastapi_and_routes(self): - """Test detection of main.py with FastAPI app and custom routes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI - -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - result = detect_main_app(project_root) - - assert result is not None - assert result["app_variable"] == "app" - assert result["has_routes"] is True - assert result["file_path"] == main_file - - def test_detect_main_app_with_app_py(self): - """Test detection works with app.py instead of main.py.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - app_file = project_root / "app.py" - app_file.write_text( - """ -from fastapi import FastAPI - -api = FastAPI() - -@api.post("/process") -async def process(): - return {"status": "ok"} -""" - ) - - result = detect_main_app(project_root) - - assert result is not None - assert result["app_variable"] == "api" - assert result["has_routes"] is True - - def test_detect_main_app_with_server_py(self): - """Test detection works with server.py instead of main.py.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - server_file = project_root / "server.py" - server_file.write_text( - """ -from fastapi import FastAPI - -server = FastAPI() - -@server.get("/health") -def health(): - return {"status": "healthy"} -""" - ) - - result = detect_main_app(project_root) - - assert result is not None - assert result["app_variable"] == "server" - assert result["has_routes"] is True - - def test_detect_main_app_no_routes(self): - """Test skipping main.py with FastAPI but no custom routes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -from runpod_flash import remote - -app = FastAPI() - -@remote(resource_config=gpu_config) -def process(data): - return data -""" - ) - - result = detect_main_app(project_root) - - # Should detect app but has_routes should be False - assert result is not None - assert result["has_routes"] is False - - def test_detect_main_app_no_fastapi(self): - """Test returns None if no FastAPI app found.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - main_file = project_root / "main.py" - main_file.write_text( - """ -def main(): - print("hello") -""" - ) - - result = detect_main_app(project_root) - - assert result is None - - def test_detect_main_app_no_file(self): - """Test returns None if no main.py exists.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - result = detect_main_app(project_root) - - assert result is None - - def test_detect_main_app_syntax_error(self): - """Test gracefully handles syntax errors.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI( - # Missing closing parenthesis -""" - ) - - result = detect_main_app(project_root) - - assert result is None - - def test_detect_main_app_priority_main_over_app(self): - """Test main.py takes priority over app.py.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - - # Create both main.py and app.py - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/main") -def main_route(): - return {"from": "main"} -""" - ) - - app_file = project_root / "app.py" - app_file.write_text( - """ -from fastapi import FastAPI -app = FastAPI() - -@app.get("/app") -def app_route(): - return {"from": "app"} -""" - ) - - result = detect_main_app(project_root) - - # Should use main.py - assert result is not None - assert result["file_path"] == main_file - - def test_detect_main_app_async_route(self): - """Test detection works with async routes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI - -app = FastAPI() - -@app.get("/") -async def root(): - return {"msg": "Hello"} -""" - ) - - result = detect_main_app(project_root) - - assert result is not None - assert result["has_routes"] is True - - def test_detect_main_app_multiple_routes(self): - """Test detection with multiple routes.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI - -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} - -@app.post("/items") -def create_item(item: dict): - return item - -@app.put("/items/{item_id}") -async def update_item(item_id: int, item: dict): - return {"id": item_id, "item": item} - -@app.delete("/items/{item_id}") -def delete_item(item_id: int): - return {"deleted": item_id} -""" - ) - - result = detect_main_app(project_root) - - assert result is not None - assert result["has_routes"] is True - - def test_detect_main_app_respects_explicit_lb_flag(self): - """Test that explicit_lb_exists flag prevents auto-detection.""" - with tempfile.TemporaryDirectory() as tmpdir: - project_root = Path(tmpdir) - main_file = project_root / "main.py" - main_file.write_text( - """ -from fastapi import FastAPI - -app = FastAPI() - -@app.get("/") -def root(): - return {"msg": "Hello"} -""" - ) - - # Without flag, should detect - result = detect_main_app(project_root, explicit_lb_exists=False) - assert result is not None - assert result["has_routes"] is True - - # With flag, should not detect - result = detect_main_app(project_root, explicit_lb_exists=True) - assert result is None diff --git a/tests/unit/cli/commands/test_build.py b/tests/unit/cli/commands/test_build.py index 4e69789c..dc435f6f 100644 --- a/tests/unit/cli/commands/test_build.py +++ b/tests/unit/cli/commands/test_build.py @@ -162,7 +162,7 @@ def test_extracts_dependencies_from_async_remote_function(self, tmp_path): worker_file.write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu, dependencies=['torch', 'transformers'])\n" "async def my_async_func(prompt: str) -> str:\n" " return prompt\n" @@ -183,7 +183,7 @@ def test_collect_requirements_scans_full_build_dir(self, tmp_path): api_example = build_dir / "api_example.py" api_example.write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu, dependencies=['transformers'])\n" "async def smoke(prompt: str) -> str:\n" " return prompt\n" @@ -238,7 +238,7 @@ def test_extracts_dependencies_from_mixed_patterns(self, tmp_path): f1 = workers_dir / "remote_worker.py" f1.write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu, dependencies=['torch'])\n" "async def train(data): return data\n" ) @@ -303,7 +303,7 @@ def test_run_build_calls_handler_generator(self, tmp_path): worker_file = project_dir / "worker.py" worker_file.write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu)\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -340,7 +340,7 @@ def test_run_build_produces_qb_handler_files(self, tmp_path): worker_file = project_dir / "worker.py" worker_file.write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu)\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -371,7 +371,7 @@ def test_run_build_manifest_includes_handler_file(self, tmp_path): worker_file = project_dir / "worker.py" worker_file.write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu)\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -511,7 +511,7 @@ def test_exits_when_runpod_flash_not_found(self, tmp_path): project_dir.mkdir() (project_dir / "worker.py").write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu)\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -542,7 +542,7 @@ def test_build_bundles_runpod_flash(self, tmp_path): project_dir.mkdir() (project_dir / "worker.py").write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu)\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -617,7 +617,7 @@ def test_auto_excludes_torch_without_flag(self, tmp_path): project_dir.mkdir() (project_dir / "worker.py").write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu, dependencies=['torch', 'numpy', 'requests'])\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -649,7 +649,7 @@ def test_user_excludes_merged_with_auto(self, tmp_path): project_dir.mkdir() (project_dir / "worker.py").write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu, dependencies=['torch', 'numpy', 'scipy', 'pandas'])\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -682,7 +682,7 @@ def test_auto_exclude_silent_when_not_in_requirements(self, tmp_path, capsys): project_dir.mkdir() (project_dir / "worker.py").write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu, dependencies=['requests'])\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" @@ -708,7 +708,7 @@ def test_user_unmatched_warning_excludes_base_image_packages( project_dir.mkdir() (project_dir / "worker.py").write_text( "from runpod_flash import remote, LiveServerless\n" - "gpu = LiveServerless()\n" + "gpu = LiveServerless(name='gpu_worker')\n" "@remote(gpu, dependencies=['requests'])\n" "def my_func(prompt: str) -> str:\n" " return prompt\n" diff --git a/tests/unit/cli/commands/test_run.py b/tests/unit/cli/commands/test_run.py index b53cd6e0..c1e86dc0 100644 --- a/tests/unit/cli/commands/test_run.py +++ b/tests/unit/cli/commands/test_run.py @@ -44,7 +44,7 @@ def upscale(self, image): """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 1 worker = workers[0] @@ -77,7 +77,7 @@ def generate(self, prompt): """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 1 worker = workers[0] @@ -423,7 +423,7 @@ async def multi_params(text: str, mode: str = "default") -> dict: """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 1 worker = workers[0] @@ -456,7 +456,7 @@ def status(self): """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 1 worker = workers[0] diff --git a/tests/unit/cli/commands/test_run_endpoint.py b/tests/unit/cli/commands/test_run_endpoint.py index d41dd55f..6203e537 100644 --- a/tests/unit/cli/commands/test_run_endpoint.py +++ b/tests/unit/cli/commands/test_run_endpoint.py @@ -20,7 +20,7 @@ async def process(data: dict) -> dict: """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 1 w = workers[0] assert w.worker_type == "QB" @@ -46,7 +46,7 @@ async def compute(data: dict) -> dict: """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 1 w = workers[0] assert w.worker_type == "LB" @@ -86,7 +86,7 @@ async def health(): """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 2 qb = [w for w in workers if w.worker_type == "QB"] @@ -114,7 +114,7 @@ async def info(self): """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) assert len(workers) == 1 w = workers[0] assert w.worker_type == "QB" @@ -137,7 +137,7 @@ async def process(data: dict) -> dict: """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) server_path = _generate_flash_server(project_root, workers) assert server_path.exists() @@ -167,7 +167,7 @@ async def compute(data: dict) -> dict: """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) server_path = _generate_flash_server(project_root, workers) assert server_path.exists() @@ -210,7 +210,7 @@ async def api_process(text: str) -> dict: """ ) - workers = _scan_project_workers(project_root) + workers, _ = _scan_project_workers(project_root) server_path = _generate_flash_server(project_root, workers) content = server_path.read_text() diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index ea8df55b..e1cbbdc0 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -389,7 +389,8 @@ def test_regenerates_server_py_on_py_file_change(self, tmp_path): stop = threading.Event() with patch( - "runpod_flash.cli.commands.run._scan_project_workers", return_value=[] + "runpod_flash.cli.commands.run._scan_project_workers", + return_value=([], None), ) as mock_scan: with patch( "runpod_flash.cli.commands.run._generate_flash_server" diff --git a/tests/unit/test_discovery.py b/tests/unit/test_discovery.py deleted file mode 100644 index b338603d..00000000 --- a/tests/unit/test_discovery.py +++ /dev/null @@ -1,336 +0,0 @@ -"""Unit tests for ResourceDiscovery.""" - -import pytest -from textwrap import dedent - -from runpod_flash.core.discovery import ResourceDiscovery -from runpod_flash.core.resources.serverless import ServerlessResource - - -class TestResourceDiscovery: - """Test ResourceDiscovery functionality.""" - - @pytest.fixture - def temp_entry_point(self, tmp_path): - """Create temporary entry point file for testing.""" - entry_file = tmp_path / "main.py" - return entry_file - - @pytest.fixture - def sample_resource_config(self): - """Create sample resource config for testing.""" - return ServerlessResource( - name="test-gpu", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - def test_discover_no_remote_decorators(self, temp_entry_point): - """Test discovery when no @remote decorators exist.""" - temp_entry_point.write_text( - dedent( - """ - from fastapi import FastAPI - - app = FastAPI() - - @app.get("/") - def root(): - return {"message": "Hello"} - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - - assert resources == [] - - def test_discover_single_remote_decorator(self, temp_entry_point): - """Test discovery of single @remote decorator.""" - temp_entry_point.write_text( - dedent( - """ - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - gpu_config = ServerlessResource( - name="test-gpu", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(resource_config=gpu_config) - async def gpu_task(): - return "result" - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - - assert len(resources) == 1 - assert isinstance(resources[0], ServerlessResource) - assert resources[0].name == "test-gpu" - - def test_discover_multiple_remote_decorators(self, temp_entry_point): - """Test discovery of multiple @remote decorators.""" - temp_entry_point.write_text( - dedent( - """ - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - gpu_config = ServerlessResource( - name="gpu-endpoint", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - cpu_config = ServerlessResource( - name="cpu-endpoint", - gpuCount=0, - workersMax=5, - workersMin=1, - flashboot=False, - ) - - @remote(resource_config=gpu_config) - async def gpu_task(): - return "gpu result" - - @remote(resource_config=cpu_config) - async def cpu_task(): - return "cpu result" - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - - assert len(resources) == 2 - names = {r.name for r in resources} - assert names == {"gpu-endpoint", "cpu-endpoint"} - - def test_discover_positional_argument(self, temp_entry_point): - """Test discovery with positional argument @remote(config).""" - temp_entry_point.write_text( - dedent( - """ - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - my_config = ServerlessResource( - name="test-endpoint", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(my_config) - async def my_task(): - return "result" - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - - assert len(resources) == 1 - assert resources[0].name == "test-endpoint" - - def test_discover_invalid_import(self, temp_entry_point): - """Test discovery handles invalid imports gracefully.""" - temp_entry_point.write_text( - dedent( - """ - import nonexistent_module - - from runpod_flash.client import remote - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - - # Should handle import error gracefully - assert isinstance(resources, list) - - def test_discover_cache(self, temp_entry_point): - """Test that discovery results are cached.""" - temp_entry_point.write_text( - dedent( - """ - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - config = ServerlessResource( - name="cached-endpoint", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(config) - async def task(): - return "result" - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - - # First call - resources1 = discovery.discover() - assert len(resources1) == 1 - - # Second call should use cache - resources2 = discovery.discover() - assert resources1 == resources2 - - def test_clear_cache(self, temp_entry_point): - """Test clearing discovery cache.""" - temp_entry_point.write_text( - dedent( - """ - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - config = ServerlessResource( - name="test-endpoint", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(config) - async def task(): - return "result" - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - assert len(resources) == 1 - - # Clear cache - discovery.clear_cache() - assert discovery._cache == {} - - def test_discover_with_syntax_error(self, temp_entry_point): - """Test discovery handles syntax errors gracefully.""" - temp_entry_point.write_text( - dedent( - """ - def invalid_syntax( - # Missing closing parenthesis - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - - # Should handle parse error gracefully - assert isinstance(resources, list) - - def test_discover_non_deployable_resource(self, temp_entry_point): - """Test discovery skips non-DeployableResource objects.""" - temp_entry_point.write_text( - dedent( - """ - from runpod_flash.client import remote - - # Not a DeployableResource - config = {"name": "not-a-resource"} - - @remote(resource_config=config) - async def task(): - return "result" - """ - ) - ) - - discovery = ResourceDiscovery(str(temp_entry_point)) - resources = discovery.discover() - - # Should skip non-DeployableResource - assert resources == [] - - def test_max_depth_limiting(self, tmp_path): - """Test that recursive scanning respects max_depth.""" - # Create nested module structure - entry_file = tmp_path / "main.py" - level1_file = tmp_path / "level1.py" - level2_file = tmp_path / "level2.py" - level3_file = tmp_path / "level3.py" - - entry_file.write_text("import level1") - level1_file.write_text("import level2") - level2_file.write_text("import level3") - level3_file.write_text("# Too deep") - - discovery = ResourceDiscovery(str(entry_file), max_depth=2) - resources = discovery.discover() - - # Should respect max_depth and not crash - assert isinstance(resources, list) - - def test_discover_with_directory_scan(self, tmp_path): - """Test directory scanning fallback for dynamic imports.""" - # Create entry point without @remote decorators - entry_file = tmp_path / "main.py" - entry_file.write_text( - dedent( - """ - # Dynamic imports using importlib.util - import importlib.util - """ - ) - ) - - # Create worker file in subdirectory with @remote decorator - workers_dir = tmp_path / "workers" - workers_dir.mkdir() - worker_file = workers_dir / "gpu_worker.py" - worker_file.write_text( - dedent( - """ - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - - gpu_config = ServerlessResource( - name="test-gpu-worker", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(resource_config=gpu_config) - async def gpu_task(): - return "result" - """ - ) - ) - - discovery = ResourceDiscovery(str(entry_file)) - resources = discovery.discover() - - # Should find resource via directory scanning - assert len(resources) == 1 - assert resources[0].name == "test-gpu-worker" diff --git a/tests/unit/test_discovery_endpoint.py b/tests/unit/test_discovery_endpoint.py deleted file mode 100644 index 9de1c8e8..00000000 --- a/tests/unit/test_discovery_endpoint.py +++ /dev/null @@ -1,337 +0,0 @@ -"""tests for ResourceDiscovery with Endpoint patterns.""" - -import os -from textwrap import dedent -from unittest.mock import patch - - -from runpod_flash.core.discovery import ResourceDiscovery -from runpod_flash.core.resources.base import DeployableResource - - -class TestDiscoveryEndpointLB: - """test discovery of Endpoint LB patterns (ep = Endpoint(...) + @ep.get).""" - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_endpoint_lb_variable(self, tmp_path): - entry = tmp_path / "api.py" - entry.write_text( - dedent("""\ - from runpod_flash.endpoint import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - api = Endpoint(name="my-api", gpu=GpuGroup.ADA_24, workers=(1, 3)) - - @api.get("/health") - async def health(): - return {"status": "ok"} - - @api.post("/compute") - async def compute(data): - return data - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 1 - assert isinstance(resources[0], DeployableResource) - # the internal resource config prepends "live-" and appends "-fb" - assert "my-api" in resources[0].name - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_endpoint_lb_cpu(self, tmp_path): - entry = tmp_path / "api.py" - entry.write_text( - dedent("""\ - from runpod_flash.endpoint import Endpoint - from runpod_flash.core.resources.cpu import CpuInstanceType - - api = Endpoint(name="cpu-api", cpu=CpuInstanceType.CPU3G_2_8) - - @api.post("/process") - async def process(data): - return data - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 1 - assert "cpu-api" in resources[0].name - # should resolve to a CpuLiveLoadBalancer - assert "Cpu" in type(resources[0]).__name__ - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_endpoint_lb_no_routes_not_discovered(self, tmp_path): - """an Endpoint variable with no route decorators is not found by the - route-based AST scan (no @ep.get/post decorators to trigger detection).""" - entry = tmp_path / "api.py" - entry.write_text( - dedent("""\ - from runpod_flash.endpoint import Endpoint - - api = Endpoint(name="unused-api") - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert resources == [] - - -class TestDiscoveryEndpointResolve: - """test _resolve_resource_variable with Endpoint objects.""" - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_resolve_endpoint_to_deployable(self, tmp_path): - entry = tmp_path / "worker.py" - entry.write_text( - dedent("""\ - from runpod_flash.endpoint import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - ep = Endpoint(name="worker", gpu=GpuGroup.ADA_24) - - @ep.post("/run") - async def run(data): - return data - """) - ) - - discovery = ResourceDiscovery(str(entry)) - module = discovery._import_module(entry) - - result = discovery._resolve_resource_variable(module, "ep") - assert result is not None - assert isinstance(result, DeployableResource) - assert "worker" in result.name - - -class TestDiscoveryEndpointDirectoryScan: - """test directory scanning fallback finds Endpoint patterns.""" - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_directory_scan_finds_endpoint(self, tmp_path): - """directory scan fallback detects files with Endpoint patterns.""" - entry = tmp_path / "main.py" - entry.write_text("import importlib.util\n") - - workers_dir = tmp_path / "workers" - workers_dir.mkdir() - worker = workers_dir / "api.py" - worker.write_text( - dedent("""\ - from runpod_flash.endpoint import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - api = Endpoint(name="found-api", gpu=GpuGroup.ADA_24) - - @api.get("/health") - async def health(): - return {"ok": True} - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 1 - assert "found-api" in resources[0].name - - -class TestDiscoveryEndpointQB: - """test discovery of @Endpoint(...) decorator on functions (QB pattern).""" - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_endpoint_qb_gpu(self, tmp_path): - """@Endpoint(name=..., gpu=...) on a function is discovered.""" - entry = tmp_path / "worker.py" - entry.write_text( - dedent("""\ - from runpod_flash import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - @Endpoint(name="qb-gpu-worker", gpu=GpuGroup.ADA_24, workers=(0, 3)) - async def gpu_task(payload: dict) -> dict: - return {"status": "done"} - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 1 - assert isinstance(resources[0], DeployableResource) - assert "qb-gpu-worker" in resources[0].name - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_endpoint_qb_cpu(self, tmp_path): - """@Endpoint(name=..., cpu=...) on a function is discovered.""" - entry = tmp_path / "worker.py" - entry.write_text( - dedent("""\ - from runpod_flash import Endpoint - - @Endpoint(name="qb-cpu-worker", cpu="cpu3c-1-2") - async def cpu_task(payload: dict) -> dict: - return {"status": "done"} - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 1 - assert "qb-cpu-worker" in resources[0].name - assert "Cpu" in type(resources[0]).__name__ - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_multiple_qb_endpoints_same_file(self, tmp_path): - """Multiple @Endpoint decorators in the same file are all discovered.""" - entry = tmp_path / "worker.py" - entry.write_text( - dedent("""\ - from runpod_flash import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - @Endpoint(name="worker-a", gpu=GpuGroup.ADA_24) - async def task_a(payload: dict) -> dict: - return {"a": True} - - @Endpoint(name="worker-b", cpu="cpu3c-1-2") - async def task_b(payload: dict) -> dict: - return {"b": True} - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 2 - names = {r.name for r in resources} - assert any("worker-a" in n for n in names) - assert any("worker-b" in n for n in names) - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_qb_on_class(self, tmp_path): - """@Endpoint(...) on a class is discovered.""" - entry = tmp_path / "worker.py" - entry.write_text( - dedent("""\ - from runpod_flash import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - @Endpoint(name="class-worker", gpu=GpuGroup.ADA_24) - class MyWorker: - async def generate(self, prompt: str) -> dict: - return {"result": prompt} - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 1 - assert "class-worker" in resources[0].name - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_discover_qb_via_directory_scan(self, tmp_path): - """directory scan fallback finds @Endpoint QB patterns.""" - entry = tmp_path / "main.py" - entry.write_text("import importlib.util\n") - - workers_dir = tmp_path / "workers" - workers_dir.mkdir() - worker = workers_dir / "gpu_worker.py" - worker.write_text( - dedent("""\ - from runpod_flash import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - @Endpoint(name="scan-worker", gpu=GpuGroup.ADA_24) - async def gpu_task(payload: dict) -> dict: - return {"ok": True} - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 1 - assert "scan-worker" in resources[0].name - - -class TestDiscoveryMixed: - """test discovery with both legacy @remote and Endpoint patterns.""" - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_mixed_remote_and_endpoint(self, tmp_path): - entry = tmp_path / "main.py" - entry.write_text( - dedent("""\ - from runpod_flash.client import remote - from runpod_flash.core.resources.serverless import ServerlessResource - from runpod_flash.endpoint import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - legacy_config = ServerlessResource( - name="legacy", - gpuCount=1, - workersMax=3, - workersMin=0, - flashboot=False, - ) - - @remote(resource_config=legacy_config) - async def legacy_task(): - return "legacy" - - api = Endpoint(name="new-api", gpu=GpuGroup.ADA_24) - - @api.post("/process") - async def process(data): - return data - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 2 - names = {r.name for r in resources} - # internal resource configs may modify names (e.g. "live-" prefix, "-fb" suffix) - assert any("legacy" in n for n in names) - assert any("new-api" in n for n in names) - - @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) - def test_mixed_qb_and_lb_endpoints(self, tmp_path): - """both QB and LB Endpoint patterns in the same file are discovered.""" - entry = tmp_path / "main.py" - entry.write_text( - dedent("""\ - from runpod_flash import Endpoint - from runpod_flash.core.resources.gpu import GpuGroup - - @Endpoint(name="qb-worker", gpu=GpuGroup.ADA_24) - async def gpu_task(payload: dict) -> dict: - return {"result": "done"} - - api = Endpoint(name="lb-api", cpu="cpu3c-1-2", workers=(1, 3)) - - @api.post("/process") - async def process(data): - return data - """) - ) - - discovery = ResourceDiscovery(str(entry)) - resources = discovery.discover() - - assert len(resources) == 2 - names = {r.name for r in resources} - assert any("qb-worker" in n for n in names) - assert any("lb-api" in n for n in names) diff --git a/tests/unit/test_endpoint.py b/tests/unit/test_endpoint.py index a3a8cfd1..ffc45f8d 100644 --- a/tests/unit/test_endpoint.py +++ b/tests/unit/test_endpoint.py @@ -670,9 +670,14 @@ def test_id_and_image_mutually_exclusive(self): with pytest.raises(ValueError, match="id and image are mutually exclusive"): Endpoint(name="test", id="abc123", image="img:latest") - def test_name_or_id_required(self): + def test_name_or_id_required_for_image_mode(self): with pytest.raises(ValueError, match="name or id is required"): - Endpoint() + Endpoint(image="my-image:latest") + + def test_nameless_decorator_mode_allowed(self): + """Endpoint() without name is valid for QB decorator mode.""" + ep = Endpoint(gpu=GpuGroup.ANY) + assert ep.name is None def test_id_no_default_gpu(self): """client-only endpoints dont default to GPU ANY.""" diff --git a/tests/unit/test_p2_gaps.py b/tests/unit/test_p2_gaps.py index 430e293d..708ea2ee 100644 --- a/tests/unit/test_p2_gaps.py +++ b/tests/unit/test_p2_gaps.py @@ -226,9 +226,9 @@ class TestScannerDocstring: """AST scanner extracts function docstrings.""" def test_scanner_extracts_docstring(self, tmp_path): - """SCAN-010: RemoteDecoratorScanner extracts first line of docstring.""" + """SCAN-010: RuntimeScanner extracts first line of docstring.""" from runpod_flash.cli.commands.build_utils.scanner import ( - RemoteDecoratorScanner, + RuntimeScanner, ) worker_file = tmp_path / "worker.py" @@ -241,7 +241,7 @@ def test_scanner_extracts_docstring(self, tmp_path): " return x * 2\n" ) - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) functions = scanner.discover_remote_functions() # Scanner should find at least one function diff --git a/tests/unit/test_p2_remaining_gaps.py b/tests/unit/test_p2_remaining_gaps.py index 8c9d5c3c..e12394bb 100644 --- a/tests/unit/test_p2_remaining_gaps.py +++ b/tests/unit/test_p2_remaining_gaps.py @@ -5,8 +5,8 @@ REM-CLS-013 – extract_class_code_simple fallback when inspect.getsource fails RES-LS-008 – ServerlessResource.env default populated from .env file VOL-006 – NetworkVolume with empty name still constructs (no validator guards it) - SCAN-016 – RemoteDecoratorScanner handles @remote on nested class (class in function) - SCAN-017 – RemoteDecoratorScanner handles conditional @remote gracefully + SCAN-016 – RuntimeScanner handles @remote on nested class (class in function) + SCAN-017 – RuntimeScanner handles conditional @remote gracefully STUB-STACK-004 – detect_remote_dependencies terminates on circular dependency graph SRVGEN-008 – RemoteClassWrapper stores _class_type for Pydantic introspection LB-ROUTE-003 – LoadBalancer random strategy selects from endpoint pool @@ -284,12 +284,12 @@ async def test_find_existing_volume_skips_empty_name(self): class TestScannerNestedClass: - """RemoteDecoratorScanner does not crash when a class is defined inside a function.""" + """RuntimeScanner does not crash when a class is defined inside a function.""" def _make_scanner(self, tmp_path: Path): - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner - return RemoteDecoratorScanner(tmp_path) + return RuntimeScanner(tmp_path) def test_nested_class_does_not_cause_scanner_error(self, tmp_path): """SCAN-016: Scanner processes a file containing a @remote on a nested class without error.""" @@ -342,12 +342,12 @@ def work(self): ... class TestScannerConditionalRemote: - """RemoteDecoratorScanner handles or skips conditional decorators gracefully.""" + """RuntimeScanner handles or skips conditional decorators gracefully.""" def _make_scanner(self, tmp_path: Path): - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner - return RemoteDecoratorScanner(tmp_path) + return RuntimeScanner(tmp_path) def test_conditional_decorator_does_not_crash_scanner(self, tmp_path): """SCAN-017: File with conditional @remote is scanned without exception.""" diff --git a/tests/unit/test_p2_remaining_gaps_2.py b/tests/unit/test_p2_remaining_gaps_2.py index f195cf34..ca9ec884 100644 --- a/tests/unit/test_p2_remaining_gaps_2.py +++ b/tests/unit/test_p2_remaining_gaps_2.py @@ -8,7 +8,7 @@ CLI-ENV-006 delete_command bubbles error when env fetch fails CLI-UNDEPLOY-001 list_command uses ResourceManager.list_all_resources() SRVGEN-017 file_to_url_prefix includes directory path (no collision) - BUILD-015 RemoteDecoratorScanner returns empty list for no @remote + BUILD-015 RuntimeScanner returns empty list for no @remote FILE-005 _build_file_upload_wrapper handles multiple bytes params LOG-003 JobOutput.model_post_init logs delayTime and executionTime """ @@ -401,19 +401,19 @@ def test_prefix_starts_with_slash(self, tmp_path): class TestScannerEmptyProject: - """RemoteDecoratorScanner handles projects with no @remote functions.""" + """RuntimeScanner handles projects with no @remote functions.""" def test_empty_project_directory_returns_empty_list(self, tmp_path): """BUILD-015: project with no Python files returns empty list.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) functions = scanner.discover_remote_functions() assert functions == [] def test_py_files_with_no_remote_decorator(self, tmp_path): """BUILD-015: .py files without @remote return empty list.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner (tmp_path / "worker.py").write_text( "def compute(x):\n return x * 2\n", @@ -424,27 +424,27 @@ def test_py_files_with_no_remote_decorator(self, tmp_path): encoding="utf-8", ) - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) functions = scanner.discover_remote_functions() assert functions == [] def test_scanner_with_syntax_error_file_does_not_raise(self, tmp_path): """BUILD-015: a file with a syntax error is skipped gracefully.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner (tmp_path / "broken.py").write_text( "def foo(\n", # unclosed parenthesis -> SyntaxError encoding="utf-8", ) - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) # Should not raise; broken file is silently skipped functions = scanner.discover_remote_functions() assert functions == [] def test_scanner_finds_remote_in_non_empty_project(self, tmp_path): """BUILD-015: contrast test – scanner finds @remote when present.""" - from runpod_flash.cli.commands.build_utils.scanner import RemoteDecoratorScanner + from runpod_flash.cli.commands.build_utils.scanner import RuntimeScanner (tmp_path / "worker.py").write_text( "from runpod_flash import remote, LiveServerless\n" @@ -455,7 +455,7 @@ def test_scanner_finds_remote_in_non_empty_project(self, tmp_path): encoding="utf-8", ) - scanner = RemoteDecoratorScanner(tmp_path) + scanner = RuntimeScanner(tmp_path) functions = scanner.discover_remote_functions() assert len(functions) == 1 assert functions[0].function_name == "process"