diff --git a/.pyrit_conf_example b/.pyrit_conf_example index 46014434f8..e32dcd260c 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -30,14 +30,19 @@ memory_db_type: sqlite # # Each initializer can be specified as: # - A simple string (name only) -# - A dictionary with 'name' and optional 'args' for constructor arguments +# - A dictionary with 'name' and optional 'args' for parameters +# +# Parameters are lists of strings. Use the CLI command +# `pyrit_scan --list-initializers` to see available parameters. # # Example: # initializers: # - simple -# - name: airt +# - name: target # args: -# some_param: value +# tags: +# - default +# - scorer initializers: - simple diff --git a/build_scripts/evaluate_scorers.py b/build_scripts/evaluate_scorers.py index c303184a16..2b79ea74cf 100644 --- a/build_scripts/evaluate_scorers.py +++ b/build_scripts/evaluate_scorers.py @@ -35,9 +35,11 @@ async def evaluate_scorers() -> None: 5. Save results to scorer_evals directory """ print("Initializing PyRIT...") + target_init = TargetInitializer() + target_init.params = {"tags": ["default", "scorer"]} await initialize_pyrit_async( memory_db_type=IN_MEMORY, - initializers=[TargetInitializer(tags=["default", "scorer"]), ScorerInitializer()], + initializers=[target_init, ScorerInitializer()], ) registry = ScorerRegistry.get_registry_singleton() diff --git a/doc/code/setup/pyrit_initializer.ipynb b/doc/code/setup/pyrit_initializer.ipynb index 80b2a579cb..d0a81f1498 100644 --- a/doc/code/setup/pyrit_initializer.ipynb +++ b/doc/code/setup/pyrit_initializer.ipynb @@ -61,7 +61,7 @@ " def execution_order(self) -> int:\n", " return 2 # Lower numbers run first (default is 1)\n", "\n", - " async def initialize_async(self) -> None:\n", + " async def initialize_async(self, *, params=None) -> None:\n", " set_default_value(class_type=OpenAIChatTarget, parameter_name=\"temperature\", value=0.9)\n", "\n", " @property\n", @@ -159,7 +159,7 @@ " def execution_order(self) -> int:\n", " return 2 # Lower numbers run first (default is 1)\n", "\n", - " async def initialize_async(self) -> None:\n", + " async def initialize_async(self, *, params=None) -> None:\n", " set_default_value(class_type=OpenAIChatTarget, parameter_name=\"temperature\", value=0.9)\n", "\n", " @property\n", diff --git a/doc/code/setup/pyrit_initializer.py b/doc/code/setup/pyrit_initializer.py index c8a7d0590f..ccf2a94e8f 100644 --- a/doc/code/setup/pyrit_initializer.py +++ b/doc/code/setup/pyrit_initializer.py @@ -43,7 +43,7 @@ def name(self) -> str: def execution_order(self) -> int: return 2 # Lower numbers run first (default is 1) - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: set_default_value(class_type=OpenAIChatTarget, parameter_name="temperature", value=0.9) @property @@ -107,7 +107,7 @@ def name(self) -> str: def execution_order(self) -> int: return 2 # Lower numbers run first (default is 1) - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: set_default_value(class_type=OpenAIChatTarget, parameter_name="temperature", value=0.9) @property diff --git a/doc/setup/pyrit_conf.md b/doc/setup/pyrit_conf.md index 09ce09f77f..199e1a2ed0 100644 --- a/doc/setup/pyrit_conf.md +++ b/doc/setup/pyrit_conf.md @@ -44,7 +44,7 @@ A list of built-in initializers to run during PyRIT initialization. Initializers Each entry can be: - **A simple string** — just the initializer name -- **A dictionary** — with `name` and optional `args` for constructor arguments +- **A dictionary** — with `name` and optional `args` (each arg is a list of strings passed to `initialize_async`) Example: @@ -53,7 +53,9 @@ initializers: - simple - name: airt args: - some_param: value + tags: + - default + - scorer ``` Use `pyrit list initializers` in the CLI to see all registered initializers. See the [initializer documentation notebook](../code/setup/pyrit_initializer.ipynb) for reference. diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 7f6adbe928..d5c88dfb92 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -77,7 +77,7 @@ def __init__( config_file: Optional[Path] = None, database: Optional[str] = None, initialization_scripts: Optional[list[Path]] = None, - initializer_names: Optional[list[str]] = None, + initializer_names: Optional[list[Any]] = None, env_files: Optional[list[Path]] = None, log_level: Optional[int] = None, ): @@ -94,7 +94,9 @@ def __init__( The file uses .pyrit_conf extension but is YAML format. database: Database type (InMemory, SQLite, or AzureSQL). initialization_scripts: Optional list of initialization script paths. - initializer_names: Optional list of built-in initializer names to run. + initializer_names: Optional list of initializer entries. Each entry can be + a string name (e.g., "simple") or a dict with 'name' and optional 'args' + (e.g., {"name": "target", "args": {"tags": "default,scorer"}}). env_files: Optional list of environment file paths to load in order. log_level: Logging level constant (e.g., logging.WARNING). Defaults to logging.WARNING. @@ -130,9 +132,7 @@ def __init__( # Use canonical mapping from configuration_loader self._database = _MEMORY_DB_TYPE_MAP[config.memory_db_type] self._initialization_scripts = config._resolve_initialization_scripts() - self._initializer_names = ( - [ic.name for ic in config._initializer_configs] if config._initializer_configs else None - ) + self._initializer_configs = config._initializer_configs if config._initializer_configs else None self._env_files = config._resolve_env_files() # Lazy-loaded registries @@ -287,15 +287,20 @@ async def run_scenario_async( # Run initializers before scenario initializer_instances = None - if context._initializer_names: - print(f"Running {len(context._initializer_names)} initializer(s)...") + if context._initializer_configs: + print(f"Running {len(context._initializer_configs)} initializer(s)...") sys.stdout.flush() initializer_instances = [] - for name in context._initializer_names: - initializer_class = context.initializer_registry.get_class(name) - initializer_instances.append(initializer_class()) + for config in context._initializer_configs: + initializer_class = context.initializer_registry.get_class(config.name) + instance = initializer_class() + if config.args: + instance.params = { + k: [str(i) for i in v] if isinstance(v, list) else [str(v)] for k, v in config.args.items() + } + initializer_instances.append(instance) # Re-initialize PyRIT with the scenario-specific initializers # This resets memory and applies initializer defaults @@ -477,6 +482,13 @@ def format_initializer_metadata(*, initializer_metadata: InitializerMetadata) -> else: print(" Required Environment Variables: None") + if initializer_metadata.supported_parameters: + print(" Supported Parameters:") + for param_name, param_desc, param_required, param_default in initializer_metadata.supported_parameters: + req_str = " (required)" if param_required else "" + default_str = f" [default: {param_default}]" if param_default else "" + print(f" - {param_name}{req_str}{default_str}: {param_desc}") + if initializer_metadata.class_description: print(" Description:") print(_format_wrapped_text(text=initializer_metadata.class_description, indent=" ")) @@ -773,7 +785,10 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path "initialization scripts, and env files. CLI arguments override config file values. " "If not specified, ~/.pyrit/.pyrit_conf is loaded if it exists." ), - "initializers": "Built-in initializer names to run before the scenario (e.g., openai_objective_target)", + "initializers": ( + "Built-in initializer names to run before the scenario. " + "Supports optional params with name:key=val syntax (e.g., target:tags=default,scorer)" + ), "initialization_scripts": "Paths to custom Python initialization scripts to run before the scenario", "env_files": "Paths to environment files to load in order (e.g., .env.production .env.local). Later files " "override earlier ones.", @@ -790,6 +805,50 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path } +def _parse_initializer_arg(arg: str) -> dict[str, Any]: + """ + Parse an initializer CLI argument into a dict for ConfigurationLoader. + + Supports two formats: + - Simple name: "simple" → {"name": "simple"} + - Name with params: "target:tags=default,scorer" → {"name": "target", "args": {"tags": "default,scorer"}} + + For multiple params, separate with semicolons: "name:key1=val1;key2=val2" + + Args: + arg: The CLI argument string. + + Returns: + dict: A dict with 'name' and optionally 'args' keys. + + Raises: + ValueError: If the argument format is invalid. + """ + if ":" not in arg: + return arg # type: ignore[return-value] + + name, params_str = arg.split(":", 1) + if not name: + raise ValueError(f"Invalid initializer argument '{arg}': missing name before ':'") + + args: dict[str, list[str]] = {} + for pair in params_str.split(";"): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + raise ValueError(f"Invalid initializer parameter '{pair}' in '{arg}': expected key=value format") + key, value = pair.split("=", 1) + key = key.strip() + if not key: + raise ValueError(f"Invalid initializer parameter in '{arg}': empty key") + args[key] = [v.strip() for v in value.split(",")] + + if args: + return {"name": name, "args": args} + return name # type: ignore[return-value] + + def parse_run_arguments(*, args_string: str) -> dict[str, Any]: """ Parse run command arguments from a string (for shell mode). @@ -837,11 +896,11 @@ def parse_run_arguments(*, args_string: str) -> dict[str, Any]: i = 1 while i < len(parts): if parts[i] == "--initializers": - # Collect initializers until next flag + # Collect initializers until next flag, parsing name:key=val syntax result["initializers"] = [] i += 1 while i < len(parts) and not parts[i].startswith("--"): - result["initializers"].append(parts[i]) + result["initializers"].append(_parse_initializer_arg(parts[i])) i += 1 elif parts[i] == "--initialization-scripts": # Collect script paths until next flag diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index 3c35194c30..77522e7996 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -91,7 +91,7 @@ def parse_args(*, args: Optional[list[str]] = None) -> Namespace: parser.add_argument( "--initializers", - type=str, + type=frontend_core._parse_initializer_arg, nargs="+", help=frontend_core.ARG_HELP["initializers"], ) @@ -164,12 +164,17 @@ async def initialize_and_run_async(*, parsed_args: Namespace) -> int: # Run initializers up-front (backend runs them once at startup, not per-scenario) initializer_instances = None - if context._initializer_names: - print(f"Running {len(context._initializer_names)} initializer(s)...") + if context._initializer_configs: + print(f"Running {len(context._initializer_configs)} initializer(s)...") initializer_instances = [] - for name in context._initializer_names: - initializer_class = context.initializer_registry.get_class(name) - initializer_instances.append(initializer_class()) + for config in context._initializer_configs: + initializer_class = context.initializer_registry.get_class(config.name) + instance = initializer_class() + if config.args: + instance.params = { + k: [str(i) for i in v] if isinstance(v, list) else [str(v)] for k, v in config.args.items() + } + initializer_instances.append(instance) # Re-initialize with initializers applied await initialize_pyrit_async( @@ -185,14 +190,14 @@ async def initialize_and_run_async(*, parsed_args: Namespace) -> int: print(f"🚀 Starting PyRIT backend on http://{parsed_args.host}:{parsed_args.port}") print(f" API Docs: http://{parsed_args.host}:{parsed_args.port}/docs") - config = uvicorn.Config( + uvicorn_config = uvicorn.Config( "pyrit.backend.main:app", host=parsed_args.host, port=parsed_args.port, log_level=parsed_args.log_level, reload=parsed_args.reload, ) - server = uvicorn.Server(config) + server = uvicorn.Server(uvicorn_config) await server.serve() return 0 diff --git a/pyrit/cli/pyrit_scan.py b/pyrit/cli/pyrit_scan.py index d73992d7bc..07c1d36c9e 100644 --- a/pyrit/cli/pyrit_scan.py +++ b/pyrit/cli/pyrit_scan.py @@ -94,7 +94,7 @@ def parse_args(args: Optional[list[str]] = None) -> Namespace: parser.add_argument( "--initializers", - type=str, + type=frontend_core._parse_initializer_arg, nargs="+", help=frontend_core.ARG_HELP["initializers"], ) diff --git a/pyrit/cli/pyrit_shell.py b/pyrit/cli/pyrit_shell.py index fdf530da49..6541c33767 100644 --- a/pyrit/cli/pyrit_shell.py +++ b/pyrit/cli/pyrit_shell.py @@ -43,7 +43,7 @@ class PyRITShell(cmd.Cmd): --env-files ... Environment files to load in order - default for all runs Run Command Options: - --initializers ... Built-in initializers to run before the scenario + --initializers ... Built-in initializers (supports name:key=val1,val2 syntax) --initialization-scripts <...> Custom Python scripts to run before the scenario --env-files ... Environment files to load in order (overrides startup default) --strategies, -s ... Strategy names to use @@ -150,7 +150,7 @@ def do_run(self, line: str) -> None: run [options] Options: - --initializers ... Built-in initializers to run before the scenario + --initializers ... Built-in initializers (supports name:key=val1,val2 syntax) --initialization-scripts <...> Custom Python scripts to run before the scenario --env-files ... Environment files to load in order --strategies, -s ... Strategy names to use @@ -375,6 +375,7 @@ def do_help(self, arg: str) -> None: print(f" {frontend_core.ARG_HELP['initializers']}") print(" Every scenario requires at least one initializer") print(" Example: run foundry --initializers openai_objective_target load_default_datasets") + print(" With params: run foundry --initializers target:tags=default,scorer") print() print(" --initialization-scripts [ ...] (Alternative to --initializers)") print(f" {frontend_core.ARG_HELP['initialization_scripts']}") diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 510daae204..cea7e16203 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -50,6 +50,9 @@ class InitializerMetadata(ClassRegistryEntry): # Execution order priority (lower = earlier). execution_order: int = field(kw_only=True) + # Supported parameters as tuples of (name, description, required, default). + supported_parameters: tuple[tuple[str, str, bool, Optional[list[str]]], ...] = field(kw_only=True, default=()) + class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetadata]): """ @@ -223,6 +226,9 @@ def _build_metadata(self, name: str, entry: ClassEntry[PyRITInitializer]) -> Ini display_name=instance.name, required_env_vars=tuple(instance.required_env_vars), execution_order=instance.execution_order, + supported_parameters=tuple( + (p.name, p.description, p.required, p.default) for p in instance.supported_parameters + ), ) except Exception as e: logger.warning(f"Failed to get metadata for {name}: {e}") diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index f31e4640e6..bd472039ff 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -312,8 +312,12 @@ def _resolve_initializers(self) -> Sequence["PyRITInitializer"]: f"Initializer '{config.name}' not found in registry.\nAvailable initializers: {available}" ) - # Instantiate with args if provided - instance = initializer_class(**config.args) if config.args else initializer_class() + # Instantiate and set params if provided + instance = initializer_class() + if config.args: + instance.params = { + k: [str(i) for i in v] if isinstance(v, list) else [str(v)] for k, v in config.args.items() + } resolved.append(instance) diff --git a/pyrit/setup/initializers/__init__.py b/pyrit/setup/initializers/__init__.py index 1df84c897b..d27fc41c2a 100644 --- a/pyrit/setup/initializers/__init__.py +++ b/pyrit/setup/initializers/__init__.py @@ -6,13 +6,14 @@ from pyrit.setup.initializers.airt import AIRTInitializer from pyrit.setup.initializers.components.scorers import ScorerInitializer from pyrit.setup.initializers.components.targets import TargetInitializer -from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer +from pyrit.setup.initializers.pyrit_initializer import InitializerParameter, PyRITInitializer from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets from pyrit.setup.initializers.scenarios.objective_list import ScenarioObjectiveListInitializer from pyrit.setup.initializers.scenarios.openai_objective_target import ScenarioObjectiveTargetInitializer from pyrit.setup.initializers.simple import SimpleInitializer __all__ = [ + "InitializerParameter", "PyRITInitializer", "AIRTInitializer", "ScorerInitializer", diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index 96740565d8..dd5980789c 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -10,6 +10,7 @@ import os from collections.abc import Callable +from typing import Optional from pyrit.auth import get_azure_openai_auth, get_azure_token_provider from pyrit.common.apply_defaults import set_default_value, set_global_variable @@ -92,7 +93,7 @@ def required_env_vars(self) -> list[str]: "AZURE_CONTENT_SAFETY_API_ENDPOINT", ] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """ Execute the complete AIRT initialization. diff --git a/pyrit/setup/initializers/components/scorers.py b/pyrit/setup/initializers/components/scorers.py index d7bc220037..09f2834e0c 100644 --- a/pyrit/setup/initializers/components/scorers.py +++ b/pyrit/setup/initializers/components/scorers.py @@ -31,7 +31,7 @@ TrueFalseQuestionPaths, TrueFalseScoreAggregator, ) -from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer +from pyrit.setup.initializers.pyrit_initializer import InitializerParameter, PyRITInitializer if TYPE_CHECKING: from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget @@ -39,7 +39,9 @@ logger = logging.getLogger(__name__) # Shared tag type with TargetInitializer -ScorerTag = Literal["default"] +ScorerTag = Literal["default", "all"] + +ALL_SCORER_TAGS: list[str] = ["default"] # Target registry names used by scorer configurations. GPT4O_TARGET: str = "azure_openai_gpt4o" @@ -77,6 +79,10 @@ class ScorerInitializer(PyRITInitializer): so this initializer must run after the target initializer (enforced via execution_order). Scorers that fail to initialize (e.g., due to missing targets) are skipped with a warning. + Supported Parameters: + tags: Tags for filtering scorers. Defaults to ["default"]. + "all" registers all scorers regardless of tag. + Example: initializer = ScorerInitializer() await initializer.initialize_async() @@ -84,15 +90,16 @@ class ScorerInitializer(PyRITInitializer): refusal = registry.get_instance_by_name(REFUSAL_GPT4O) """ - def __init__(self, *, tags: list[ScorerTag] | None = None) -> None: - """ - Initialize the Scorer Initializer. - - Args: - tags (list[ScorerTag] | None): Tags for future filtering. Defaults to ["default"]. - """ - super().__init__() - self._tags = tags if tags is not None else ["default"] + @property + def supported_parameters(self) -> list[InitializerParameter]: + """Get the list of parameters this initializer accepts.""" + return [ + InitializerParameter( + name="tags", + description="Tags for filtering (e.g., ['default'] or ['all'])", + default=["default"], + ), + ] @property def name(self) -> str: @@ -127,13 +134,21 @@ def required_env_vars(self) -> list[str]: """ return [] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """ Register available scorers using targets from the TargetRegistry. + Args: + params: Optional parameters. Supports 'tags' (list of tag names). + Raises: RuntimeError: If the TargetRegistry is empty or hasn't been initialized. """ + params = params or {} + tags = params.get("tags", ["default"]) + if "all" in tags: + tags = ALL_SCORER_TAGS + target_registry = TargetRegistry.get_registry_singleton() if len(target_registry) == 0: diff --git a/pyrit/setup/initializers/components/targets.py b/pyrit/setup/initializers/components/targets.py index 3cc42de3ee..5aa0a0081a 100644 --- a/pyrit/setup/initializers/components/targets.py +++ b/pyrit/setup/initializers/components/targets.py @@ -31,13 +31,15 @@ RealtimeTarget, ) from pyrit.registry import TargetRegistry -from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer +from pyrit.setup.initializers.pyrit_initializer import InitializerParameter, PyRITInitializer logger = logging.getLogger(__name__) # Literal type for target tags -TargetTag = Literal["default", "scorer"] +TargetTag = Literal["default", "scorer", "all"] + +ALL_TARGET_TAGS: list[str] = ["default", "scorer"] @dataclass @@ -366,11 +368,11 @@ class TargetInitializer(PyRITInitializer): the corresponding targets into the TargetRegistry. Targets can be filtered by tags to control which targets are registered. - Args: - tags: List of tags to filter which targets to register. + Supported Parameters: + tags: Target tags to register (list of strings). "default" registers the base environment targets. "scorer" registers scorer-specific temperature variant targets. - Pass multiple tags to register targets matching any tag. + "all" registers all targets regardless of tag. If not provided, only "default" targets are registered. Supported Endpoints by Category: @@ -426,20 +428,19 @@ class TargetInitializer(PyRITInitializer): await initializer.initialize_async() # Register scorer temperature variants too - initializer = TargetInitializer(tags=["default", "scorer"]) - await initializer.initialize_async() + await initializer.initialize_async(params={"tags": ["default", "scorer"]}) """ - def __init__(self, *, tags: list[TargetTag] | None = None) -> None: - """ - Initialize the Target Initializer. - - Args: - tags (list[TargetTag] | None): Tags to filter which targets to register. - If None, only "default" targets are registered. - """ - super().__init__() - self._tags = tags if tags is not None else ["default"] + @property + def supported_parameters(self) -> list[InitializerParameter]: + """Get the list of parameters this initializer accepts.""" + return [ + InitializerParameter( + name="tags", + description="Target tags to register (e.g., ['default'], ['default', 'scorer'], or ['all'])", + default=["default"], + ), + ] @property def name(self) -> str: @@ -469,16 +470,24 @@ def required_env_vars(self) -> list[str]: """ return [] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """ Register available targets based on environment variables. Scans for known endpoint environment variables and registers the corresponding targets into the TargetRegistry. Only targets with tags matching the configured tags are registered. + + Args: + params: Optional parameters. Supports 'tags' (list of tag names). """ + params = params or {} + tags = params.get("tags", ["default"]) + if "all" in tags: + tags = ALL_TARGET_TAGS + for config in TARGET_CONFIGS: - if not any(tag in self._tags for tag in config.tags): + if not any(tag in tags for tag in config.tags): continue self._register_target(config) diff --git a/pyrit/setup/initializers/pyrit_initializer.py b/pyrit/setup/initializers/pyrit_initializer.py index e4bff8c6aa..95ecd69d97 100644 --- a/pyrit/setup/initializers/pyrit_initializer.py +++ b/pyrit/setup/initializers/pyrit_initializer.py @@ -12,11 +12,33 @@ from abc import ABC, abstractmethod from collections.abc import Iterator from contextlib import contextmanager, suppress -from typing import Any +from dataclasses import dataclass +from typing import Any, Optional from pyrit.common.apply_defaults import get_global_default_values +@dataclass(frozen=True) +class InitializerParameter: + """ + Describes a parameter that an initializer accepts. + + Each parameter value is a list of strings, which works naturally with + CLI (comma-separated), YAML (lists), and programmatic APIs. + + Args: + name: The parameter name (used as key in the params dict). + description: Human-readable description of the parameter. + required: Whether the parameter must be provided. Defaults to False. + default: Default value if not provided. Defaults to None. + """ + + name: str + description: str + required: bool = False + default: Optional[list[str]] = None + + class PyRITInitializer(ABC): """ Abstract base class for PyRIT configuration initializers. @@ -32,6 +54,7 @@ class PyRITInitializer(ABC): def __init__(self) -> None: # noqa: B027 """Initialize the PyRIT initializer with no parameters.""" + self.params: dict[str, list[str]] = {} @property @abstractmethod @@ -89,25 +112,44 @@ def execution_order(self) -> int: """ return 1 + @property + def supported_parameters(self) -> list[InitializerParameter]: + """ + Get the list of parameters this initializer accepts. + + Override this property to declare what parameters the initializer + supports. Parameters are passed as a dict[str, list[str]] to initialize_async(). + + Returns: + list[InitializerParameter]: List of supported parameters. Defaults to empty list. + """ + return [] + @abstractmethod - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """ Execute the initialization logic asynchronously. This method should contain all the configuration logic, including calls to set_default_value() and set_global_variable() as needed. All initializers must implement this as an async method. + + Args: + params: Optional dictionary of string-list parameters. + Use supported_parameters to declare which params are accepted. """ def validate(self) -> None: """ Validate the initializer configuration before execution. - This method checks that all required environment variables are set. + This method checks that all required environment variables are set + and validates any configured parameters against supported_parameters. Subclasses should not override this method. Raises: - ValueError: If required environment variables are not set. + ValueError: If required environment variables are not set or + if configured parameters are invalid. """ import os @@ -118,16 +160,57 @@ def validate(self) -> None: f"{', '.join(missing_vars)}" ) + # Validate configured params + if self.params: + self._validate_params(params=self.params) + + def _validate_params(self, *, params: dict[str, list[str]]) -> None: + """ + Validate parameters against supported_parameters. + + Checks that all provided params are declared in supported_parameters + and that all required params are present. + + Args: + params: The parameters to validate. + + Raises: + ValueError: If unknown parameters are provided or required parameters are missing. + """ + supported = {p.name: p for p in self.supported_parameters} + supported_names = set(supported.keys()) + + # Check for unknown params + unknown = set(params.keys()) - supported_names + if unknown: + raise ValueError( + f"Initializer '{self.name}' received unknown parameter(s): {', '.join(sorted(unknown))}. " + f"Supported parameters: {', '.join(sorted(supported_names)) if supported_names else 'none'}" + ) + + # Check for missing required params + for param_def in self.supported_parameters: + if param_def.required and param_def.name not in params: + raise ValueError( + f"Initializer '{self.name}' requires parameter '{param_def.name}': {param_def.description}" + ) + async def initialize_with_tracking_async(self) -> None: """ Execute initialization while tracking what changes are made. - This method runs initialize_async() and captures information about what - default values and global variables were set. The tracking information - is not cached - it's captured during the actual initialization run. + This method runs initialize_async() with stored params and captures + information about what default values and global variables were set. + The tracking information is not cached - it's captured during the actual + initialization run. """ with self._track_initialization_changes(): - await self.initialize_async() + params = self.params if self.params else None + try: + await self.initialize_async(params=params) + except TypeError: + # Backward compatibility: old-style initializers without params argument + await self.initialize_async() @contextmanager def _track_initialization_changes(self) -> Iterator[dict[str, Any]]: @@ -210,7 +293,11 @@ async def get_dynamic_default_values_info_async(self) -> dict[str, Any]: try: # Run initialization in sandbox with tracking (starting from empty state) with self._track_initialization_changes() as tracking_info: - await self.initialize_async() + params = self.params if self.params else None + try: + await self.initialize_async(params=params) + except TypeError: + await self.initialize_async() return tracking_info @@ -265,6 +352,18 @@ async def get_info_async(cls) -> dict[str, Any]: "execution_order": instance.execution_order, } + # Add supported parameters if any are declared + if instance.supported_parameters: + base_info["supported_parameters"] = [ + { + "name": p.name, + "description": p.description, + "required": p.required, + "default": p.default, + } + for p in instance.supported_parameters + ] + # Add required environment variables if any are defined if instance.required_env_vars: base_info["required_env_vars"] = instance.required_env_vars diff --git a/pyrit/setup/initializers/scenarios/load_default_datasets.py b/pyrit/setup/initializers/scenarios/load_default_datasets.py index 0055736372..c27cca5898 100644 --- a/pyrit/setup/initializers/scenarios/load_default_datasets.py +++ b/pyrit/setup/initializers/scenarios/load_default_datasets.py @@ -10,6 +10,7 @@ import logging import textwrap +from typing import Optional from pyrit.datasets import SeedDatasetProvider from pyrit.memory import CentralMemory @@ -50,7 +51,7 @@ def required_env_vars(self) -> list[str]: """Return the list of required environment variables.""" return [] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """Load default datasets from all registered scenarios.""" # Get ScenarioRegistry to discover all scenarios registry = ScenarioRegistry.get_registry_singleton() diff --git a/pyrit/setup/initializers/scenarios/objective_list.py b/pyrit/setup/initializers/scenarios/objective_list.py index a07ca9024c..61eb460901 100644 --- a/pyrit/setup/initializers/scenarios/objective_list.py +++ b/pyrit/setup/initializers/scenarios/objective_list.py @@ -10,6 +10,8 @@ should prefer using dataset_config in initialize_async for more flexibility. """ +from typing import Optional + from pyrit.common.apply_defaults import set_default_value from pyrit.scenario import Scenario from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer @@ -33,7 +35,7 @@ def required_env_vars(self) -> list[str]: """Return an empty list because this initializer requires no environment variables.""" return [] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """Set default objectives for scenarios that accept them (deprecated).""" # This uses the deprecated 'objectives' parameter which will emit warnings. # Users should prefer using dataset_config in initialize_async instead. diff --git a/pyrit/setup/initializers/scenarios/openai_objective_target.py b/pyrit/setup/initializers/scenarios/openai_objective_target.py index 6fc6ae4315..a0f28ecb1b 100644 --- a/pyrit/setup/initializers/scenarios/openai_objective_target.py +++ b/pyrit/setup/initializers/scenarios/openai_objective_target.py @@ -11,6 +11,7 @@ """ import os +from typing import Optional from pyrit.common.apply_defaults import set_default_value from pyrit.prompt_target import OpenAIChatTarget @@ -48,7 +49,7 @@ def required_env_vars(self) -> list[str]: "DEFAULT_OPENAI_FRONTEND_KEY", ] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """Set default objective target for scenarios that accept them.""" objective_target = OpenAIChatTarget( endpoint=os.getenv("DEFAULT_OPENAI_FRONTEND_ENDPOINT"), diff --git a/pyrit/setup/initializers/simple.py b/pyrit/setup/initializers/simple.py index b50b5f8710..32b1acec10 100644 --- a/pyrit/setup/initializers/simple.py +++ b/pyrit/setup/initializers/simple.py @@ -9,6 +9,7 @@ """ import os +from typing import Optional from pyrit.common.apply_defaults import set_default_value, set_global_variable from pyrit.executor.attack import ( @@ -112,7 +113,7 @@ def _get_api_key(self): # type: ignore[no-untyped-def] return get_azure_openai_auth(endpoint) - async def initialize_async(self) -> None: + async def initialize_async(self, *, params: Optional[dict[str, list[str]]] = None) -> None: """ Execute the complete simple initialization. diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 1ef19b7985..95f1fb8a4e 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -24,7 +24,7 @@ def test_init_with_defaults(self): assert context._database == frontend_core.SQLITE assert context._initialization_scripts is None - assert context._initializer_names is None + assert context._initializer_configs is None assert context._log_level == logging.WARNING assert context._initialized is False @@ -45,7 +45,8 @@ def test_init_with_all_parameters(self): assert context._initialization_scripts is not None assert len(context._initialization_scripts) == 1 assert context._initialization_scripts[0].parts[-2:] == ("test", "script.py") - assert context._initializer_names == initializers + assert context._initializer_configs is not None + assert [ic.name for ic in context._initializer_configs] == initializers assert context._log_level == logging.DEBUG def test_init_with_invalid_database(self): @@ -516,6 +517,49 @@ def test_format_initializer_metadata_with_description(self, capsys) -> None: assert "Test description" in captured.out +class TestParseInitializerArg: + """Tests for _parse_initializer_arg function.""" + + def test_simple_name_returns_string(self) -> None: + """Test that a plain name without ':' returns the string as-is.""" + assert frontend_core._parse_initializer_arg("simple") == "simple" + + def test_name_with_single_param(self) -> None: + """Test name:key=value parsing.""" + result = frontend_core._parse_initializer_arg("target:tags=default") + assert result == {"name": "target", "args": {"tags": ["default"]}} + + def test_name_with_comma_separated_values(self) -> None: + """Test that comma-separated values are split into a list.""" + result = frontend_core._parse_initializer_arg("target:tags=default,scorer") + assert result == {"name": "target", "args": {"tags": ["default", "scorer"]}} + + def test_name_with_multiple_params(self) -> None: + """Test semicolon-separated multiple params.""" + result = frontend_core._parse_initializer_arg("target:tags=default;mode=strict") + assert result == {"name": "target", "args": {"tags": ["default"], "mode": ["strict"]}} + + def test_missing_name_before_colon_raises(self) -> None: + """Test that ':key=val' with no name raises ValueError.""" + with pytest.raises(ValueError, match="missing name before ':'"): + frontend_core._parse_initializer_arg(":tags=default") + + def test_missing_equals_in_param_raises(self) -> None: + """Test that 'name:badparam' without '=' raises ValueError.""" + with pytest.raises(ValueError, match="expected key=value format"): + frontend_core._parse_initializer_arg("target:badparam") + + def test_empty_key_raises(self) -> None: + """Test that 'name:=value' with empty key raises ValueError.""" + with pytest.raises(ValueError, match="empty key"): + frontend_core._parse_initializer_arg("target:=value") + + def test_colon_but_no_params_returns_string(self) -> None: + """Test that 'name:' with trailing colon but no params returns the name string.""" + result = frontend_core._parse_initializer_arg("target:") + assert result == "target" + + class TestParseRunArguments: """Tests for parse_run_arguments function.""" @@ -534,6 +578,31 @@ def test_parse_run_arguments_with_initializers(self): assert result["scenario_name"] == "test_scenario" assert result["initializers"] == ["init1", "init2"] + def test_parse_run_arguments_with_initializer_params(self): + """Test parsing initializers with key=value params.""" + result = frontend_core.parse_run_arguments( + args_string="test_scenario --initializers simple target:tags=default" + ) + + assert result["initializers"][0] == "simple" + assert result["initializers"][1] == {"name": "target", "args": {"tags": ["default"]}} + + def test_parse_run_arguments_with_initializer_multiple_params(self): + """Test parsing initializers with multiple key=value params separated by semicolons.""" + result = frontend_core.parse_run_arguments( + args_string="test_scenario --initializers target:tags=default;mode=strict" + ) + + assert result["initializers"][0] == {"name": "target", "args": {"tags": ["default"], "mode": ["strict"]}} + + def test_parse_run_arguments_with_initializer_comma_list(self): + """Test parsing initializer params with comma-separated values into lists.""" + result = frontend_core.parse_run_arguments( + args_string="test_scenario --initializers target:tags=default,scorer" + ) + + assert result["initializers"][0] == {"name": "target", "args": {"tags": ["default", "scorer"]}} + def test_parse_run_arguments_with_strategies(self): """Test parsing with strategies.""" result = frontend_core.parse_run_arguments(args_string="test_scenario --strategies s1 s2") diff --git a/tests/unit/cli/test_pyrit_backend.py b/tests/unit/cli/test_pyrit_backend.py index d99744f138..a66e0e2422 100644 --- a/tests/unit/cli/test_pyrit_backend.py +++ b/tests/unit/cli/test_pyrit_backend.py @@ -42,7 +42,7 @@ async def test_initialize_and_run_passes_config_file_to_frontend_core(self) -> N ): mock_core = MagicMock() mock_core.initialize_async = AsyncMock() - mock_core._initializer_names = None + mock_core._initializer_configs = None mock_core_class.return_value = mock_core mock_server = MagicMock() diff --git a/tests/unit/setup/test_pyrit_initializer.py b/tests/unit/setup/test_pyrit_initializer.py index d4c22d82d0..c23e66c176 100644 --- a/tests/unit/setup/test_pyrit_initializer.py +++ b/tests/unit/setup/test_pyrit_initializer.py @@ -10,7 +10,7 @@ set_default_value, set_global_variable, ) -from pyrit.setup.initializers import PyRITInitializer +from pyrit.setup.initializers import InitializerParameter, PyRITInitializer class TestPyRITInitializerBase: @@ -46,7 +46,7 @@ def name(self) -> str: def description(self) -> str: return "Concrete initializer" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = ConcreteInitializer() @@ -60,7 +60,7 @@ class MissingName(PyRITInitializer): def description(self) -> str: return "Missing name" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass with pytest.raises(TypeError): @@ -93,7 +93,7 @@ def name(self) -> str: def description(self) -> str: return "Default order" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = DefaultOrder() @@ -115,7 +115,7 @@ def description(self) -> str: def execution_order(self) -> int: return 5 - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = CustomOrder() @@ -133,7 +133,7 @@ def name(self) -> str: def description(self) -> str: return "No env vars" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = NoEnvVars() @@ -155,7 +155,7 @@ def description(self) -> str: def required_env_vars(self): return ["API_KEY", "ENDPOINT"] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = WithEnvVars() @@ -173,7 +173,7 @@ def name(self) -> str: def description(self) -> str: return "Default validation" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = DefaultValidate() @@ -195,7 +195,7 @@ def description(self) -> str: def validate(self) -> None: raise ValueError("Validation failed") - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = CustomValidate() @@ -231,7 +231,7 @@ def name(self) -> str: def description(self) -> str: return "Trackable init" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: nonlocal executed executed = True @@ -255,7 +255,7 @@ def name(self) -> str: def description(self) -> str: return "Tracking defaults" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: set_default_value(class_type=DummyClass, parameter_name="value", value="tracked") init = TrackingInit() @@ -279,7 +279,7 @@ def name(self) -> str: def description(self) -> str: return "Sets global var" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: set_global_variable(name="tracked_var", value="test_value") init = GlobalVarInit() @@ -313,7 +313,7 @@ def name(self) -> str: def description(self) -> str: return "For testing get_info" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass info = await InfoInit.get_info_async() @@ -331,7 +331,7 @@ def name(self) -> str: def description(self) -> str: return "Basic description" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass info = await BasicInfoInit.get_info_async() @@ -361,7 +361,7 @@ def description(self) -> str: def required_env_vars(self): return ["API_KEY"] - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass info = await EnvVarsInit.get_info_async() @@ -380,7 +380,7 @@ def name(self) -> str: def description(self) -> str: return "No env vars" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass info = await NoEnvVarsInit.get_info_async() @@ -399,7 +399,7 @@ def name(self) -> str: def description(self) -> str: return "For class method test" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass # Should work without creating an instance @@ -418,7 +418,7 @@ def name(self) -> str: def description(self) -> str: return "Sets defaults" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass info = await DefaultsInit.get_info_async() @@ -468,7 +468,7 @@ def name(self) -> str: def description(self) -> str: return "Sets both defaults and globals" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: # Set default values set_default_value(class_type=DummyTarget, parameter_name="endpoint", value="custom_endpoint") set_default_value(class_type=DummyConverter, parameter_name="target", value="custom_target") @@ -503,7 +503,7 @@ def name(self) -> str: def description(self) -> str: return "Sets nothing" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass # Don't set anything info = await EmptyInit.get_info_async() @@ -539,7 +539,7 @@ def name(self) -> str: def description(self) -> str: return "Dynamic info" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = DynamicInit() @@ -558,7 +558,7 @@ def name(self) -> str: def description(self) -> str: return "For keys test" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = KeysInit() @@ -582,7 +582,7 @@ def name(self) -> str: def description(self) -> str: return "Captures defaults" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: set_default_value(class_type=DummyClass, parameter_name="value", value="captured") init = DefaultsInit() @@ -603,7 +603,7 @@ def name(self) -> str: def description(self) -> str: return "Captures globals" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: set_global_variable(name="dynamic_test_var", value="captured") init = GlobalsInit() @@ -630,7 +630,7 @@ def name(self) -> str: def description(self) -> str: return "Restores state" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: set_default_value(class_type=DummyClass, parameter_name="other_value", value="temporary") init = RestoringInit() @@ -663,7 +663,7 @@ def name(self) -> str: def description(self) -> str: return "No memory initialized" - async def initialize_async(self) -> None: + async def initialize_async(self, *, params=None) -> None: pass init = NoMemoryInit() @@ -672,3 +672,164 @@ async def initialize_async(self) -> None: # Should return helpful messages assert "await initialize_pyrit_async()" in str(info["default_values"]) assert "await initialize_pyrit_async()" in str(info["global_variables"]) + + +class TestSupportedParameters: + """Tests for the parameter system on PyRITInitializer.""" + + def test_default_supported_parameters_is_empty(self) -> None: + """Test that base class returns empty supported_parameters.""" + + class NoParamsInit(PyRITInitializer): + @property + def name(self) -> str: + return "no_params" + + async def initialize_async(self, *, params=None) -> None: + pass + + init = NoParamsInit() + assert init.supported_parameters == [] + + def test_supported_parameters_declared(self) -> None: + """Test that subclass can declare supported_parameters.""" + + class WithParamsInit(PyRITInitializer): + @property + def name(self) -> str: + return "with_params" + + @property + def supported_parameters(self) -> list: + return [ + InitializerParameter(name="mode", description="Operation mode", default="fast"), + InitializerParameter(name="count", description="Item count", required=True), + ] + + async def initialize_async(self, *, params=None) -> None: + pass + + init = WithParamsInit() + assert len(init.supported_parameters) == 2 + assert init.supported_parameters[0].name == "mode" + assert init.supported_parameters[1].required is True + + def test_validate_params_raises_on_unknown(self) -> None: + """Test that unknown params raise ValueError.""" + + class StrictInit(PyRITInitializer): + @property + def name(self) -> str: + return "strict" + + @property + def supported_parameters(self) -> list: + return [InitializerParameter(name="level", description="Level")] + + async def initialize_async(self, *, params=None) -> None: + pass + + init = StrictInit() + with pytest.raises(ValueError, match="unknown parameter"): + init._validate_params(params={"bogus": ["value"]}) + + def test_validate_params_raises_on_missing_required(self) -> None: + """Test that missing required params raise ValueError.""" + + class RequiredInit(PyRITInitializer): + @property + def name(self) -> str: + return "required" + + @property + def supported_parameters(self) -> list: + return [InitializerParameter(name="key", description="API key", required=True)] + + async def initialize_async(self, *, params=None) -> None: + pass + + init = RequiredInit() + with pytest.raises(ValueError, match="requires parameter 'key'"): + init._validate_params(params={}) + + def test_validate_params_accepts_valid(self) -> None: + """Test that valid params pass validation.""" + + class ValidInit(PyRITInitializer): + @property + def name(self) -> str: + return "valid" + + @property + def supported_parameters(self) -> list: + return [ + InitializerParameter(name="mode", description="Mode", default="fast"), + InitializerParameter(name="key", description="Key", required=True), + ] + + async def initialize_async(self, *, params=None) -> None: + pass + + init = ValidInit() + # Should not raise + init._validate_params(params={"key": ["abc"], "mode": ["slow"]}) + + def test_validate_checks_params_on_instance(self) -> None: + """Test that validate() checks self.params.""" + + class ParamInit(PyRITInitializer): + @property + def name(self) -> str: + return "param_init" + + @property + def supported_parameters(self) -> list: + return [InitializerParameter(name="x", description="X")] + + async def initialize_async(self, *, params=None) -> None: + pass + + init = ParamInit() + init.params = {"unknown_key": ["val"]} + with pytest.raises(ValueError, match="unknown parameter"): + init.validate() + + @pytest.mark.asyncio + async def test_params_passed_to_initialize_async(self) -> None: + """Test that params are forwarded from initialize_with_tracking_async.""" + + received_params = {} + + class TrackingInit(PyRITInitializer): + @property + def name(self) -> str: + return "tracking" + + async def initialize_async(self, *, params=None) -> None: + if params: + received_params.update(params) + + init = TrackingInit() + init.params = {"tags": ["default", "scorer"]} + await init.initialize_with_tracking_async() + + assert received_params == {"tags": ["default", "scorer"]} + + @pytest.mark.asyncio + async def test_empty_params_passes_none(self) -> None: + """Test that empty _params passes None to initialize_async.""" + + received = {"called_with": "not_set"} + + class EmptyParamsInit(PyRITInitializer): + @property + def name(self) -> str: + return "empty" + + async def initialize_async(self, *, params=None) -> None: + received["called_with"] = params + + init = EmptyParamsInit() + await init.initialize_with_tracking_async() + + assert received["called_with"] is None diff --git a/tests/unit/setup/test_scorer_initializer.py b/tests/unit/setup/test_scorer_initializer.py index 7c8846e127..bdb26cd8a8 100644 --- a/tests/unit/setup/test_scorer_initializer.py +++ b/tests/unit/setup/test_scorer_initializer.py @@ -195,6 +195,30 @@ async def test_gracefully_skips_scorers_with_missing_target(self) -> None: assert registry.get_instance_by_name("inverted_refusal_gpt4o_unsafe_temp9") is None assert registry.get_instance_by_name("refusal_gpt4o") is not None + @pytest.mark.asyncio + async def test_all_tag_registers_all_scorers(self) -> None: + """Test that tags=['all'] registers all scorers (bypasses tag filtering).""" + self._register_all_scorer_targets() + os.environ.update(self.CONTENT_SAFETY_ENV_VARS) + + init = ScorerInitializer() + await init.initialize_async(params={"tags": ["all"]}) + + registry = ScorerRegistry.get_registry_singleton() + assert len(registry) == 24 + + @pytest.mark.asyncio + async def test_default_tag_registers_all_current_scorers(self) -> None: + """Test that tags=['default'] registers all current scorers (all are tagged default).""" + self._register_all_scorer_targets() + os.environ.update(self.CONTENT_SAFETY_ENV_VARS) + + init = ScorerInitializer() + await init.initialize_async(params={"tags": ["default"]}) + + registry = ScorerRegistry.get_registry_singleton() + assert len(registry) == 24 + class TestScorerInitializerGetInfo: """Tests for ScorerInitializer.get_info_async method.""" diff --git a/tests/unit/setup/test_targets_initializer.py b/tests/unit/setup/test_targets_initializer.py index 213268ccc4..a8c831cb46 100644 --- a/tests/unit/setup/test_targets_initializer.py +++ b/tests/unit/setup/test_targets_initializer.py @@ -286,8 +286,8 @@ async def test_default_tag_excludes_scorer_targets(self) -> None: os.environ["AZURE_OPENAI_GPT4O_KEY"] = "test_key" os.environ["AZURE_OPENAI_GPT4O_MODEL"] = "gpt-4o" - init = TargetInitializer(tags=["default"]) - await init.initialize_async() + init = TargetInitializer() + await init.initialize_async(params={"tags": ["default"]}) registry = TargetRegistry.get_registry_singleton() assert registry.get_instance_by_name("azure_openai_gpt4o") is not None @@ -305,8 +305,8 @@ async def test_scorer_tag_only_registers_scorer_targets(self) -> None: os.environ["AZURE_OPENAI_GPT4O_KEY"] = "test_key" os.environ["AZURE_OPENAI_GPT4O_MODEL"] = "gpt-4o" - init = TargetInitializer(tags=["scorer"]) - await init.initialize_async() + init = TargetInitializer() + await init.initialize_async(params={"tags": ["scorer"]}) registry = TargetRegistry.get_registry_singleton() assert registry.get_instance_by_name("azure_openai_gpt4o") is None @@ -324,8 +324,27 @@ async def test_multiple_tags_registers_matching(self) -> None: os.environ["AZURE_OPENAI_GPT4O_KEY"] = "test_key" os.environ["AZURE_OPENAI_GPT4O_MODEL"] = "gpt-4o" - init = TargetInitializer(tags=["default", "scorer"]) - await init.initialize_async() + init = TargetInitializer() + await init.initialize_async(params={"tags": ["default", "scorer"]}) + + registry = TargetRegistry.get_registry_singleton() + assert registry.get_instance_by_name("azure_openai_gpt4o") is not None + assert registry.get_instance_by_name("azure_openai_gpt4o_temp9") is not None + + # Clean up + del os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] + del os.environ["AZURE_OPENAI_GPT4O_KEY"] + del os.environ["AZURE_OPENAI_GPT4O_MODEL"] + + @pytest.mark.asyncio + async def test_all_tag_registers_all_targets(self) -> None: + """Test that tags=['all'] registers both default and scorer targets.""" + os.environ["AZURE_OPENAI_GPT4O_ENDPOINT"] = "https://test.openai.azure.com" + os.environ["AZURE_OPENAI_GPT4O_KEY"] = "test_key" + os.environ["AZURE_OPENAI_GPT4O_MODEL"] = "gpt-4o" + + init = TargetInitializer() + await init.initialize_async(params={"tags": ["all"]}) registry = TargetRegistry.get_registry_singleton() assert registry.get_instance_by_name("azure_openai_gpt4o") is not None