Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyrit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak
from pyrit.datasets.seed_datasets import local, remote # noqa: F401
from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider
from pyrit.datasets.seed_datasets.seed_metadata import DatasetMetadata, DatasetFilters

__all__ = [
"DatasetMetadata",
"DatasetFilters",
"SeedDatasetProvider",
"TextJailBreak",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
)
from pyrit.models import SeedDataset, SeedPrompt

from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -107,7 +109,8 @@ def __init__(

# Validate harm categories if provided
if harm_categories:
invalid_categories = {cat for cat in harm_categories if cat not in self.HARM_CATEGORIES}
invalid_categories = {
cat for cat in harm_categories if cat not in self.HARM_CATEGORIES}
if invalid_categories:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we likely still want to load these; should we use a default harm category here?

f"Invalid harm categories: {invalid_categories}. Valid categories are: {self.HARM_CATEGORIES}"
Expand Down Expand Up @@ -157,7 +160,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:
prompt_harm_categories = []
if violated_categories:
# The violated_categories field contains comma-separated category names
categories = [cat.strip() for cat in violated_categories.split(",") if cat.strip()]
categories = [
cat.strip() for cat in violated_categories.split(",") if cat.strip()]
prompt_harm_categories = categories

# Filter by harm_categories if specified
Expand Down Expand Up @@ -186,3 +190,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:
)

return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name)

def metadata_factory(self) -> SeedMetadata:
return SeedMetadata(
size=
)
60 changes: 53 additions & 7 deletions pyrit/datasets/seed_datasets/seed_dataset_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tqdm import tqdm

from pyrit.models.seeds import SeedDataset
from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata

logger = logging.getLogger(__name__)

Expand All @@ -25,6 +26,10 @@ class SeedDatasetProvider(ABC):
Subclasses must implement:
- fetch_dataset(): Fetch and return the dataset as a SeedDataset
- dataset_name property: Human-readable name for the dataset

