diff --git a/src/runpod_flash/cli/commands/_run_server_helpers.py b/src/runpod_flash/cli/commands/_run_server_helpers.py index 9d792425..60898ea5 100644 --- a/src/runpod_flash/cli/commands/_run_server_helpers.py +++ b/src/runpod_flash/cli/commands/_run_server_helpers.py @@ -98,6 +98,11 @@ async def lb_execute(resource_config, func, body: dict): func: The @remote LB route handler function. body: Parsed request body (from FastAPI's automatic JSON parsing). """ + # Extract dependencies before unwrapping the Endpoint facade + dependencies = getattr(resource_config, "dependencies", None) + system_dependencies = getattr(resource_config, "system_dependencies", None) + accelerate_downloads = getattr(resource_config, "accelerate_downloads", False) + # Endpoint facade wraps an internal resource config if hasattr(resource_config, "_build_resource_config"): resource_config = resource_config._build_resource_config() @@ -121,7 +126,9 @@ async def lb_execute(resource_config, func, body: dict): log.info(f"{resource_config} | {route_label}") try: - result = await stub(func, None, None, False, **kwargs) + result = await stub( + func, dependencies, system_dependencies, accelerate_downloads, **kwargs + ) log.info(f"{resource_config} | Execution complete") return result except TimeoutError as e: diff --git a/src/runpod_flash/cli/commands/build.py b/src/runpod_flash/cli/commands/build.py index deaa8e55..65e7e7fe 100644 --- a/src/runpod_flash/cli/commands/build.py +++ b/src/runpod_flash/cli/commands/build.py @@ -60,19 +60,18 @@ PIP_MODULE = "pip" -# Packages pre-installed in base Docker images (runpod/pytorch:*). -# Always excluded from build artifacts to avoid: -# 1. Exceeding the 500 MB tarball limit (torch alone is ~500 MB) -# 2. Redundant copies — these are already in the base Docker image -# NOTE: numpy is excluded because the base Docker image provides it, and -# keeping it out of the tarball saves ~30 MB toward the 500 MB limit. -BASE_IMAGE_PACKAGES: frozenset[str] = frozenset( +# These are CUDA/GPU-oriented packages whose large CUDA builds are already +# provided by the GPU base images (runpod/pytorch:*) and therefore should +# not be bundled into the tarball. +# Do NOT add packages here just because the GPU image ships them (e.g. numpy). +# The blacklist is defined strictly by size constraints, not by whether a +# package happens to be present in a particular base image. +SIZE_PROHIBITIVE_PACKAGES: frozenset[str] = frozenset( { - "torch", - "torchvision", - "torchaudio", - "numpy", - "triton", + "torch", # ~500 MB + "torchvision", # ~50 MB, requires torch + "torchaudio", # ~30 MB, requires torch + "triton", # ~150 MB, CUDA compiler } ) @@ -272,11 +271,11 @@ def run_build( # Create build directory first to ensure clean state before collecting files build_dir = create_build_directory(project_dir, app_name) - # Parse exclusions: merge user-specified with always-excluded base image packages + # Parse exclusions: merge user-specified with always-excluded size-prohibitive packages user_excluded = [] if exclude: user_excluded = [pkg.strip().lower() for pkg in exclude.split(",")] - excluded_packages = list(set(user_excluded) | BASE_IMAGE_PACKAGES) + excluded_packages = list(set(user_excluded) | SIZE_PROHIBITIVE_PACKAGES) spec = load_ignore_patterns(project_dir) files = get_file_tree(project_dir, spec) @@ -370,7 +369,7 @@ def run_build( for req in requirements: if should_exclude_package(req, excluded_packages): pkg_name = extract_package_name(req) - if pkg_name in BASE_IMAGE_PACKAGES: + if pkg_name in SIZE_PROHIBITIVE_PACKAGES: auto_matched.add(pkg_name) if pkg_name in user_excluded: user_matched.add(pkg_name) @@ -381,12 +380,12 @@ def run_build( if auto_matched: console.print( - f"[dim]Auto-excluded base image packages: " + f"[dim]Auto-excluded size-prohibitive packages: " f"{', '.join(sorted(auto_matched))}[/dim]" ) # Only warn about unmatched user-specified packages (not auto-excludes) - user_unmatched = set(user_excluded) - user_matched - BASE_IMAGE_PACKAGES + user_unmatched = set(user_excluded) - user_matched - SIZE_PROHIBITIVE_PACKAGES if user_unmatched: console.print( f"[yellow]Warning:[/yellow] No packages matched exclusions: " @@ -981,7 +980,7 @@ def create_tarball( excluded_packages: list[str] | None = None, ) -> None: """ - Create gzipped tarball of build directory, excluding base image packages. + Create gzipped tarball of build directory, excluding size-prohibitive packages. Filters at tarball creation time rather than constraining pip resolution, because pip constraints (`<0.0.0a0`) break resolution for any package that diff --git a/src/runpod_flash/cli/commands/run.py b/src/runpod_flash/cli/commands/run.py index 063f8918..4d103947 100644 --- a/src/runpod_flash/cli/commands/run.py +++ b/src/runpod_flash/cli/commands/run.py @@ -242,7 +242,7 @@ def _module_parent_subdir(module_path: str) -> str | None: return parts[0].replace(".", "/") -def _make_import_line(module_path: str, name: str) -> str: +def _make_import_line(module_path: str, name: str, alias: str | None = None) -> str: """Build an import statement for *name* from *module_path*. Uses a regular ``from … import …`` when the module path is a valid @@ -250,12 +250,22 @@ def _make_import_line(module_path: str, name: str) -> str: helper in server.py) when any segment starts with a digit. The helper temporarily scopes ``sys.path`` so sibling imports in the target module resolve to the correct directory. + + Args: + module_path: Dotted module path to import from. + name: Symbol name to import. + alias: If provided, assign the import to this variable name instead + of *name*. Prevents collisions when multiple modules export the + same symbol (e.g. multiple files exporting ``api``). """ + target = alias or name if _has_numeric_module_segments(module_path): subdir = _module_parent_subdir(module_path) if subdir: - return f'{name} = _flash_import("{module_path}", "{name}", "{subdir}")' - return f'{name} = _flash_import("{module_path}", "{name}")' + return f'{target} = _flash_import("{module_path}", "{name}", "{subdir}")' + return f'{target} = _flash_import("{module_path}", "{name}")' + if alias: + return f"from {module_path} import {name} as {alias}" return f"from {module_path} import {name}" @@ -390,13 +400,22 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat ) elif worker.worker_type == "LB": # Import the resource config variable (e.g. "api" from api = LiveLoadBalancer(...)) + # Use aliased names to prevent collisions when multiple files export + # the same variable name (e.g. multiple files exporting "api"). config_vars = { r["config_variable"] for r in worker.lb_routes if r.get("config_variable") } for var in sorted(config_vars): - all_imports.append(_make_import_line(worker.module_path, var)) + alias = f"_cfg_{_sanitize_fn_name(worker.resource_name)}" + all_imports.append( + _make_import_line(worker.module_path, var, alias=alias) + ) + # Store the alias so route codegen can reference it + for r in worker.lb_routes: + if r.get("config_variable") == var: + r["_config_alias"] = alias for fn_name in worker.functions: all_imports.append(_make_import_line(worker.module_path, fn_name)) @@ -561,7 +580,7 @@ def _generate_flash_server(project_root: Path, workers: List[WorkerInfo]) -> Pat method = route["method"].lower() sub_path = route["path"].lstrip("/") fn_name = route["fn_name"] - config_var = route["config_variable"] + config_var = route.get("_config_alias") or route["config_variable"] full_path = f"{worker.url_prefix}/{sub_path}" handler_name = _sanitize_fn_name( f"_route_{worker.resource_name}_{fn_name}" diff --git a/src/runpod_flash/core/discovery.py b/src/runpod_flash/core/discovery.py index e77d7c25..728d03ed 100644 --- a/src/runpod_flash/core/discovery.py +++ b/src/runpod_flash/core/discovery.py @@ -95,6 +95,7 @@ def _find_resource_config_vars(self, file_path: Path) -> Set[str]: Detects: - @remote(resource_config=var) / @remote(var) patterns - ep = Endpoint(...) variables used as LB route decorators (@ep.get, @ep.post, etc) + - @Endpoint(...) used directly as a function/class decorator (QB pattern) Args: file_path: Path to Python file to parse @@ -143,6 +144,10 @@ def _find_resource_config_vars(self, file_path: Path) -> Set[str]: if var_name: var_names.add(var_name) + # @Endpoint(name=..., gpu=...) directly on function/class (QB) + elif self._is_endpoint_direct_decorator(decorator): + var_names.add(node.name) + except Exception as e: log.warning(f"Failed to parse {file_path}: {e}") @@ -170,6 +175,21 @@ def _extract_endpoint_var_from_route(self, decorator: ast.Call) -> str: return func.value.id return "" + def _is_endpoint_direct_decorator(self, decorator: ast.expr) -> bool: + """Check if decorator is @Endpoint(...) used directly on a function/class (QB pattern). + + Matches @Endpoint(name=..., gpu=...) but NOT @ep.get()/@ep.post() (which are + attribute calls on an Endpoint variable, handled separately). + """ + if not isinstance(decorator, ast.Call): + return False + func = decorator.func + if isinstance(func, ast.Name) and func.id == "Endpoint": + return True + if isinstance(func, ast.Attribute) and func.attr == "Endpoint": + return True + return False + def _is_remote_decorator(self, decorator: ast.expr) -> bool: """Check if decorator is @remote. @@ -248,8 +268,10 @@ def _import_module(self, file_path: Path): def _resolve_resource_variable(self, module, var_name: str) -> DeployableResource: """Resolve variable name to DeployableResource instance. - Handles both legacy resource config objects (LiveServerless, etc) and - Endpoint facade objects (unwraps via _build_resource_config()). + Handles: + - Legacy resource config objects (LiveServerless, etc) + - Endpoint facade objects (unwraps via _build_resource_config()) + - QB-decorated functions/classes (unwraps __remote_config__["resource_config"]) Args: module: Imported module @@ -270,6 +292,13 @@ def _resolve_resource_variable(self, module, var_name: str) -> DeployableResourc if isinstance(resource, DeployableResource): return resource + # unwrap @Endpoint(...)-decorated function/class (QB pattern). + # Endpoint.__call__ wraps via @remote which attaches __remote_config__ + if obj is not None and hasattr(obj, "__remote_config__"): + resource = obj.__remote_config__.get("resource_config") + if isinstance(resource, DeployableResource): + return resource + if obj is not None: log.warning( f"Resource '{var_name}' failed to resolve to DeployableResource " diff --git a/tests/integration/test_p1_integration_gaps.py b/tests/integration/test_p1_integration_gaps.py index 08ff1404..816c45ad 100644 --- a/tests/integration/test_p1_integration_gaps.py +++ b/tests/integration/test_p1_integration_gaps.py @@ -58,9 +58,9 @@ def test_qb_and_lb_workers_in_same_project(self): assert "_call_with_body(process" in content # LB worker: config + function import + LB route - assert "from api import api_config" in content + assert "from api import api_config as _cfg_api" in content assert "from api import list_items" in content - assert "_lb_execute(api_config, list_items," in content + assert "_lb_execute(_cfg_api, list_items," in content # Both import helpers should be present assert "_call_with_body" in content @@ -113,9 +113,9 @@ def test_qb_class_and_lb_function_in_same_project(self): assert "_instance_TextModel.predict" in content # LB function: config import + route - assert "from health import health_config" in content + assert "from health import health_config as _cfg_health" in content assert "from health import status" in content - assert "_lb_execute(health_config, status," in content + assert "_lb_execute(_cfg_health, status," in content def test_multiple_lb_routes_alongside_qb(self): """Multiple LB routes + QB function all present.""" @@ -160,8 +160,8 @@ def test_multiple_lb_routes_alongside_qb(self): content = server_path.read_text() # Both LB routes registered - assert "_lb_execute(lb_config, create," in content - assert "_lb_execute(lb_config, read," in content + assert "_lb_execute(_cfg_routes, create," in content + assert "_lb_execute(_cfg_routes, read," in content # QB route also present assert '"/worker/runsync"' in content diff --git a/tests/unit/cli/commands/test_build.py b/tests/unit/cli/commands/test_build.py index 09b771cd..4e69789c 100644 --- a/tests/unit/cli/commands/test_build.py +++ b/tests/unit/cli/commands/test_build.py @@ -7,7 +7,7 @@ import typer from runpod_flash.cli.commands.build import ( - BASE_IMAGE_PACKAGES, + SIZE_PROHIBITIVE_PACKAGES, _find_runpod_flash, _resolve_pip_python_version, collect_requirements, @@ -601,15 +601,18 @@ def _stack(): return _stack() def test_constant_contains_expected_packages(self): - """Verify torch ecosystem, numpy, and triton are in BASE_IMAGE_PACKAGES.""" - assert "torch" in BASE_IMAGE_PACKAGES - assert "torchvision" in BASE_IMAGE_PACKAGES - assert "torchaudio" in BASE_IMAGE_PACKAGES - assert "numpy" in BASE_IMAGE_PACKAGES - assert "triton" in BASE_IMAGE_PACKAGES + """Verify CUDA/torch ecosystem packages are in SIZE_PROHIBITIVE_PACKAGES.""" + assert "torch" in SIZE_PROHIBITIVE_PACKAGES + assert "torchvision" in SIZE_PROHIBITIVE_PACKAGES + assert "torchaudio" in SIZE_PROHIBITIVE_PACKAGES + assert "triton" in SIZE_PROHIBITIVE_PACKAGES + + def test_numpy_not_in_size_prohibitive_packages(self): + """NumPy must NOT be excluded — CPU images (python-slim) don't ship it.""" + assert "numpy" not in SIZE_PROHIBITIVE_PACKAGES def test_auto_excludes_torch_without_flag(self, tmp_path): - """Torch and numpy are filtered even with no --exclude flag.""" + """Torch is filtered even with no --exclude flag; numpy passes through.""" project_dir = tmp_path / "project" project_dir.mkdir() (project_dir / "worker.py").write_text( @@ -637,11 +640,11 @@ def fake_install(_build_dir, reqs, _no_deps, target_python_version=None): pkg_names = [extract_package_name(r) for r in installed] assert "torch" not in pkg_names - assert "numpy" not in pkg_names + assert "numpy" in pkg_names assert "requests" in pkg_names def test_user_excludes_merged_with_auto(self, tmp_path): - """User --exclude scipy + auto torch/numpy = all excluded.""" + """User --exclude scipy + auto torch = all excluded; numpy passes through.""" project_dir = tmp_path / "project" project_dir.mkdir() (project_dir / "worker.py").write_text( @@ -669,12 +672,12 @@ def fake_install(_build_dir, reqs, _no_deps, target_python_version=None): pkg_names = [extract_package_name(r) for r in installed] assert "torch" not in pkg_names - assert "numpy" not in pkg_names + assert "numpy" in pkg_names assert "scipy" not in pkg_names assert "pandas" in pkg_names def test_auto_exclude_silent_when_not_in_requirements(self, tmp_path, capsys): - """No auto-exclude message if no base image packages are in requirements.""" + """No auto-exclude message if no size-prohibitive packages are in requirements.""" project_dir = tmp_path / "project" project_dir.mkdir() (project_dir / "worker.py").write_text( @@ -695,7 +698,7 @@ def test_auto_exclude_silent_when_not_in_requirements(self, tmp_path, capsys): run_build(project_dir, "test_app", no_deps=True) captured = capsys.readouterr() - assert "Auto-excluded base image packages" not in captured.out + assert "Auto-excluded size-prohibitive packages" not in captured.out def test_user_unmatched_warning_excludes_base_image_packages( self, tmp_path, capsys diff --git a/tests/unit/cli/test_run.py b/tests/unit/cli/test_run.py index aa94ab7f..ea8df55b 100644 --- a/tests/unit/cli/test_run.py +++ b/tests/unit/cli/test_run.py @@ -469,7 +469,7 @@ def test_post_lb_route_generates_body_param(self, tmp_path): worker = self._make_lb_worker(tmp_path, method) content = _generate_flash_server(tmp_path, [worker]).read_text() assert "body: _api_list_routes_Input" in content - assert "_lb_execute(api_config, list_routes, _to_dict(body))" in content + assert "_lb_execute(_cfg_api, list_routes, _to_dict(body))" in content def test_get_lb_route_uses_query_params(self, tmp_path): """GET LB routes pass query params as a dict.""" @@ -477,15 +477,14 @@ def test_get_lb_route_uses_query_params(self, tmp_path): content = _generate_flash_server(tmp_path, [worker]).read_text() assert "async def _route_api_list_routes(request: Request):" in content assert ( - "_lb_execute(api_config, list_routes, dict(request.query_params))" - in content + "_lb_execute(_cfg_api, list_routes, dict(request.query_params))" in content ) def test_lb_config_var_and_function_imported(self, tmp_path): """LB config vars and functions are both imported for remote dispatch.""" worker = self._make_lb_worker(tmp_path) content = _generate_flash_server(tmp_path, [worker]).read_text() - assert "from api import api_config" in content + assert "from api import api_config as _cfg_api" in content assert "from api import list_routes" in content def test_lb_execute_import_present_when_lb_routes_exist(self, tmp_path): diff --git a/tests/unit/core/api/test_runpod_graphql_extended.py b/tests/unit/core/api/test_runpod_graphql_extended.py index aa3a3104..2c69f126 100644 --- a/tests/unit/core/api/test_runpod_graphql_extended.py +++ b/tests/unit/core/api/test_runpod_graphql_extended.py @@ -226,9 +226,11 @@ class TestGraphQLQueries: lambda r: len(r) == 1, ), ], - ids=lambda x: x - if isinstance(x, str) and not x.startswith("{") and not x.startswith("(") - else "", + ids=lambda x: ( + x + if isinstance(x, str) and not x.startswith("{") and not x.startswith("(") + else "" + ), ) async def test_query_success( self, method_name, call_args, mock_response, assert_fn diff --git a/tests/unit/test_discovery_endpoint.py b/tests/unit/test_discovery_endpoint.py index c03b513d..9de1c8e8 100644 --- a/tests/unit/test_discovery_endpoint.py +++ b/tests/unit/test_discovery_endpoint.py @@ -143,6 +143,128 @@ async def health(): assert "found-api" in resources[0].name +class TestDiscoveryEndpointQB: + """test discovery of @Endpoint(...) decorator on functions (QB pattern).""" + + @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) + def test_discover_endpoint_qb_gpu(self, tmp_path): + """@Endpoint(name=..., gpu=...) on a function is discovered.""" + entry = tmp_path / "worker.py" + entry.write_text( + dedent("""\ + from runpod_flash import Endpoint + from runpod_flash.core.resources.gpu import GpuGroup + + @Endpoint(name="qb-gpu-worker", gpu=GpuGroup.ADA_24, workers=(0, 3)) + async def gpu_task(payload: dict) -> dict: + return {"status": "done"} + """) + ) + + discovery = ResourceDiscovery(str(entry)) + resources = discovery.discover() + + assert len(resources) == 1 + assert isinstance(resources[0], DeployableResource) + assert "qb-gpu-worker" in resources[0].name + + @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) + def test_discover_endpoint_qb_cpu(self, tmp_path): + """@Endpoint(name=..., cpu=...) on a function is discovered.""" + entry = tmp_path / "worker.py" + entry.write_text( + dedent("""\ + from runpod_flash import Endpoint + + @Endpoint(name="qb-cpu-worker", cpu="cpu3c-1-2") + async def cpu_task(payload: dict) -> dict: + return {"status": "done"} + """) + ) + + discovery = ResourceDiscovery(str(entry)) + resources = discovery.discover() + + assert len(resources) == 1 + assert "qb-cpu-worker" in resources[0].name + assert "Cpu" in type(resources[0]).__name__ + + @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) + def test_discover_multiple_qb_endpoints_same_file(self, tmp_path): + """Multiple @Endpoint decorators in the same file are all discovered.""" + entry = tmp_path / "worker.py" + entry.write_text( + dedent("""\ + from runpod_flash import Endpoint + from runpod_flash.core.resources.gpu import GpuGroup + + @Endpoint(name="worker-a", gpu=GpuGroup.ADA_24) + async def task_a(payload: dict) -> dict: + return {"a": True} + + @Endpoint(name="worker-b", cpu="cpu3c-1-2") + async def task_b(payload: dict) -> dict: + return {"b": True} + """) + ) + + discovery = ResourceDiscovery(str(entry)) + resources = discovery.discover() + + assert len(resources) == 2 + names = {r.name for r in resources} + assert any("worker-a" in n for n in names) + assert any("worker-b" in n for n in names) + + @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) + def test_discover_qb_on_class(self, tmp_path): + """@Endpoint(...) on a class is discovered.""" + entry = tmp_path / "worker.py" + entry.write_text( + dedent("""\ + from runpod_flash import Endpoint + from runpod_flash.core.resources.gpu import GpuGroup + + @Endpoint(name="class-worker", gpu=GpuGroup.ADA_24) + class MyWorker: + async def generate(self, prompt: str) -> dict: + return {"result": prompt} + """) + ) + + discovery = ResourceDiscovery(str(entry)) + resources = discovery.discover() + + assert len(resources) == 1 + assert "class-worker" in resources[0].name + + @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) + def test_discover_qb_via_directory_scan(self, tmp_path): + """directory scan fallback finds @Endpoint QB patterns.""" + entry = tmp_path / "main.py" + entry.write_text("import importlib.util\n") + + workers_dir = tmp_path / "workers" + workers_dir.mkdir() + worker = workers_dir / "gpu_worker.py" + worker.write_text( + dedent("""\ + from runpod_flash import Endpoint + from runpod_flash.core.resources.gpu import GpuGroup + + @Endpoint(name="scan-worker", gpu=GpuGroup.ADA_24) + async def gpu_task(payload: dict) -> dict: + return {"ok": True} + """) + ) + + discovery = ResourceDiscovery(str(entry)) + resources = discovery.discover() + + assert len(resources) == 1 + assert "scan-worker" in resources[0].name + + class TestDiscoveryMixed: """test discovery with both legacy @remote and Endpoint patterns.""" @@ -184,3 +306,32 @@ async def process(data): # internal resource configs may modify names (e.g. "live-" prefix, "-fb" suffix) assert any("legacy" in n for n in names) assert any("new-api" in n for n in names) + + @patch.dict(os.environ, {"FLASH_IS_LIVE_PROVISIONING": "true"}) + def test_mixed_qb_and_lb_endpoints(self, tmp_path): + """both QB and LB Endpoint patterns in the same file are discovered.""" + entry = tmp_path / "main.py" + entry.write_text( + dedent("""\ + from runpod_flash import Endpoint + from runpod_flash.core.resources.gpu import GpuGroup + + @Endpoint(name="qb-worker", gpu=GpuGroup.ADA_24) + async def gpu_task(payload: dict) -> dict: + return {"result": "done"} + + api = Endpoint(name="lb-api", cpu="cpu3c-1-2", workers=(1, 3)) + + @api.post("/process") + async def process(data): + return data + """) + ) + + discovery = ResourceDiscovery(str(entry)) + resources = discovery.discover() + + assert len(resources) == 2 + names = {r.name for r in resources} + assert any("qb-worker" in n for n in names) + assert any("lb-api" in n for n in names)