[DRAFT] FEAT: Dataset Loading Changes#1451
Conversation
| invalid_categories = { | ||
| cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} | ||
| if invalid_categories: | ||
| raise ValueError( |
There was a problem hiding this comment.
I think we likely still want to load these; should we use a default harm category here?
|
|
||
| @classmethod | ||
| def get_all_dataset_names(cls) -> list[str]: | ||
| def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]: |
There was a problem hiding this comment.
I might start by going backwards from what will work well
An open-ended dictionary might be tricky to use - as a caller it's not clear how to I'd want to use it. I'd Potentially have a DatasetProviderFilter class. I'd likely make the decision that we want to filter these before fetching the datasets. But both could be valid options.
Here is what it might look like (ty copilot)
Problem
Today, SeedDatasetProvider.fetch_datasets_async() fetches all registered datasets (or a hard-coded list of names). There's no way to say "give me only text datasets" or "give me only small safety-related datasets." Every call potentially downloads 30+ datasets from HuggingFace/GitHub.
We want to add metadata to each provider and a filter object so users can query datasets like this:
# "Give me small text-only datasets tagged as default"
f = DatasetProviderFilter(
tags={"default"},
modalities=[DatasetModality.TEXT],
sizes=[DatasetSize.SMALL],
)
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)
# "Give me everything multimodal"
f = DatasetProviderFilter(modalities=[DatasetModality.TEXT, DatasetModality.IMAGE])
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)
# "Give me all datasets (no filtering)"
f = DatasetProviderFilter(tags={"all"})
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)
Filtering happens before fetching — providers that don't match are skipped entirely, so nothing is downloaded unnecessarily.
There are exactly two things to build: Provider metadata and DatasetProviderFilter
1. Provider metadata — each provider declares what it is
Every SeedDatasetProvider subclass needs to declare four pieces of metadata as class-level attributes (not instance properties — we need to read them without instantiating):
class _HarmBenchDataset(_RemoteDatasetLoader):
"""HarmBench: 504 harmful behaviors across safety categories."""
harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"]
modalities: list[DatasetModality] = [DatasetModality.TEXT]
size: DatasetSize = DatasetSize.LARGE # 504 seeds
tags: set[str] = {"default", "safety"} # "default" means included in curated set
@property
def dataset_name(self) -> str:
return "harmbench"
# ...
| Attribute | Type | Purpose |
|---|---|---|
| harm_categories | list[str] |
Free-form strings like "violence", "cybercrime". No enum — each dataset uses its own vocabulary. |
| modalities | list[DatasetModality] |
TEXT, IMAGE, AUDIO, VIDEO. Indicates what data types the seeds contain. |
| size | DatasetSize |
SMALL (<50 seeds), MEDIUM (50–500), LARGE (>500). Self-declared by the provider author. |
| tags | set[str] |
Flexible labels. "default" means “include me in the curated default set.” Anything else is free-form. |
| source_type | str |
"local" or "remote". Set once on the base classes: _LocalDatasetLoader returns "local" and _RemoteDatasetLoader returns "remote". |
2. DatasetProviderFilter — the user-facing filter
@dataclass
class DatasetProviderFilter:
"""
Filters dataset providers based on their declared metadata.
All fields are optional. None means "don't filter on this axis."
Across axes: AND (all specified conditions must match).
Within each axis: OR (provider needs at least one overlap).
Special tag behavior:
- tags={"all"} → skip tag filtering entirely, return everything
- tags={"default"} → only providers that have "default" in their tags
- tags=None → no tag filtering (same as "all")
"""
harm_categories: Optional[list[str]] = None
source_type: Optional[Literal["local", "remote"]] = None
modalities: Optional[list[DatasetModality]] = None
sizes: Optional[list[DatasetSize]] = None
tags: Optional[set[str]] = None
def matches(self, *, provider: SeedDatasetProvider) -> bool:
"""Return True if the provider passes all filter conditions."""
# Tags: "all" means skip tag check
if self.tags is not None and "all" not in self.tags:
if not self.tags & provider.tags: # set intersection — need at least one overlap
return False
# Harm categories: provider must have at least one matching category
if self.harm_categories is not None:
if not set(self.harm_categories) & set(provider.harm_categories):
return False
# Source type
if self.source_type is not None:
if provider.source_type != self.source_type:
return False
# Modalities: provider must support at least one requested modality
if self.modalities is not None:
if not set(self.modalities) & set(provider.modalities):
return False
# Size
if self.sizes is not None:
if provider.size not in self.sizes:
return False
return True
Matching logic in plain English:
Each specified filter field must be satisfied (AND)
Within a field, the provider only needs to overlap on one value (OR)
None on any field = don't care about that field
tags={"all"} = special: return everything regardless of tags
tags=None = also returns everything (no tag filter applied)
Fetch Datasets Update
In seed_dataset_provider.py, the existing method gains two new parameter, filter and max_seeds. Before creating tasks, datasets are narrowed:
# Apply filter to narrow down which providers to even consider
providers = cls._registry
if filter:
providers = {
name: pclass for name, pclass in providers.items()
if filter.matches(provider=pclass())
}
max_seeds needs to be implemented in the dataset classes, but would allow us to limit the number of seeds retrieved. This way we can still have integration tests for all datasets.
There was a problem hiding this comment.
I agree with these points. A few things to work on in my opinion:
Caller Interface
A dictionary literal is definitely too ambiguous like you said. Users have no way of intuiting what is or isn't a valid filter. But I think calling the enums directly is verbose, and where those live and why they work isn't obvious to the user. Users who just want to grab all small datasets have to invoke a new class and several custom types with that approach, and they shouldn't have to dive into the type system of DatasetMetadata to do it.
I think we have a few options to narrow it down. We could use a typed dictionary, something like this:
f: DatasetFilters = {
"sizes": "small",
"modalities": ["text", "image"]
}
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)This lets us constrain the types allowed in the filters without making it cumbersome for the user. I'll keep iterating on this, but I think it's a good start.
Class Attributes and Instantiating
My first instinct was to just add the metadata as fields to each SeedDatasetProvider child, but there are some issues with this. The first is that our default implementation instantiates the class in SeedDatasetProvider.__init_subclass__ anyway, which makes that seem like the more natural injection point. The second is that derived attributes like exact size cannot be used as metadata if they're kept as class attributes. And the third is that we run the risk of having out-of-date metadata. We could just do a one-time scan of each dataset and store its size, but for remote datasets especially I feel that drift is an issue.
I don't have a good solution to this, so I'm leaning towards scoping derived attributes out of the PR, but worth thinking about tradeoffs.
Dataset-Specific Metadata
One issue I ran into early was whether or not we need each dataset to explicitly define its own metadata. The answer can definitely be a yes, but I wanted to do it in a way that would make it easier for users to add or remove datasets without spending too much time on it.
The first attempt I did was a factory method that produced a metadata object. That seemed more cumbersome. The second was a private class attribute, which has hidden state, but is more convenient. Neither approach is great.
Where I'm leaning now is private class attribute for metadata that has custom tags as an attribute. But I don't like that approach very much. I'll keep iterating, but I think we should try to keep metadata together as a single object.
Description
Features:
meet filter criteria
and is called in SeedDatasetProvider's subclass call during its init
Problems:
Possible Solutions:
Tests and Documentation