All subclasses also have a _metadata property that is optional to make
dataset addition easier, but failing to complete it makes downstream
analysis more difficult.
"""

_registry: dict[str, type["SeedDatasetProvider"]] = {}
Expand All @@ -40,6 +45,10 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
if not inspect.isabstract(cls) and getattr(cls, "should_register", True):
SeedDatasetProvider._registry[cls.__name__] = cls
logger.debug(f"Registered dataset provider: {cls.__name__}")
# Providing metadata is optional
if getattr(cls, "_metadata", False):
logger.debug(
f"Dataset provider {cls.__name__} provided metadata.")

@property
@abstractmethod
Expand Down Expand Up @@ -78,10 +87,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]:
return cls._registry.copy()

@classmethod
def get_all_dataset_names(cls) -> list[str]:
def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]:
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might start by going backwards from what will work well

An open-ended dictionary might be tricky to use - as a caller it's not clear how to I'd want to use it. I'd Potentially have a DatasetProviderFilter class. I'd likely make the decision that we want to filter these before fetching the datasets. But both could be valid options.

Here is what it might look like (ty copilot)

Problem

Today, SeedDatasetProvider.fetch_datasets_async() fetches all registered datasets (or a hard-coded list of names). There's no way to say "give me only text datasets" or "give me only small safety-related datasets." Every call potentially downloads 30+ datasets from HuggingFace/GitHub.

We want to add metadata to each provider and a filter object so users can query datasets like this:

# "Give me small text-only datasets tagged as default"
f = DatasetProviderFilter(
    tags={"default"},
    modalities=[DatasetModality.TEXT],
    sizes=[DatasetSize.SMALL],
)
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

# "Give me everything multimodal"
f = DatasetProviderFilter(modalities=[DatasetModality.TEXT, DatasetModality.IMAGE])
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

# "Give me all datasets (no filtering)"
f = DatasetProviderFilter(tags={"all"})
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

Filtering happens before fetching — providers that don't match are skipped entirely, so nothing is downloaded unnecessarily.

There are exactly two things to build: Provider metadata and DatasetProviderFilter

1. Provider metadata — each provider declares what it is

Every SeedDatasetProvider subclass needs to declare four pieces of metadata as class-level attributes (not instance properties — we need to read them without instantiating):

class _HarmBenchDataset(_RemoteDatasetLoader):
    """HarmBench: 504 harmful behaviors across safety categories."""

    harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"]
    modalities: list[DatasetModality] = [DatasetModality.TEXT]
    size: DatasetSize = DatasetSize.LARGE          # 504 seeds
    tags: set[str] = {"default", "safety"}         # "default" means included in curated set

    @property
    def dataset_name(self) -> str:
        return "harmbench"
    # ...
Attribute Type Purpose
harm_categories list[str] Free-form strings like "violence", "cybercrime". No enum — each dataset uses its own vocabulary.
modalities list[DatasetModality] TEXT, IMAGE, AUDIO, VIDEO. Indicates what data types the seeds contain.
size DatasetSize SMALL (<50 seeds), MEDIUM (50–500), LARGE (>500). Self-declared by the provider author.
tags set[str] Flexible labels. "default" means “include me in the curated default set.” Anything else is free-form.
source_type str "local" or "remote". Set once on the base classes: _LocalDatasetLoader returns "local" and _RemoteDatasetLoader returns "remote".

2. DatasetProviderFilter — the user-facing filter

@dataclass
class DatasetProviderFilter:
    """
    Filters dataset providers based on their declared metadata.

    All fields are optional. None means "don't filter on this axis."
    Across axes: AND (all specified conditions must match).
    Within each axis: OR (provider needs at least one overlap).

    Special tag behavior:
    - tags={"all"} → skip tag filtering entirely, return everything
    - tags={"default"} → only providers that have "default" in their tags
    - tags=None → no tag filtering (same as "all")
    """

    harm_categories: Optional[list[str]] = None
    source_type: Optional[Literal["local", "remote"]] = None
    modalities: Optional[list[DatasetModality]] = None
    sizes: Optional[list[DatasetSize]] = None
    tags: Optional[set[str]] = None

    def matches(self, *, provider: SeedDatasetProvider) -> bool:
        """Return True if the provider passes all filter conditions."""

        # Tags: "all" means skip tag check
        if self.tags is not None and "all" not in self.tags:
            if not self.tags & provider.tags:  # set intersection — need at least one overlap
                return False

        # Harm categories: provider must have at least one matching category
        if self.harm_categories is not None:
            if not set(self.harm_categories) & set(provider.harm_categories):
                return False

        # Source type
        if self.source_type is not None:
            if provider.source_type != self.source_type:
                return False

        # Modalities: provider must support at least one requested modality
        if self.modalities is not None:
            if not set(self.modalities) & set(provider.modalities):
                return False

        # Size
        if self.sizes is not None:
            if provider.size not in self.sizes:
                return False

        return True

Matching logic in plain English:

Each specified filter field must be satisfied (AND)
Within a field, the provider only needs to overlap on one value (OR)
None on any field = don't care about that field
tags={"all"} = special: return everything regardless of tags
tags=None = also returns everything (no tag filter applied)

Fetch Datasets Update

In seed_dataset_provider.py, the existing method gains two new parameter, filter and max_seeds. Before creating tasks, datasets are narrowed:

# Apply filter to narrow down which providers to even consider
providers = cls._registry
if filter:
    providers = {
        name: pclass for name, pclass in providers.items()
        if filter.matches(provider=pclass())
    }

max_seeds needs to be implemented in the dataset classes, but would allow us to limit the number of seeds retrieved. This way we can still have integration tests for all datasets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with these points. A few things to work on in my opinion:

Caller Interface
A dictionary literal is definitely too ambiguous like you said. Users have no way of intuiting what is or isn't a valid filter. But I think calling the enums directly is verbose, and where those live and why they work isn't obvious to the user. Users who just want to grab all small datasets have to invoke a new class and several custom types with that approach, and they shouldn't have to dive into the type system of DatasetMetadata to do it.

I think we have a few options to narrow it down. We could use a typed dictionary, something like this:

f: DatasetFilters = {
    "sizes": "small",
    "modalities": ["text", "image"]
}
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

This lets us constrain the types allowed in the filters without making it cumbersome for the user. I'll keep iterating on this, but I think it's a good start.

Class Attributes and Instantiating
My first instinct was to just add the metadata as fields to each SeedDatasetProvider child, but there are some issues with this. The first is that our default implementation instantiates the class in SeedDatasetProvider.__init_subclass__ anyway, which makes that seem like the more natural injection point. The second is that derived attributes like exact size cannot be used as metadata if they're kept as class attributes. And the third is that we run the risk of having out-of-date metadata. We could just do a one-time scan of each dataset and store its size, but for remote datasets especially I feel that drift is an issue.

I don't have a good solution to this, so I'm leaning towards scoping derived attributes out of the PR, but worth thinking about tradeoffs.

Dataset-Specific Metadata
One issue I ran into early was whether or not we need each dataset to explicitly define its own metadata. The answer can definitely be a yes, but I wanted to do it in a way that would make it easier for users to add or remove datasets without spending too much time on it.

The first attempt I did was a factory method that produced a metadata object. That seemed more cumbersome. The second was a private class attribute, which has hidden state, but is more convenient. Neither approach is great.

Where I'm leaning now is private class attribute for metadata that has custom tags as an attribute. But I don't like that approach very much. I'll keep iterating, but I think we should try to keep metadata together as a single object.

"""
Get the names of all registered datasets.

