diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1ec8b590..ff798ee1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,3 +53,44 @@ jobs: python -m pip install pip==26.0.1 pip install -e ".[dev,test,performance]" pip-audit + + schema-updated: + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + + - name: Check for schema changes + id: schema + run: | + CHANGED=$(git diff --name-only origin/${{ github.base_ref }}...HEAD -- \ + 'src/inference_endpoint/config/schema.py' \ + 'src/inference_endpoint/endpoint_client/config.py' \ + 'src/inference_endpoint/commands/benchmark/cli.py' \ + 'scripts/regenerate_templates.py' \ + 'src/inference_endpoint/config/templates/*.yaml') + echo "changed=$([[ -n "$CHANGED" ]] && echo true || echo false)" >> "$GITHUB_OUTPUT" + + - name: Set up Python 3.12 + if: steps.schema.outputs.changed == 'true' + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.12" + + - name: Install dependencies + if: steps.schema.outputs.changed == 'true' + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run schema fuzz tests + if: steps.schema.outputs.changed == 'true' + run: | + pytest -xv -m schema_fuzz + + - name: Check YAML templates are up to date + if: steps.schema.outputs.changed == 'true' + run: | + python scripts/regenerate_templates.py --check diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f953252..e8872d43 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,7 @@ repos: hooks: - id: prettier types_or: [yaml, json, markdown] - exclude: ^(src/inference_endpoint/openai/openai_types_gen.py|src/inference_endpoint/openai/openapi.yaml)$ + exclude: ^(src/inference_endpoint/openai/openai_types_gen.py|src/inference_endpoint/openai/openapi.yaml|src/inference_endpoint/config/templates/) - repo: local hooks: @@ -48,12 +48,12 @@ repos: args: ["--tb=short", "--strict-markers"] stages: [manual] - - id: validate-templates - name: Validate YAML templates against schema - entry: python -c "from pathlib import Path; from inference_endpoint.config.schema import BenchmarkConfig; [BenchmarkConfig.from_yaml_file(f) for f in sorted(Path('src/inference_endpoint/config/templates').glob('*.yaml'))]" + - id: regenerate-templates + name: Regenerate YAML templates from schema defaults + entry: python scripts/regenerate_templates.py language: system pass_filenames: false - files: ^src/inference_endpoint/config/(schema\.py|templates/) + files: ^(src/inference_endpoint/config/(schema\.py|templates/.*)|src/inference_endpoint/endpoint_client/config\.py|scripts/regenerate_templates\.py)$ - id: add-license-header name: Add license headers diff --git a/AGENTS.md b/AGENTS.md index 8bb4a38d..52a3dbb5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -162,7 +162,7 @@ src/inference_endpoint/ │ ├── ruleset_registry.py # Ruleset registry │ ├── user_config.py # UserConfig dataclass for ruleset user overrides │ ├── rulesets/mlcommons/ # MLCommons-specific rules, datasets, models -│ └── templates/ # YAML config templates (offline, online, eval, etc.) +│ └── templates/ # YAML config templates (_template.yaml minimal, _template_full.yaml all defaults) ├── openai/ # OpenAI-compatible API types and adapters │ ├── types.py # OpenAI response types │ ├── openai_adapter.py # Request/response adapter @@ -204,7 +204,16 @@ tests/ - **License headers**: Required on all Python files (enforced by pre-commit hook `scripts/add_license_header.py`) - **Conventional commits**: `feat:`, `fix:`, `docs:`, `test:`, `chore:` -All of these hooks run automatically on commit: trailing-whitespace, end-of-file-fixer, check-yaml, check-merge-conflict, debug-statements, `ruff` (lint + autofix), `ruff-format`, `mypy`, `prettier` (YAML/JSON/Markdown), license header enforcement. +### Pre-commit Hooks + +All of these run automatically on commit: + +- trailing-whitespace, end-of-file-fixer, check-yaml, check-merge-conflict, debug-statements +- `ruff` (lint + autofix) and `ruff-format` +- `mypy` type checking +- `prettier` for YAML/JSON/Markdown +- License header enforcement +- `regenerate-templates`: auto-regenerates YAML config templates from schema defaults when `schema.py`, `config.py`, or `regenerate_templates.py` change **Always run `pre-commit run --all-files` before committing.** diff --git a/docs/DEVELOPMENT.md b/docs/DEVELOPMENT.md index 2fe0e9c2..af32da1d 100644 --- a/docs/DEVELOPMENT.md +++ b/docs/DEVELOPMENT.md @@ -273,6 +273,21 @@ pytest -s -v python -m pdb -m pytest test_file.py ``` +## YAML Config Templates + +Config templates in `src/inference_endpoint/config/templates/` are auto-generated from schema defaults. When you change `config/schema.py`, regenerate them: + +```bash +python scripts/regenerate_templates.py +``` + +The pre-commit hook auto-regenerates templates when `schema.py`, `config.py`, or `regenerate_templates.py` change. CI validates templates are up to date via `--check` mode. + +Two variants are generated per mode (offline, online, concurrency): + +- `_template.yaml` — minimal: only required fields + placeholders +- `_template_full.yaml` — all fields with schema defaults + inline `# options:` comments + ## Package Management ### Adding Dependencies diff --git a/pyproject.toml b/pyproject.toml index 19616ece..19fa129d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,8 @@ test = [ "aiohttp==3.13.5", # Plotting for benchmark sweep mode "matplotlib==3.10.8", + # Property-based testing (CLI fuzz) + "hypothesis==6.151.10", ] performance = [ "pytest-benchmark==5.2.3", @@ -184,6 +186,7 @@ markers = [ "integration: marks tests as integration tests", "unit: marks tests as unit tests", "run_explicitly: mark test to only run explicitly", + "schema_fuzz: hypothesis CLI fuzz tests (run in CI on schema changes)", ] filterwarnings = [ "ignore:Session timeout reached:RuntimeWarning", diff --git a/scripts/regenerate_templates.py b/scripts/regenerate_templates.py new file mode 100644 index 00000000..eb84a6dd --- /dev/null +++ b/scripts/regenerate_templates.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Regenerate YAML config templates from Pydantic schema field defaults. + +Used by pre-commit to keep templates in sync when schema.py changes. + +Generates two variants per template: + - ``_template.yaml`` — minimal: only required fields + placeholders + - ``_template_full.yaml`` — all fields with schema defaults + placeholders +""" + +from __future__ import annotations + +import enum +import os +import re +import sys +import types +import typing +from pathlib import Path + +import cyclopts +import yaml +from inference_endpoint.config.schema import ( + BenchmarkConfig, + OfflineBenchmarkConfig, + OnlineBenchmarkConfig, + TestType, +) +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined + +TEMPLATES_DIR = Path(__file__).parent.parent / "src/inference_endpoint/config/templates" + +# Template name → (test_type, extra overrides merged on top). +TEMPLATES: dict[str, tuple[TestType, dict]] = { + "offline": (TestType.OFFLINE, {}), + "online": ( + TestType.ONLINE, + {"settings": {"load_pattern": {"type": "poisson", "target_qps": 10.0}}}, + ), + "concurrency": ( + TestType.ONLINE, + { + "name": "concurrency_benchmark", + "settings": { + "load_pattern": {"type": "concurrency", "target_concurrency": 32} + }, + }, + ), + # TODO(vir): eval/submission raise CLIError in schema, generate templates when support is added +} + +MODEL_FOR_TYPE: dict[TestType, type[BenchmarkConfig]] = { + TestType.OFFLINE: OfflineBenchmarkConfig, + TestType.ONLINE: OnlineBenchmarkConfig, +} + +PERF_DATASET = { + "name": "perf", + "type": "performance", + "path": "", + "parser": {"prompt": "text_input"}, +} + +ACC_DATASET = { + "name": "accuracy", + "type": "accuracy", + "path": "", + "eval_method": "exact_match", + "parser": {"prompt": "question", "system": "system_prompt"}, + "accuracy_config": { + "eval_method": "pass_at_1", + "ground_truth": "ground_truth", + "extractor": "boxed_math_extractor", + "num_repeats": 1, + }, +} + +PLACEHOLDER_MODEL = "" +PLACEHOLDER_ENDPOINT = "" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _unwrap(annotation: object) -> object: + """Unwrap Optional/Annotated/Union/ForwardRef to the core type.""" + # Evaluate string forward refs (e.g. ForwardRef('AccuracyConfig | None')) + if isinstance(annotation, typing.ForwardRef): + return annotation + origin = typing.get_origin(annotation) + if origin is typing.Annotated: + return _unwrap(typing.get_args(annotation)[0]) + if origin is types.UnionType or origin is typing.Union: + args = [a for a in typing.get_args(annotation) if a is not type(None)] + return _unwrap(args[0]) if len(args) == 1 else annotation + return annotation + + +def _resolved_hints(model: type[BaseModel]) -> dict[str, object]: + """Get type hints with forward refs resolved.""" + try: + return typing.get_type_hints(model, include_extras=True) + except Exception: + return {n: i.annotation for n, i in model.model_fields.items()} + + +def _deep_merge(base: dict, override: dict) -> dict: + merged = dict(base) + for k, v in override.items(): + if k in merged and isinstance(merged[k], dict) and isinstance(v, dict): + merged[k] = _deep_merge(merged[k], v) + else: + merged[k] = v + return merged + + +def _dump_defaults(model: type[BaseModel]) -> dict: + """Extract field defaults from a model WITHOUT constructing it. + + Avoids model validators (e.g. num_workers=-1 → CPU count). + Recurses into nested BaseModel fields. Excluded fields are omitted. + """ + hints = _resolved_hints(model) + out: dict[str, object] = {} + for name, info in model.model_fields.items(): + if info.exclude is True: + continue + core = _unwrap(hints.get(name, info.annotation)) + # Get raw default + if info.default is not PydanticUndefined: + default = info.default + elif info.default_factory is not None: + default = info.default_factory() + else: + # Required field — recurse if BaseModel, else None + if isinstance(core, type) and issubclass(core, BaseModel): + out[name] = _dump_defaults(core) + else: + out[name] = None + continue + # Serialize + if isinstance(default, BaseModel): + out[name] = _dump_defaults(type(default)) + elif isinstance(default, list): + out[name] = [ + _dump_defaults(type(i)) if isinstance(i, BaseModel) else i + for i in default + ] + elif isinstance(default, enum.Enum): + out[name] = default.value + elif isinstance(default, Path): + out[name] = str(default) + else: + out[name] = default + return out + + +def _list_item_model(info: object) -> type[BaseModel] | None: + """For a list[SomeModel] field, return SomeModel.""" + if not isinstance(info, FieldInfo): + return None + core = _unwrap(info.annotation) + if typing.get_origin(core) is not list: + return None + args = typing.get_args(core) + if args and isinstance(args[0], type) and issubclass(args[0], BaseModel): + return args[0] + return None + + +# --------------------------------------------------------------------------- +# Inline comments — auto-discovered from Enum/Literal/description +# --------------------------------------------------------------------------- + + +def _collect_comments(model: type[BaseModel]) -> dict[str, str]: + """Walk model tree, build {yaml_key: "# comment"} for described/enum fields. + + For ambiguous field names (same name, different descriptions across models), + falls back to value-specific keys so each enum value gets the right comment. + """ + + def _enum_vals(tp: object) -> list[str] | None: + origin = typing.get_origin(tp) + if origin is typing.Literal: + return [ + a.value if isinstance(a, enum.Enum) else str(a) + for a in typing.get_args(tp) + ] + if isinstance(tp, type) and issubclass(tp, enum.Enum): + return [str(m.value) for m in tp] + return None + + def _help(info: object) -> str | None: + if not isinstance(info, FieldInfo): + return None + if info.description: + return info.description + for m in info.metadata or []: + if isinstance(m, cyclopts.Parameter) and m.help: + return m.help + return None + + result: dict[str, str] = {} + by_name: dict[str, list[str]] = {} + + def _walk(m: type[BaseModel]) -> None: + hints = _resolved_hints(m) + for name, info in m.model_fields.items(): + if info.annotation is None: + continue + core = _unwrap(hints.get(name, info.annotation)) + vals = _enum_vals(core) + parts: list[str] = [] + desc = _help(info) + if desc: + parts.append(desc) + if vals: + parts.append(f"options: {', '.join(vals)}") + if parts: + comment = "# " + " | ".join(parts) + by_name.setdefault(name, []).append(comment) + if vals: + for v in vals: + result[f"{name}: {v}"] = comment + # Recurse into nested models + if isinstance(core, type) and issubclass(core, BaseModel): + _walk(core) + elif typing.get_origin(core) is list: + args = typing.get_args(core) + if ( + args + and isinstance(args[0], type) + and issubclass(args[0], BaseModel) + ): + _walk(args[0]) + + _walk(model) + for name, comments in by_name.items(): + if len(set(comments)) == 1: + result[f"{name}: "] = comments[0] + # Also match block-style (no trailing space, e.g. "parser:\n") + result[f"{name}:"] = comments[0] + + return result + + +def _add_comments(text: str, comments: dict[str, str]) -> str: + """Inject inline # comments into YAML text.""" + for key, comment in sorted(comments.items(), key=lambda x: -len(x[0])): + text = re.sub( + rf"^(\s*{re.escape(key)}.*)$", + lambda m, c=comment: m.group(0) + if "#" in m.group(0) + else f"{m.group(0)} {c}", + text, + count=0, + flags=re.MULTILINE, + ) + return text + + +# --------------------------------------------------------------------------- +# Template builders +# --------------------------------------------------------------------------- + + +def _build_full(model_cls: type[BenchmarkConfig], overrides: dict) -> dict: + """All fields with schema defaults + placeholders. 2 dataset examples.""" + data = _dump_defaults(model_cls) + + # Fill empty list[BaseModel] fields with one default entry + for name, info in model_cls.model_fields.items(): + if isinstance(data.get(name), list) and len(data[name]) == 0: + item_model = _list_item_model(info) + if item_model is not None: + data[name] = [_dump_defaults(item_model)] + + data = _deep_merge( + data, + { + "model_params": {"name": PLACEHOLDER_MODEL}, + "endpoint_config": {"endpoints": [PLACEHOLDER_ENDPOINT]}, + }, + ) + + # 2 dataset examples: perf + accuracy + ds_defaults = _dump_defaults( + _list_item_model(model_cls.model_fields["datasets"]) # type: ignore[arg-type] + ) + data["datasets"] = [ + _deep_merge(ds_defaults, PERF_DATASET), + _deep_merge(ds_defaults, ACC_DATASET), + ] + + if overrides: + data = _deep_merge(data, overrides) + + # Resolve streaming AUTO → off/on (mirrors schema validator) + test_type = data.get("type") + mp = data.get("model_params", {}) + if isinstance(mp, dict) and mp.get("streaming") == "auto": + mp["streaming"] = "off" if test_type == "offline" else "on" + + if not data.get("name") and test_type: + data["name"] = f"{test_type}_benchmark" + + return data + + +def _build_minimal(test_type: TestType, overrides: dict) -> dict: + """Only required fields + placeholders. 1 dataset example.""" + name = overrides.get("name") or f"{test_type.value}_benchmark" + data: dict[str, object] = { + "name": name, + "type": test_type.value, + "model_params": {"name": PLACEHOLDER_MODEL}, + "datasets": [PERF_DATASET], + "settings": { + "runtime": { + "min_duration_ms": 600000, + "max_duration_ms": 0, + "n_samples_to_issue": None, + }, + }, + "endpoint_config": {"endpoints": [PLACEHOLDER_ENDPOINT]}, + } + if overrides: + data = _deep_merge(data, overrides) + return data + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(check_only: bool = False): + """Regenerate templates, or check they're up to date. + + Locally (pre-commit): regenerates files, pre-commit detects the diff. + CI: auto-detects ``CI`` env var and switches to check-only mode. + Explicit: ``--check`` flag forces check-only mode. + """ + if os.environ.get("CI"): + check_only = True + + comments = _collect_comments(BenchmarkConfig) + stale = False + + for name, (test_type, overrides) in TEMPLATES.items(): + model_cls = MODEL_FOR_TYPE[test_type] + variants = { + f"{name}_template.yaml": _build_minimal(test_type, overrides), + f"{name}_template_full.yaml": _build_full(model_cls, overrides), + } + for filename, data in variants.items(): + raw = yaml.dump(data, default_flow_style=False, sort_keys=False) + expected = _add_comments(raw, comments) + path = TEMPLATES_DIR / filename + + if check_only: + current = path.read_text() if path.exists() else "" + if current != expected: + print(f" STALE: {filename}") + stale = True + else: + print(f" OK: {filename}") + else: + path.write_text(expected) + print(f" Generated: {filename}") + + if stale: + print("\nRun: python scripts/regenerate_templates.py") + raise SystemExit(1) + + +if __name__ == "__main__": + main(check_only="--check" in sys.argv) diff --git a/src/inference_endpoint/commands/init.py b/src/inference_endpoint/commands/init.py index d14e1a82..f665bc1a 100644 --- a/src/inference_endpoint/commands/init.py +++ b/src/inference_endpoint/commands/init.py @@ -21,56 +21,79 @@ import yaml -from ..config.schema import TEMPLATE_TYPE_MAP, BenchmarkConfig +from ..config.schema import ( + BenchmarkConfig, + LoadPattern, + LoadPatternType, + OnlineSettings, + TestType, +) from ..exceptions import InputValidationError, SetupError logger = logging.getLogger(__name__) TEMPLATES_DIR = Path(__file__).parent.parent / "config" / "templates" -TEMPLATE_FILES = { - "offline": "offline_template.yaml", - "online": "online_template.yaml", +VALID_TYPES = {"offline", "online", "concurrency", "eval", "submission"} + +# eval/submission not yet supported in create_default_config — copy handwritten templates +_HANDWRITTEN = { "eval": "eval_template.yaml", "submission": "submission_template.yaml", } +_TYPE_MAP = { + "offline": TestType.OFFLINE, + "online": TestType.ONLINE, + "concurrency": TestType.ONLINE, +} + def execute_init(template_type: str) -> None: - """Generate example YAML configuration template.""" - output_path = f"{template_type}_template.yaml" + """Generate YAML config template. + + For offline/online/concurrency: generates via model_dump(exclude_none=True). + For eval/submission: copies handwritten template files. - if template_type not in TEMPLATE_FILES: + Args: + template_type: One of "offline", "online", "concurrency", "eval", "submission". + """ + if template_type not in VALID_TYPES: raise InputValidationError( f"Unknown template type: {template_type}. " - f"Available: {', '.join(TEMPLATE_FILES.keys())}" + f"Available: {', '.join(sorted(VALID_TYPES))}" ) - template_file = TEMPLATES_DIR / TEMPLATE_FILES[template_type] + output_path = f"{template_type}_template.yaml" output_file = Path(output_path) + if output_file.exists(): logger.warning(f"File exists: {output_path} (will be overwritten)") try: - if not template_file.exists(): - logger.info("Generating from BenchmarkConfig.create_default_config...") - config = BenchmarkConfig.create_default_config( - TEMPLATE_TYPE_MAP[template_type] - ) - config_dict = config.model_dump(exclude_none=True) - with open(output_path, "w") as f: - yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False) - logger.info(f"Generated: {output_path}") - else: + # TODO(vir): + # generate these automatically when support is added + # for now just copy over hand-written templates + if template_type in _HANDWRITTEN: + template_file = TEMPLATES_DIR / _HANDWRITTEN[template_type] + if not template_file.exists(): + raise SetupError(f"Template file not found: {template_file}") shutil.copy(template_file, output_path) - logger.info(f"Created from template: {output_path}") - - except NotImplementedError as e: - logger.error(str(e)) - if template_file.exists(): - shutil.copy(template_file, output_path) - logger.info(f"Created from template: {output_path}") else: - raise SetupError(f"Template file not found: {template_file}") from e + config = BenchmarkConfig.create_default_config(_TYPE_MAP[template_type]) + if template_type == "concurrency": + config = config.with_updates( + name="concurrency_benchmark", + settings=OnlineSettings( + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, + target_concurrency=32, + ), + ), + ) + data = config.model_dump(mode="json", exclude_none=True) + with open(output_path, "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + logger.info(f"Created: {output_path}") except (OSError, PermissionError) as e: raise SetupError(f"Failed to create template: {e}") from e diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 01f7ce2c..a8fb87ac 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -88,6 +88,16 @@ class EvalMethod(str, Enum): JUDGE = "judge" +class ScorerMethod(str, Enum): + """Registered scorer methods for accuracy evaluation.""" + + PASS_AT_1 = "pass_at_1" + STRING_MATCH = "string_match" + ROUGE = "rouge" + CODE_BENCH = "code_bench_scorer" + SHOPIFY_CATEGORY_F1 = "shopify_category_f1" + + class TestMode(str, Enum): """Test mode determining what to collect. @@ -244,7 +254,9 @@ class Dataset(BaseModel): eval_method: EvalMethod | None = Field( None, description="Accuracy evaluation method" ) - parser: dict[str, str] | None = Field(None, description="Column remapping") + parser: dict[str, str] | None = Field( + None, description="Column remapping: {prompt: , system: }" + ) accuracy_config: AccuracyConfig | None = Field( None, description="Accuracy evaluation settings" ) @@ -260,14 +272,11 @@ def _auto_derive_name(self) -> Self: class AccuracyConfig(BaseModel): """Accuracy configuration. - The eval_method is the method to use to evaluate the accuracy of the model. - Currently only "pass_at_1" is supported. - The ground_truth is the column in the dataset that contains the ground truth. - Defaults to "ground_truth" if not specified. - The extractor is the extractor to use to extract the ground truth from the output. - Currently "boxed_math_extractor" and "abcd_extractor" are supported. - The num_repeats is the number of times to repeat the dataset for evaluation. - Defaults to 1 if not specified. + eval_method: Scorer to use (see ScorerMethod enum for options). + ground_truth: Column in the dataset containing ground truth. Defaults to "ground_truth". + extractor: Post-processor to extract answers from model output + (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor). + num_repeats: Number of times to repeat the dataset for evaluation. Defaults to 1. Example: accuracy_config: @@ -279,10 +288,15 @@ class AccuracyConfig(BaseModel): model_config = ConfigDict(extra="forbid", frozen=True) - eval_method: str | None = None - ground_truth: str | None = None - extractor: str | None = None - num_repeats: int = Field(1, ge=1) + eval_method: ScorerMethod | None = Field(None, description="Scorer method") + ground_truth: str | None = Field(None, description="Ground truth column name") + extractor: str | None = Field( + None, + description="Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor)", + ) + num_repeats: int = Field( + 1, ge=1, description="Repeat dataset N times for evaluation" + ) class RuntimeConfig(BaseModel): @@ -339,6 +353,7 @@ def _validate_durations(self) -> Self: return self +@cyclopts.Parameter(name="*") class LoadPattern(BaseModel): """Load pattern configuration. @@ -352,7 +367,7 @@ class LoadPattern(BaseModel): type: Annotated[ LoadPatternType, - cyclopts.Parameter(alias="--load-pattern", help="Load pattern type"), + cyclopts.Parameter(name="--load-pattern", help="Load pattern type"), ] = LoadPatternType.MAX_THROUGHPUT target_qps: Annotated[ float | None, cyclopts.Parameter(alias="--target-qps", help="Target QPS") @@ -514,7 +529,9 @@ class BenchmarkConfig(WithUpdatesMixin, BaseModel): cyclopts.Parameter(alias="--timeout", help="Global timeout in seconds"), ] = None # verbose is handled by cyclopts meta app (-v flag), not here - verbose: Annotated[bool, cyclopts.Parameter(show=False)] = False + verbose: Annotated[bool, cyclopts.Parameter(show=False)] = Field( + False, description="Enable verbose logging" + ) enable_cpu_affinity: Annotated[ bool, cyclopts.Parameter( diff --git a/src/inference_endpoint/config/templates/concurrency_template.yaml b/src/inference_endpoint/config/templates/concurrency_template.yaml index 4e61bffb..31e1c05f 100644 --- a/src/inference_endpoint/config/templates/concurrency_template.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template.yaml @@ -1,50 +1,21 @@ -# Online Concurrency-Based Benchmark (NOT YET IMPLEMENTED) -# This template shows the future concurrency-based online mode -name: "concurrency-benchmark" -version: "1.0" -type: "online" - +name: concurrency_benchmark +type: online # Test type: offline, online, eval, submission | options: offline, online, eval, submission model_params: - name: "meta-llama/Llama-3.1-8B-Instruct" - temperature: 0.7 - top_p: 0.9 - max_new_tokens: 1024 - -datasets: - - name: "concurrency-test" - type: "performance" - path: "datasets/queries.jsonl" - samples: 500 - + name: '' +datasets: # Dataset configs +- name: perf + type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + parser: # Column remapping: {prompt: , system: } + prompt: text_input settings: runtime: - min_duration_ms: 600000 # 10 minutes - max_duration_ms: 1800000 # 30 minutes - scheduler_random_seed: 42 # For Poisson/distribution sampling - dataloader_random_seed: 42 # For dataset shuffling - + min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) + max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) + n_samples_to_issue: null # Sample count override load_pattern: - type: "concurrency" # NOT YET IMPLEMENTED - target_concurrency: 32 # Maintain 32 concurrent requests - # Note: target_qps is not used in this mode - # QPS will be determined by: concurrency / avg_latency - - client: - num_workers: 4 - -metrics: - collect: - - "throughput" # Will be concurrency / avg_latency - - "latency" # p50, p90, p95, p99, p999 at this concurrency level - - "ttft" - - "tpot" - + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + target_concurrency: 32 # Concurrent requests endpoint_config: - endpoints: - - "http://localhost:8000" - api_key: null - api_type: "openai" # Options: openai or sglang -# How this differs from Poisson mode: -# - Poisson: Fixed QPS target, concurrency varies based on latency -# - Concurrency: Fixed N requests in-flight, QPS varies based on latency -# - Useful for: Measuring latency at specific concurrency levels + endpoints: # Endpoint URL(s) + - '' diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml new file mode 100644 index 00000000..1e18b3bf --- /dev/null +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -0,0 +1,86 @@ +name: concurrency_benchmark +version: '1.0' # Config version +type: online # Test type: offline, online, eval, submission | options: offline, online, eval, submission +submission_ref: null +benchmark_mode: null # options: offline, online +model_params: + name: '' + temperature: null # Sampling temperature + top_k: null # Top-K sampling + top_p: null # Top-P (nucleus) sampling + repetition_penalty: null # Repetition penalty + max_new_tokens: 1024 # Max output tokens + osl_distribution: null # Output sequence length distribution + streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off +datasets: # Dataset configs +- name: perf + type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + format: null # Dataset format (auto-detected) + samples: null # Number of samples to use + eval_method: null + parser: # Column remapping: {prompt: , system: } + prompt: text_input + accuracy_config: null # Accuracy evaluation settings +- name: accuracy + type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + format: null # Dataset format (auto-detected) + samples: null # Number of samples to use + eval_method: exact_match # Accuracy evaluation method | options: exact_match, contains, judge + parser: # Column remapping: {prompt: , system: } + prompt: question + system: system_prompt + accuracy_config: # Accuracy evaluation settings + eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1 + ground_truth: ground_truth # Ground truth column name + extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) + num_repeats: 1 # Repeat dataset N times for evaluation +settings: + runtime: + min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) + max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) + n_samples_to_issue: null # Sample count override + scheduler_random_seed: 42 # Scheduler RNG seed + dataloader_random_seed: 42 # Dataloader RNG seed + load_pattern: + type: concurrency # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + target_qps: null # Target QPS + target_concurrency: 32 # Concurrent requests + client: + num_workers: -1 # Worker processes (-1=auto) + record_worker_events: false # Record per-worker events + log_level: INFO # Worker log level + warmup_connections: -1 # Pre-establish TCP connections (-1=auto, 0=disabled) + max_connections: -1 # Max TCP connections (-1=unlimited) + transport: + type: zmq # options: zmq + recv_buffer_size: 16777216 # IPC receive buffer size in bytes (default 16MB). Increase for multimodal payloads. + send_buffer_size: 16777216 # IPC send buffer size in bytes (default 16MB). Increase for multimodal payloads. + io_threads: 4 # ZMQ I/O thread pool size (main process) + worker_io_threads: 1 # ZMQ I/O thread pool size (worker processes) + high_water_mark: 0 # ZMQ HWM (0=unlimited) + linger: -1 # ZMQ linger on close (-1=block until sent) + immediate: 1 # ZMQ IMMEDIATE (1=only enqueue on ready) + stream_all_chunks: false # Stream all chunks to main thread (caution: perf overhead) + worker_initialization_timeout: 60.0 # Worker init timeout (seconds) + worker_graceful_shutdown_wait: 0.5 # Post-run graceful shutdown wait (seconds) + worker_force_kill_timeout: 0.5 # Force kill timeout after graceful wait (seconds) + max_idle_time: 4.0 # Discard connections idle longer than this (seconds) + min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) + worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system +metrics: + collect: + - throughput + - latency + - ttft + - tpot +endpoint_config: + endpoints: # Endpoint URL(s) + - '' + api_key: null # API key + api_type: openai # API type: openai or sglang | options: openai, sglang +report_dir: null # Report output directory +timeout: null # Global timeout in seconds +verbose: false # Enable verbose logging +enable_cpu_affinity: true # NUMA-aware CPU pinning diff --git a/src/inference_endpoint/config/templates/offline_template.yaml b/src/inference_endpoint/config/templates/offline_template.yaml index da0816e7..6531771a 100644 --- a/src/inference_endpoint/config/templates/offline_template.yaml +++ b/src/inference_endpoint/config/templates/offline_template.yaml @@ -1,44 +1,18 @@ -# Offline Throughput Benchmark -name: "offline-benchmark" -version: "1.0" -type: "offline" - +name: offline_benchmark +type: offline # Test type: offline, online, eval, submission | options: offline, online, eval, submission model_params: - name: "meta-llama/Llama-3.1-8B-Instruct" - temperature: 0.7 - top_p: 0.9 - max_new_tokens: 1024 - -datasets: - - name: "perf-test" - type: "performance" - path: "tests/datasets/dummy_1k.jsonl" - samples: 1000 - parser: - prompt: "text_input" - + name: '' +datasets: # Dataset configs +- name: perf + type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + parser: # Column remapping: {prompt: , system: } + prompt: text_input settings: runtime: - min_duration_ms: 60000 # 1 minutes - max_duration_ms: 180000 # 3 minutes - scheduler_random_seed: 42 # For Poisson/distribution sampling - dataloader_random_seed: 42 # For dataset shuffling - - load_pattern: - type: "max_throughput" - - client: - num_workers: 4 - -metrics: - collect: - - "throughput" - - "latency" - - "ttft" - - "tpot" - + min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) + max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) + n_samples_to_issue: null # Sample count override endpoint_config: - endpoints: - - "http://localhost:8000" - api_key: null - api_type: "openai" # Options: openai or sglang + endpoints: # Endpoint URL(s) + - '' diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml new file mode 100644 index 00000000..29a661ed --- /dev/null +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -0,0 +1,86 @@ +name: offline_benchmark +version: '1.0' # Config version +type: offline # Test type: offline, online, eval, submission | options: offline, online, eval, submission +submission_ref: null +benchmark_mode: null # options: offline, online +model_params: + name: '' + temperature: null # Sampling temperature + top_k: null # Top-K sampling + top_p: null # Top-P (nucleus) sampling + repetition_penalty: null # Repetition penalty + max_new_tokens: 1024 # Max output tokens + osl_distribution: null # Output sequence length distribution + streaming: 'off' # Streaming mode: auto/on/off | options: auto, on, off +datasets: # Dataset configs +- name: perf + type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + format: null # Dataset format (auto-detected) + samples: null # Number of samples to use + eval_method: null + parser: # Column remapping: {prompt: , system: } + prompt: text_input + accuracy_config: null # Accuracy evaluation settings +- name: accuracy + type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + format: null # Dataset format (auto-detected) + samples: null # Number of samples to use + eval_method: exact_match # Accuracy evaluation method | options: exact_match, contains, judge + parser: # Column remapping: {prompt: , system: } + prompt: question + system: system_prompt + accuracy_config: # Accuracy evaluation settings + eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1 + ground_truth: ground_truth # Ground truth column name + extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) + num_repeats: 1 # Repeat dataset N times for evaluation +settings: + runtime: + min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) + max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) + n_samples_to_issue: null # Sample count override + scheduler_random_seed: 42 # Scheduler RNG seed + dataloader_random_seed: 42 # Dataloader RNG seed + load_pattern: + type: max_throughput # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + target_qps: null # Target QPS + target_concurrency: null # Concurrent requests + client: + num_workers: -1 # Worker processes (-1=auto) + record_worker_events: false # Record per-worker events + log_level: INFO # Worker log level + warmup_connections: -1 # Pre-establish TCP connections (-1=auto, 0=disabled) + max_connections: -1 # Max TCP connections (-1=unlimited) + transport: + type: zmq # options: zmq + recv_buffer_size: 16777216 # IPC receive buffer size in bytes (default 16MB). Increase for multimodal payloads. + send_buffer_size: 16777216 # IPC send buffer size in bytes (default 16MB). Increase for multimodal payloads. + io_threads: 4 # ZMQ I/O thread pool size (main process) + worker_io_threads: 1 # ZMQ I/O thread pool size (worker processes) + high_water_mark: 0 # ZMQ HWM (0=unlimited) + linger: -1 # ZMQ linger on close (-1=block until sent) + immediate: 1 # ZMQ IMMEDIATE (1=only enqueue on ready) + stream_all_chunks: false # Stream all chunks to main thread (caution: perf overhead) + worker_initialization_timeout: 60.0 # Worker init timeout (seconds) + worker_graceful_shutdown_wait: 0.5 # Post-run graceful shutdown wait (seconds) + worker_force_kill_timeout: 0.5 # Force kill timeout after graceful wait (seconds) + max_idle_time: 4.0 # Discard connections idle longer than this (seconds) + min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) + worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system +metrics: + collect: + - throughput + - latency + - ttft + - tpot +endpoint_config: + endpoints: # Endpoint URL(s) + - '' + api_key: null # API key + api_type: openai # API type: openai or sglang | options: openai, sglang +report_dir: null # Report output directory +timeout: null # Global timeout in seconds +verbose: false # Enable verbose logging +enable_cpu_affinity: true # NUMA-aware CPU pinning diff --git a/src/inference_endpoint/config/templates/online_template.yaml b/src/inference_endpoint/config/templates/online_template.yaml index de81431f..eafac9e9 100644 --- a/src/inference_endpoint/config/templates/online_template.yaml +++ b/src/inference_endpoint/config/templates/online_template.yaml @@ -1,45 +1,21 @@ -# Online Latency Benchmark -name: "online-benchmark" -version: "1.0" -type: "online" - +name: online_benchmark +type: online # Test type: offline, online, eval, submission | options: offline, online, eval, submission model_params: - name: "meta-llama/Llama-3.1-8B-Instruct" - temperature: 0.7 - top_p: 0.9 - max_new_tokens: 1024 - -datasets: - - name: "latency-test" - type: "performance" - path: "cnn_dailymail_train.json" - samples: 500 - parser: - prompt: "article" - + name: '' +datasets: # Dataset configs +- name: perf + type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + parser: # Column remapping: {prompt: , system: } + prompt: text_input settings: runtime: - min_duration_ms: 60000 # 1 minutes - max_duration_ms: 180000 # 3 minutes - scheduler_random_seed: 42 # For Poisson/distribution sampling - dataloader_random_seed: 42 # For dataset shuffling - + min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) + max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) + n_samples_to_issue: null # Sample count override load_pattern: - type: "poisson" - target_qps: 10 - - client: - num_workers: 4 - -metrics: - collect: - - "throughput" - - "latency" - - "ttft" - - "tpot" - + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + target_qps: 10.0 # Target QPS endpoint_config: - endpoints: - - "http://localhost:8000" - api_key: null - api_type: "openai" # Options: openai or sglang + endpoints: # Endpoint URL(s) + - '' diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml new file mode 100644 index 00000000..ad1a2423 --- /dev/null +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -0,0 +1,86 @@ +name: online_benchmark +version: '1.0' # Config version +type: online # Test type: offline, online, eval, submission | options: offline, online, eval, submission +submission_ref: null +benchmark_mode: null # options: offline, online +model_params: + name: '' + temperature: null # Sampling temperature + top_k: null # Top-K sampling + top_p: null # Top-P (nucleus) sampling + repetition_penalty: null # Repetition penalty + max_new_tokens: 1024 # Max output tokens + osl_distribution: null # Output sequence length distribution + streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off +datasets: # Dataset configs +- name: perf + type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + format: null # Dataset format (auto-detected) + samples: null # Number of samples to use + eval_method: null + parser: # Column remapping: {prompt: , system: } + prompt: text_input + accuracy_config: null # Accuracy evaluation settings +- name: accuracy + type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy + path: '' # Dataset file path + format: null # Dataset format (auto-detected) + samples: null # Number of samples to use + eval_method: exact_match # Accuracy evaluation method | options: exact_match, contains, judge + parser: # Column remapping: {prompt: , system: } + prompt: question + system: system_prompt + accuracy_config: # Accuracy evaluation settings + eval_method: pass_at_1 # Scorer method | options: pass_at_1, string_match, rouge, code_bench_scorer, shopify_category_f1 + ground_truth: ground_truth # Ground truth column name + extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) + num_repeats: 1 # Repeat dataset N times for evaluation +settings: + runtime: + min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) + max_duration_ms: 0 # Maximum test duration in ms (0 for no limit) + n_samples_to_issue: null # Sample count override + scheduler_random_seed: 42 # Scheduler RNG seed + dataloader_random_seed: 42 # Dataloader RNG seed + load_pattern: + type: poisson # Load pattern type | options: max_throughput, poisson, concurrency, burst, step + target_qps: 10.0 # Target QPS + target_concurrency: null # Concurrent requests + client: + num_workers: -1 # Worker processes (-1=auto) + record_worker_events: false # Record per-worker events + log_level: INFO # Worker log level + warmup_connections: -1 # Pre-establish TCP connections (-1=auto, 0=disabled) + max_connections: -1 # Max TCP connections (-1=unlimited) + transport: + type: zmq # options: zmq + recv_buffer_size: 16777216 # IPC receive buffer size in bytes (default 16MB). Increase for multimodal payloads. + send_buffer_size: 16777216 # IPC send buffer size in bytes (default 16MB). Increase for multimodal payloads. + io_threads: 4 # ZMQ I/O thread pool size (main process) + worker_io_threads: 1 # ZMQ I/O thread pool size (worker processes) + high_water_mark: 0 # ZMQ HWM (0=unlimited) + linger: -1 # ZMQ linger on close (-1=block until sent) + immediate: 1 # ZMQ IMMEDIATE (1=only enqueue on ready) + stream_all_chunks: false # Stream all chunks to main thread (caution: perf overhead) + worker_initialization_timeout: 60.0 # Worker init timeout (seconds) + worker_graceful_shutdown_wait: 0.5 # Post-run graceful shutdown wait (seconds) + worker_force_kill_timeout: 0.5 # Force kill timeout after graceful wait (seconds) + max_idle_time: 4.0 # Discard connections idle longer than this (seconds) + min_required_connections: -1 # Min connections to initialize (-1=auto, 0=disabled) + worker_gc_mode: relaxed # Worker GC strategy | options: disabled, relaxed, system +metrics: + collect: + - throughput + - latency + - ttft + - tpot +endpoint_config: + endpoints: # Endpoint URL(s) + - '' + api_key: null # API key + api_type: openai # API type: openai or sglang | options: openai, sglang +report_dir: null # Report output directory +timeout: null # Global timeout in seconds +verbose: false # Enable verbose logging +enable_cpu_affinity: true # NUMA-aware CPU pinning diff --git a/src/inference_endpoint/endpoint_client/config.py b/src/inference_endpoint/endpoint_client/config.py index 0351d85d..c5509fc2 100644 --- a/src/inference_endpoint/endpoint_client/config.py +++ b/src/inference_endpoint/endpoint_client/config.py @@ -113,19 +113,29 @@ class HTTPClientConfig(WithUpdatesMixin, BaseModel): # NOTE: # - StreamChunk.metadata['first_chunk'] is set for first chunk of every response # - At end of stream, QueryResult is returned with the entire response content - stream_all_chunks: bool = False + stream_all_chunks: bool = Field( + False, description="Stream all chunks to main thread (caution: perf overhead)" + ) # Worker lifecycle timeouts - worker_initialization_timeout: float = 60.0 # init - worker_graceful_shutdown_wait: float = 0.5 # post-run - worker_force_kill_timeout: float = 0.5 # post-run + worker_initialization_timeout: float = Field( + 60.0, description="Worker init timeout (seconds)" + ) + worker_graceful_shutdown_wait: float = Field( + 0.5, description="Post-run graceful shutdown wait (seconds)" + ) + worker_force_kill_timeout: float = Field( + 0.5, description="Force kill timeout after graceful wait (seconds)" + ) # Connection idle timeout - discard connections idle longer than this. # Two fold benefits: # 1. Prevents keep-alive race condition where server closes idle connection # at the exact moment client sends a new request (half-closed TCP). # 2. Early discard connections which are likely disconnected by the server already - max_idle_time: float = 4.0 # seconds + max_idle_time: float = Field( + 4.0, description="Discard connections idle longer than this (seconds)" + ) # Minimum required connections for http-client to initialize. # Will log warning if not enough ephemeral ports are available during warmup. @@ -134,7 +144,9 @@ class HTTPClientConfig(WithUpdatesMixin, BaseModel): # - >0 = explicit minimum required connections # - 0 = disable check (no warning if ports unavailable) # - -1 = auto (defaults to 12.5% of system ephemeral port range) - min_required_connections: int = -1 + min_required_connections: int = Field( + -1, description="Min connections to initialize (-1=auto, 0=disabled)" + ) # GC strategy for worker processes to reduce latency spikes from collection pauses # @@ -142,7 +154,9 @@ class HTTPClientConfig(WithUpdatesMixin, BaseModel): # - "disabled": GC completely disabled (risky for long-running benchmarks) # - "relaxed": GC enabled with 50x higher threshold (less aggressive) # - "system": Standard Python GC with default thresholds - worker_gc_mode: Literal["disabled", "relaxed", "system"] = "relaxed" + worker_gc_mode: Literal["disabled", "relaxed", "system"] = Field( + "relaxed", description="Worker GC strategy" + ) # ========================================================================= # Internal fields (parse=False — set programmatically, not via CLI/YAML) diff --git a/src/inference_endpoint/main.py b/src/inference_endpoint/main.py index 996b09a8..abae5064 100644 --- a/src/inference_endpoint/main.py +++ b/src/inference_endpoint/main.py @@ -102,7 +102,11 @@ def validate_yaml( @app.command(name="init") def init_cmd(template: str): - """Generate config template.""" + """Generate config template. + + Args: + template: Template type (offline, online, concurrency, eval, submission). + """ execute_init(template) diff --git a/tests/integration/commands/test_benchmark_command.py b/tests/integration/commands/test_benchmark_command.py index ddfe874a..a51afe8c 100644 --- a/tests/integration/commands/test_benchmark_command.py +++ b/tests/integration/commands/test_benchmark_command.py @@ -16,8 +16,11 @@ """Integration tests for benchmark commands against echo server.""" import json +import re +from pathlib import Path import pytest +import yaml from inference_endpoint.commands.benchmark.execute import run_benchmark from inference_endpoint.config.schema import ( BenchmarkConfig, @@ -87,7 +90,7 @@ def test_offline_benchmark( @pytest.mark.integration @pytest.mark.parametrize("streaming", [StreamingMode.OFF, StreamingMode.ON]) - def test_online_benchmark( + def test_poisson_benchmark( self, mock_http_echo_server, ds_dataset_path, caplog, streaming ): config = _config( @@ -105,6 +108,32 @@ def test_online_benchmark( assert "PoissonDistributionScheduler" in caplog.text assert "50" in caplog.text + @pytest.mark.integration + @pytest.mark.parametrize("streaming", [StreamingMode.OFF, StreamingMode.ON]) + def test_concurrency_benchmark( + self, mock_http_echo_server, ds_dataset_path, caplog, streaming + ): + config = _config( + mock_http_echo_server.url, + ds_dataset_path, + type=TestType.ONLINE, + model_params=ModelParams(name="echo-server", streaming=streaming), + settings=Settings( + runtime=RuntimeConfig(min_duration_ms=2000), + load_pattern=LoadPattern( + type=LoadPatternType.CONCURRENCY, target_concurrency=4 + ), + client=HTTPClientConfig( + num_workers=1, warmup_connections=0, max_connections=10 + ), + ), + ) + with caplog.at_level("INFO"): + run_benchmark(config, TestMode.PERF) + + assert "Completed in" in caplog.text + assert "successful" in caplog.text + @pytest.mark.integration def test_results_json_output( self, mock_http_echo_server, ds_dataset_path, tmp_path @@ -137,3 +166,59 @@ def test_mode_logging(self, mock_http_echo_server, ds_dataset_path, caplog): assert "Mode:" in caplog.text assert "QPS: 20" in caplog.text assert "Responses: False" in caplog.text + + +TEMPLATE_DIR = ( + Path(__file__).parent.parent.parent.parent + / "src" + / "inference_endpoint" + / "config" + / "templates" +) + +# Templates generated by regenerate_templates.py (excludes handwritten eval/submission) +_GENERATED_TEMPLATES = sorted( + p.name + for p in TEMPLATE_DIR.glob("*_template*.yaml") + if p.name.startswith(("offline_", "online_", "concurrency_")) +) + + +def _resolve_template(template_path: Path, server_url: str) -> dict: + """Load a template YAML, strip wrappers, and patch for testing. + + Only replaces placeholders with working values and caps n_samples_to_issue. + Everything else stays as the template defines it. + """ + raw = template_path.read_text() + # Strip → value (all templates use eg: form) + raw = re.sub(r"<[^>]*eg:\s*([^>]+)>", r"\1", raw) + # Replace endpoint URLs with the test server + raw = re.sub(r"http://localhost:\d+", server_url, raw) + data = yaml.safe_load(raw) + + # Cap total samples so test finishes in seconds + data.setdefault("settings", {}) + data["settings"].setdefault("runtime", {}) + data["settings"]["runtime"]["n_samples_to_issue"] = 10 + + # Accuracy datasets can't run e2e against echo server (no scorer), so keep only performance datasets. + data["datasets"] = [ + ds for ds in data.get("datasets", []) if ds.get("type") != "accuracy" + ] + return data + + +class TestTemplateIntegration: + """Verify generated templates run end-to-end against a local server.""" + + @pytest.mark.integration + @pytest.mark.parametrize("template", _GENERATED_TEMPLATES) + def test_template_runs(self, mock_http_echo_server, tmp_path, caplog, template): + data = _resolve_template(TEMPLATE_DIR / template, mock_http_echo_server.url) + tmp_yaml = tmp_path / template + tmp_yaml.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False)) + config = BenchmarkConfig.from_yaml_file(tmp_yaml) + with caplog.at_level("INFO"): + run_benchmark(config, TestMode.PERF) + assert "Completed in" in caplog.text diff --git a/tests/integration/commands/test_cli.py b/tests/integration/commands/test_cli.py new file mode 100644 index 00000000..2696b25d --- /dev/null +++ b/tests/integration/commands/test_cli.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLI integration tests. + +Hypothesis fuzzing: auto-discovers all CLI flags from cyclopts +``assemble_argument_collection()`` and tests random combinations through +the parser. E2E tests verify all three execution modes against echo server. +""" + +from __future__ import annotations + +import enum +import json +from typing import Literal, get_args, get_origin + +import pytest +from hypothesis import given +from hypothesis import settings as hyp_settings +from hypothesis import strategies as st +from hypothesis.strategies import DrawFn, composite +from inference_endpoint.commands.benchmark.cli import benchmark_app +from inference_endpoint.main import app +from pydantic import BaseModel + +# --------------------------------------------------------------------------- +# Flag discovery — walks cyclopts argument tree at import time. +# Zero hardcoded knowledge: all flags, types, and valid values are derived +# from the Pydantic models via assemble_argument_collection(). +# --------------------------------------------------------------------------- + + +def _enum_values(hint): + """Extract string values from enum or Literal type hints.""" + if isinstance(hint, type) and issubclass(hint, enum.Enum): + return [e.value for e in hint] + if get_origin(hint) is Literal: + return [a.value if isinstance(a, enum.Enum) else str(a) for a in get_args(hint)] + return None + + +def _discover_flags(cmd_name: str) -> list[tuple[str, list[str]]]: + """Discover (flag, [valid_values]) for every leaf CLI flag in a subcommand. + + Includes both primary names and aliases as separate entries so Hypothesis + exercises them independently. + """ + sub_app = benchmark_app.resolved_commands()[cmd_name] + result = [] + type_vals: dict[type, list[str]] = { + int: ["1", "10"], + float: ["1.0", "10.0"], + str: ["test-val"], + bool: [], + } + for arg in sub_app.assemble_argument_collection(): + flags = [n for n in arg.names if n.startswith("--")] + if not flags or arg.names == ("*",): + continue + if isinstance(arg.hint, type) and issubclass(arg.hint, BaseModel): + continue + enum_vals = _enum_values(arg.hint) + vals: list[str] = ( + [str(v) for v in enum_vals] if enum_vals else type_vals.get(arg.hint, []) + ) + result.append((flags[0], vals)) + # Aliases must also parse correctly — add them separately + for alias in flags[1:]: + result.append((alias, vals)) + return result + + +_OFFLINE_FLAGS = _discover_flags("offline") +_ONLINE_FLAGS = _discover_flags("online") + +# --------------------------------------------------------------------------- +# Hypothesis strategies — build random CLI invocations from discovered flags. +# Covers all three modes: offline, online/poisson, online/concurrency. +# --------------------------------------------------------------------------- + +_OFF = [ + "benchmark", + "offline", + "--endpoints", + "http://h:80", + "--model", + "m", + "--dataset", + "d.pkl", +] +_ON_P = [ + "benchmark", + "online", + "--endpoints", + "http://h:80", + "--model", + "m", + "--dataset", + "d.pkl", + "--load-pattern", + "poisson", + "--target-qps", + "100", +] +_ON_C = [ + "benchmark", + "online", + "--endpoints", + "http://h:80", + "--model", + "m", + "--dataset", + "d.pkl", + "--load-pattern", + "concurrency", + "--concurrency", + "10", +] + + +def _build_tokens( + draw: DrawFn, base: list[str], flags: list[tuple[str, list[str]]] +) -> list[str]: + """Append 1-10 random flags (with valid values) to a base token list.""" + n = draw(st.integers(min_value=1, max_value=10)) + chosen = draw( + st.lists( + st.sampled_from(flags), min_size=n, max_size=n, unique_by=lambda x: x[0] + ) + ) + tokens = list(base) + for flag, vals in chosen: + if flag in tokens: + continue # don't duplicate flags already in base + tokens.extend([flag, draw(st.sampled_from(vals))] if vals else [flag]) + return tokens + + +@composite +def offline_tokens(draw: DrawFn) -> list[str]: + """Random offline CLI invocation.""" + return _build_tokens(draw, _OFF, _OFFLINE_FLAGS) + + +@composite +def online_tokens(draw: DrawFn) -> list[str]: + """Random online CLI invocation — randomly picks poisson or concurrency base.""" + return _build_tokens(draw, draw(st.sampled_from([_ON_P, _ON_C])), _ONLINE_FLAGS) + + +# --------------------------------------------------------------------------- +# Hypothesis parsing fuzz +# --------------------------------------------------------------------------- + + +@pytest.mark.integration +@pytest.mark.schema_fuzz +@pytest.mark.slow +@hyp_settings(max_examples=2000, deadline=5000) +@given(tokens=offline_tokens()) +def test_offline_cli_no_crash(tokens): + """Random offline flag combos must parse or reject cleanly — never crash.""" + try: + app.parse_args(tokens) + except SystemExit: + pass # clean rejection (e.g. invalid value) is fine + + +@pytest.mark.integration +@pytest.mark.schema_fuzz +@pytest.mark.slow +@hyp_settings(max_examples=2000, deadline=5000) +@given(tokens=online_tokens()) +def test_online_cli_no_crash(tokens): + """2000 random online flag combos (poisson + concurrency) — no crashes.""" + try: + app.parse_args(tokens) + except SystemExit: + pass # clean rejection (e.g. missing required arg) is fine + + +# --------------------------------------------------------------------------- +# E2E: CLI tokens → echo server → results.json +# One test per execution mode: offline, poisson, concurrency. +# --------------------------------------------------------------------------- + +_FAST = [ + "--workers", + "1", + "--client.warmup-connections", + "0", + "--client.max-connections", + "10", +] + + +def _run(tokens: list[str]): + """Invoke the full CLI pipeline, swallowing normal SystemExit(0).""" + try: + app.meta(tokens) + except SystemExit as e: + if e.code != 0: + raise + + +def _bench(url, ds, tmp_path, *extra): + """Run a benchmark via CLI and return parsed results.json.""" + _run( + [ + *extra, + "--endpoints", + url, + "--model", + "test-model", + "--dataset", + ds, + "--report-dir", + str(tmp_path), + *_FAST, + ] + ) + return json.loads((tmp_path / "results.json").read_text()) + + +class TestE2E: + """Full CLI → benchmark execution → echo server → results.json.""" + + @pytest.mark.integration + def test_offline(self, mock_http_echo_server, ds_dataset_path, tmp_path): + """Offline (max_throughput): all queries at t=0.""" + r = _bench( + mock_http_echo_server.url, + ds_dataset_path, + tmp_path, + "benchmark", + "offline", + "--duration", + "0", + "--streaming", + "off", + ) + assert r["results"]["total"] > 0 + assert r["results"]["successful"] > 0 + + @pytest.mark.integration + def test_poisson(self, mock_http_echo_server, ds_dataset_path, tmp_path): + """Online (poisson): sustained QPS with Poisson arrival distribution.""" + r = _bench( + mock_http_echo_server.url, + ds_dataset_path, + tmp_path, + "benchmark", + "online", + "--load-pattern", + "poisson", + "--target-qps", + "50", + "--duration", + "2000", + ) + assert r["results"]["total"] > 0 + + @pytest.mark.integration + def test_concurrency(self, mock_http_echo_server, ds_dataset_path, tmp_path): + """Online (concurrency): fixed concurrent requests.""" + r = _bench( + mock_http_echo_server.url, + ds_dataset_path, + tmp_path, + "benchmark", + "online", + "--load-pattern", + "concurrency", + "--concurrency", + "4", + "--duration", + "2000", + ) + assert r["results"]["total"] > 0 diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index 2252c5bf..b99db6d2 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -367,13 +367,16 @@ class TestYAMLTemplateValidation: @pytest.mark.unit @pytest.mark.parametrize( "template", - [ - "concurrency_template.yaml", - "eval_template.yaml", - "offline_template.yaml", - "online_template.yaml", - "submission_template.yaml", - ], + sorted( + p.name + for p in ( + Path(__file__).parent.parent.parent.parent + / "src" + / "inference_endpoint" + / "config" + / "templates" + ).glob("*_template*.yaml") + ), ) def test_valid_templates_parse(self, template): config = BenchmarkConfig.from_yaml_file(TEMPLATE_DIR / template) @@ -381,6 +384,23 @@ def test_valid_templates_parse(self, template): assert config.endpoint_config.endpoints +class TestScorerMethodSync: + """Ensure ScorerMethod enum stays in sync with the scorer registry.""" + + @pytest.mark.unit + def test_scorer_enum_matches_registry(self): + from inference_endpoint.config.schema import ScorerMethod + from inference_endpoint.evaluation.scoring import Scorer + + enum_values = {m.value for m in ScorerMethod} + registry_keys = set(Scorer.PREDEFINED.keys()) + assert enum_values == registry_keys, ( + f"ScorerMethod enum out of sync with Scorer registry.\n" + f" In enum only: {enum_values - registry_keys}\n" + f" In registry only: {registry_keys - enum_values}" + ) + + class TestResponseCollector: @pytest.mark.unit def test_success_response(self): diff --git a/tests/unit/commands/test_util_commands.py b/tests/unit/commands/test_util_commands.py index db4fb826..67d0c791 100644 --- a/tests/unit/commands/test_util_commands.py +++ b/tests/unit/commands/test_util_commands.py @@ -101,7 +101,9 @@ def test_unknown_template(self): execute_init("unknown") @pytest.mark.unit - @pytest.mark.parametrize("template", ["offline", "online", "eval", "submission"]) + @pytest.mark.parametrize( + "template", ["offline", "online", "concurrency", "eval", "submission"] + ) def test_generates_template(self, template): output_file = Path(f"{template}_template.yaml") try: @@ -122,31 +124,13 @@ def test_warns_on_overwrite(self, caplog): output_file.unlink(missing_ok=True) @pytest.mark.unit - def test_fallback_when_template_missing(self, tmp_path, monkeypatch): - """When template file doesn't exist, falls back to create_default_config.""" + def test_missing_template_raises_setup_error(self, tmp_path, monkeypatch): monkeypatch.setattr( "inference_endpoint.commands.init.TEMPLATES_DIR", tmp_path / "nonexistent", ) - output_file = Path("offline_template.yaml") - try: - execute_init("offline") - assert output_file.exists() - finally: - output_file.unlink(missing_ok=True) - - @pytest.mark.unit - def test_os_error_raises_setup_error(self, monkeypatch): - monkeypatch.setattr( - "inference_endpoint.commands.init.TEMPLATES_DIR", - Path("/nonexistent"), - ) - monkeypatch.setattr( - "inference_endpoint.commands.init.BenchmarkConfig.create_default_config", - MagicMock(side_effect=OSError("permission denied")), - ) - with pytest.raises(SetupError, match="Failed to create"): - execute_init("offline") + with pytest.raises(SetupError, match="Template file not found"): + execute_init("eval") class TestProbeConfig: