From 857f596e178b99308c366f39ead31f0810edc077 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 10 Mar 2026 23:47:44 +0000 Subject: [PATCH 1/3] scaffolding --- .../remote/aegis_ai_content_safety_dataset.py | 13 +++++- .../seed_datasets/seed_dataset_provider.py | 41 +++++++++++++++---- pyrit/datasets/seed_datasets/seed_metadata.py | 33 +++++++++++++++ 3 files changed, 78 insertions(+), 9 deletions(-) create mode 100644 pyrit/datasets/seed_datasets/seed_metadata.py diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index 4b9004f772..ecc952f35a 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -11,6 +11,8 @@ ) from pyrit.models import SeedDataset, SeedPrompt +from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata + logger = logging.getLogger(__name__) @@ -107,7 +109,8 @@ def __init__( # Validate harm categories if provided if harm_categories: - invalid_categories = {cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} + invalid_categories = { + cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} if invalid_categories: raise ValueError( f"Invalid harm categories: {invalid_categories}. Valid categories are: {self.HARM_CATEGORIES}" @@ -157,7 +160,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: prompt_harm_categories = [] if violated_categories: # The violated_categories field contains comma-separated category names - categories = [cat.strip() for cat in violated_categories.split(",") if cat.strip()] + categories = [ + cat.strip() for cat in violated_categories.split(",") if cat.strip()] prompt_harm_categories = categories # Filter by harm_categories if specified @@ -186,3 +190,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: ) return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + def metadata_factory(self) -> SeedMetadata: + return SeedMetadata( + size= + ) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 56b61b3996..cb0e7ed11a 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -10,6 +10,7 @@ from tqdm import tqdm from pyrit.models.seeds import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata logger = logging.getLogger(__name__) @@ -51,6 +52,12 @@ def dataset_name(self) -> str: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """ + @abstractmethod + def metadata_factory(self) -> SeedMetadata: + """ + Build metadata from tags and derived fields (e.g. dataset size). + """ + @abstractmethod async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: """ @@ -78,10 +85,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls) -> list[str]: + def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]: """ Get the names of all registered datasets. + Args: + filters (Optional[Dict[str, str]]): List of filters to apply. + Returns: List[str]: List of dataset names from all registered providers. @@ -97,9 +107,21 @@ def get_all_dataset_names(cls) -> list[str]: try: # Instantiate to get dataset name provider = provider_class() + + # Injection point for filtering. TODO + + # 1 Remove invalid filters by checking ground truth in seed_metadata + + # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) + + # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + + # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets + # since we can't check them statically dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -142,9 +164,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -170,7 +194,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -208,10 +233,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py new file mode 100644 index 0000000000..8ac0c99fd5 --- /dev/null +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from enum import Enum +from dataclasses import dataclass + + +class DatasetLoadingRank(Enum): + """Represents the general difficulty of loading in a dataset.""" + DEFAULT = "default" + EXTENDED = "extended" + SLOW = "slow" + + +class DatasetModalities(Enum): + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + + +class DatasetSourceType(Enum): + GENERIC_URL = "generic_url" + LOCAL = "local" + HUGGING_FACE = "hugging_face" + + +@dataclass +class DatasetMetadata: + size: int + modalities: list[DatasetModalities] + source: DatasetSourceType + loading_rank: DatasetLoadingRank From 15b58e8a47fd9d673b6e58c0e7f5b01e24d52e9d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 11 Mar 2026 19:52:05 +0000 Subject: [PATCH 2/3] more scaffolding --- pyrit/datasets/__init__.py | 3 ++ .../seed_datasets/seed_dataset_provider.py | 43 +++++++++++++------ pyrit/datasets/seed_datasets/seed_metadata.py | 32 +++++++++++++- .../test_seed_dataset_provider_integration.py | 15 +++++-- .../datasets/test_seed_dataset_metadata.py | 32 ++++++++++++++ 5 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 tests/unit/datasets/test_seed_dataset_metadata.py diff --git a/pyrit/datasets/__init__.py b/pyrit/datasets/__init__.py index 5eb89b6f44..c8d8592625 100644 --- a/pyrit/datasets/__init__.py +++ b/pyrit/datasets/__init__.py @@ -8,8 +8,11 @@ from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak from pyrit.datasets.seed_datasets import local, remote # noqa: F401 from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import DatasetMetadata, DatasetFilters __all__ = [ + "DatasetMetadata", + "DatasetFilters", "SeedDatasetProvider", "TextJailBreak", ] diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index cb0e7ed11a..7e11bf5f4c 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -26,6 +26,10 @@ class SeedDatasetProvider(ABC): Subclasses must implement: - fetch_dataset(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset + + All subclasses also have a _metadata property that is optional to make + dataset addition easier, but failing to complete it makes downstream + analysis more difficult. """ _registry: dict[str, type["SeedDatasetProvider"]] = {} @@ -41,6 +45,10 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") + # Providing metadata is optional + if getattr(cls, "_metadata", False): + logger.debug( + f"Dataset provider {cls.__name__} provided metadata.") @property @abstractmethod @@ -52,12 +60,6 @@ def dataset_name(self) -> str: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """ - @abstractmethod - def metadata_factory(self) -> SeedMetadata: - """ - Build metadata from tags and derived fields (e.g. dataset size). - """ - @abstractmethod async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: """ @@ -103,21 +105,38 @@ def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() + # 1 Remove invalid filters by checking ground truth in seed_metadata + if filters: + valid_filters = [f.value for f in SeedMetadata.DatasetFilters] + # Prefer doing this to a list or set comprehension so we can raise ValueError on + # specific unsupported filters + for filter, _ in filters.items(): + if filter not in valid_filters: + raise ValueError( + f"Tried to pass invalid filter `{filter}` to SeedDatasetProvider.get_all_dataset_names!") + for provider_class in cls._registry.values(): try: # Instantiate to get dataset name provider = provider_class() - # Injection point for filtering. TODO + if filters: + # 1 Check if it has metadata + # should this be none or false + if getattr(provider, "_metadata", False): + # Skip a dataset without metadata if we have filters enabled + continue + + # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) - # 1 Remove invalid filters by checking ground truth in seed_metadata + # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition - # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) + # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets + # since we can't check them statically - # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + # Solution: If filter is dynamic, then just download or load into central memory early to retrieve it + # and present a warning to the user that this is occuring - # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets - # since we can't check them statically dataset_names.add(provider.dataset_name) except Exception as e: raise ValueError( diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 8ac0c99fd5..8ad37940d4 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -4,6 +4,20 @@ from enum import Enum from dataclasses import dataclass +""" +TODO Finish docstring + +Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). + +We have one DatasetMetadata dataclass that is our ground truth. As we instantiate datasets +using the subclass call in SeedDatasetProvider, we create DatasetMetadata and assign it to +a private variable there. + +Some fields are dynamic (e.g. loading statistics, timestamp, dataset size) and are left as +NoneType until the SeedDatasetProvider actually downloads/parses the dataset and puts it in +CentralMemory. +""" + class DatasetLoadingRank(Enum): """Represents the general difficulty of loading in a dataset.""" @@ -27,7 +41,23 @@ class DatasetSourceType(Enum): @dataclass class DatasetMetadata: + # TODO: separate dynamic fields from static fields and mark dynamic fields as None size: int modalities: list[DatasetModalities] source: DatasetSourceType - loading_rank: DatasetLoadingRank + rank: DatasetLoadingRank + + +class DatasetFilters(Enum): + # TODO: This is a bad way of extracting the fields from DatasetMetadata. + # A metaclass or even just calling getattr might be better. + SIZE = "size" + MODALITIES = "modalities" + SOURCE = "source" + RANK = "rank" + +# TODO These stubs should be moved somewhere, maybe as static methods to the metadata dataclass? + + +def _validate_filter_value(v): + """Check if the filter value given is valid.""" diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index a3ede4beab..ceacc2a860 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -37,10 +37,12 @@ async def test_fetch_dataset_integration(self, name, provider_cls): try: # Use max_examples for slow providers that fetch many remote images - provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() + provider = provider_cls( + max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() dataset = await provider.fetch_dataset(cache=False) - assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" + assert isinstance( + dataset, SeedDataset), f"{name} did not return a SeedDataset" assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" assert dataset.dataset_name, f"{name} has no dataset_name" @@ -51,7 +53,14 @@ async def test_fetch_dataset_integration(self, name, provider_cls): f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" ) - logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds") + logger.info( + f"Successfully verified {name} with {len(dataset.seeds)} seeds") except Exception as e: pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") + + @pytest.mark.asyncio + @pytest.mark.parameterize("name,provider_cls", get_dataset_providers()) + async def test_fetch_dataset_with_filtering(self, name, provider_cls): + # TODO + pass diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py new file mode 100644 index 0000000000..7f38572311 --- /dev/null +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +TODO + +Tests for SeedDatasetMetadata +""" + + +class TestMetadataParsing: + def test_invalid_filter_key(self): + pass + + def test_invalid_filter_value(self): + pass + + +class TestMetadataLifecycle: + def test_static_values_populated(self): + pass + + def test_dynamic_values_populated(self): + pass + + +class TestMetadataPerformance: + def test_quick_retrieval_for_static_values(self): + pass + + def test_acceptable_retrieval_for_dynamic_values(self): + pass From fc43c8c6f7e198ce9669f7cd8dd7047795791977 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 12 Mar 2026 00:32:21 +0000 Subject: [PATCH 3/3] . --- pyrit/datasets/seed_datasets/seed_metadata.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 8ad37940d4..b5a4070f89 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -61,3 +61,13 @@ class DatasetFilters(Enum): def _validate_filter_value(v): """Check if the filter value given is valid.""" + + +def _metadata_builder(): + """ + Force build metadata for all datasets. + Download/load into local memory. + Add a timestamp. + Add all derived attributes. + Make sure every dataset subclass has it. + """