Args:
filters (Optional[Dict[str, str]]): List of filters to apply.

Returns:
List[str]: List of dataset names from all registered providers.

Expand All @@ -93,13 +105,42 @@ def get_all_dataset_names(cls) -> list[str]:
>>> print(f"Available datasets: {', '.join(names)}")
"""
dataset_names = set()
# 1 Remove invalid filters by checking ground truth in seed_metadata
if filters:
valid_filters = [f.value for f in SeedMetadata.DatasetFilters]
# Prefer doing this to a list or set comprehension so we can raise ValueError on
# specific unsupported filters
for filter, _ in filters.items():
if filter not in valid_filters:
raise ValueError(
f"Tried to pass invalid filter `{filter}` to SeedDatasetProvider.get_all_dataset_names!")

for provider_class in cls._registry.values():
try:
# Instantiate to get dataset name
provider = provider_class()

if filters:
# 1 Check if it has metadata
# should this be none or false
if getattr(provider, "_metadata", False):
# Skip a dataset without metadata if we have filters enabled
continue

# 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not)

# 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition

# Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets
# since we can't check them statically

# Solution: If filter is dynamic, then just download or load into central memory early to retrieve it
# and present a warning to the user that this is occuring

dataset_names.add(provider.dataset_name)
except Exception as e:
raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e
raise ValueError(
f"Could not get dataset name from {provider_class.__name__}: {e}") from e
return sorted(dataset_names)

@classmethod
Expand Down Expand Up @@ -142,9 +183,11 @@ async def fetch_datasets_async(
# Validate dataset names if specified
if dataset_names is not None:
available_names = cls.get_all_dataset_names()
invalid_names = [name for name in dataset_names if name not in available_names]
invalid_names = [
name for name in dataset_names if name not in available_names]
if invalid_names:
raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}")
raise ValueError(
f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}")

async def fetch_single_dataset(
provider_name: str, provider_class: type["SeedDatasetProvider"]
Expand All @@ -170,7 +213,8 @@ async def fetch_single_dataset(

# Progress tracking
total_count = len(cls._registry)
pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset")
pbar = tqdm(total=total_count,
desc="Loading datasets - this can take a few minutes", unit="dataset")

async def fetch_with_semaphore(
provider_name: str, provider_class: type["SeedDatasetProvider"]
Expand Down Expand Up @@ -208,10 +252,12 @@ async def fetch_with_semaphore(
logger.info(f"Merging multiple sources for {dataset_name}.")

existing_dataset = datasets[dataset_name]
combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds)
combined_seeds = list(
existing_dataset.seeds) + list(dataset.seeds)
existing_dataset.seeds = combined_seeds
else:
datasets[dataset_name] = dataset

logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers")
logger.info(
f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers")
return list(datasets.values())
73 changes: 73 additions & 0 deletions pyrit/datasets/seed_datasets/seed_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from enum import Enum
from dataclasses import dataclass

"""
TODO Finish docstring

Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider).

We have one DatasetMetadata dataclass that is our ground truth. As we instantiate datasets
using the subclass call in SeedDatasetProvider, we create DatasetMetadata and assign it to
a private variable there.

Some fields are dynamic (e.g. loading statistics, timestamp, dataset size) and are left as
NoneType until the SeedDatasetProvider actually downloads/parses the dataset and puts it in
CentralMemory.
"""


class DatasetLoadingRank(Enum):
"""Represents the general difficulty of loading in a dataset."""
DEFAULT = "default"
EXTENDED = "extended"
SLOW = "slow"


class DatasetModalities(Enum):
TEXT = "text"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"


class DatasetSourceType(Enum):
GENERIC_URL = "generic_url"
LOCAL = "local"
HUGGING_FACE = "hugging_face"


@dataclass
class DatasetMetadata:
# TODO: separate dynamic fields from static fields and mark dynamic fields as None
size: int
modalities: list[DatasetModalities]
source: DatasetSourceType
rank: DatasetLoadingRank


class DatasetFilters(Enum):
# TODO: This is a bad way of extracting the fields from DatasetMetadata.
# A metaclass or even just calling getattr might be better.
SIZE = "size"
MODALITIES = "modalities"
SOURCE = "source"
RANK = "rank"

# TODO These stubs should be moved somewhere, maybe as static methods to the metadata dataclass?


def _validate_filter_value(v):
"""Check if the filter value given is valid."""


def _metadata_builder():
"""
Force build metadata for all datasets.
Download/load into local memory.
Add a timestamp.
Add all derived attributes.
Make sure every dataset subclass has it.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ async def test_fetch_dataset_integration(self, name, provider_cls):

try:
# Use max_examples for slow providers that fetch many remote images
provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls()
provider = provider_cls(
max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls()
dataset = await provider.fetch_dataset(cache=False)

assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset"
assert isinstance(
dataset, SeedDataset), f"{name} did not return a SeedDataset"
assert len(dataset.seeds) > 0, f"{name} returned an empty dataset"
assert dataset.dataset_name, f"{name} has no dataset_name"

Expand All @@ -51,7 +53,14 @@ async def test_fetch_dataset_integration(self, name, provider_cls):
f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}"
)

logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds")
logger.info(
f"Successfully verified {name} with {len(dataset.seeds)} seeds")

except Exception as e:
pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}")

@pytest.mark.asyncio
@pytest.mark.parameterize("name,provider_cls", get_dataset_providers())
async def test_fetch_dataset_with_filtering(self, name, provider_cls):
# TODO
pass
32 changes: 32 additions & 0 deletions tests/unit/datasets/test_seed_dataset_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""
TODO

Tests for SeedDatasetMetadata
"""


class TestMetadataParsing:
def test_invalid_filter_key(self):
pass

def test_invalid_filter_value(self):
pass


class TestMetadataLifecycle:
def test_static_values_populated(self):
pass

def test_dynamic_values_populated(self):
pass


class TestMetadataPerformance:
def test_quick_retrieval_for_static_values(self):
pass

def test_acceptable_retrieval_for_dynamic_values(self):
pass
Loading