Skip to content

[DRAFT] FEAT: Dataset Loading Changes#1451

Draft
ValbuenaVC wants to merge 6 commits intoAzure:mainfrom
ValbuenaVC:datasetloader
Draft

[DRAFT] FEAT: Dataset Loading Changes#1451
ValbuenaVC wants to merge 6 commits intoAzure:mainfrom
ValbuenaVC:datasetloader

Conversation

@ValbuenaVC
Copy link
Contributor

@ValbuenaVC ValbuenaVC commented Mar 10, 2026

Description

Features:

  • Addition of filters argument to get_all_dataset_names, which rejects datasets that don't
    meet filter criteria
  • Use of a DatasetMetadata factory enables both static metadata (like loading rank) and size (which at least for remote datasets can only exist after being downloaded)
  • DatasetMetadata dataclass contains size: int, modalities: list[DatasetModalities], source: DatasetSourceType, and loading_rank: DatasetLoadingRank.
  • The fields of DatasetMetadata
  • Each dataset child implements the abstract method metadata_factory which returns the metadata
    and is called in SeedDatasetProvider's subclass call during its init

Problems:

  • Way too complicated for static attributes that could be class variables
  • Forces metadata generation to wait until dataset is downloaded for derived attributes
  • It would be nice to have SQL ability for all datasets; imagine doing a JOIN operation across different datasets using the same harm category
  • Not a lot of interaction with identifiers which seem like a natural overlap point for tracking dataset metadata
  • Use of a factory method is more explicit, but use of a private attribute is more intuitive. It's unclear which should take precendence

Possible Solutions:

  • Separate metadata into dynamic and static subtypes that have different paths
  • Use None values for dynamic attributes and populate them when a dataset actually downloads (if invoked in get_all_dataset_names, force downloads)
  • Save rich querying via SQL for a separate PR
  • Migrate seed_dataset to an identifier, which already makes the crucial distinction of static attributes and dynamic (what ComponentIdentifier calls behavioral) attributes

Tests and Documentation

  • For remote, just test that the file writes to test and connection is served
  • For local, test that one entry makes it into the patched DB

invalid_categories = {
cat for cat in harm_categories if cat not in self.HARM_CATEGORIES}
if invalid_categories:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we likely still want to load these; should we use a default harm category here?


@classmethod
def get_all_dataset_names(cls) -> list[str]:
def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]:
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might start by going backwards from what will work well

An open-ended dictionary might be tricky to use - as a caller it's not clear how to I'd want to use it. I'd Potentially have a DatasetProviderFilter class. I'd likely make the decision that we want to filter these before fetching the datasets. But both could be valid options.

Here is what it might look like (ty copilot)

Problem

Today, SeedDatasetProvider.fetch_datasets_async() fetches all registered datasets (or a hard-coded list of names). There's no way to say "give me only text datasets" or "give me only small safety-related datasets." Every call potentially downloads 30+ datasets from HuggingFace/GitHub.

We want to add metadata to each provider and a filter object so users can query datasets like this:

# "Give me small text-only datasets tagged as default"
f = DatasetProviderFilter(
    tags={"default"},
    modalities=[DatasetModality.TEXT],
    sizes=[DatasetSize.SMALL],
)
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

# "Give me everything multimodal"
f = DatasetProviderFilter(modalities=[DatasetModality.TEXT, DatasetModality.IMAGE])
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

# "Give me all datasets (no filtering)"
f = DatasetProviderFilter(tags={"all"})
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

Filtering happens before fetching — providers that don't match are skipped entirely, so nothing is downloaded unnecessarily.

There are exactly two things to build: Provider metadata and DatasetProviderFilter

1. Provider metadata — each provider declares what it is

Every SeedDatasetProvider subclass needs to declare four pieces of metadata as class-level attributes (not instance properties — we need to read them without instantiating):

class _HarmBenchDataset(_RemoteDatasetLoader):
    """HarmBench: 504 harmful behaviors across safety categories."""

    harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"]
    modalities: list[DatasetModality] = [DatasetModality.TEXT]
    size: DatasetSize = DatasetSize.LARGE          # 504 seeds
    tags: set[str] = {"default", "safety"}         # "default" means included in curated set

    @property
    def dataset_name(self) -> str:
        return "harmbench"
    # ...
Attribute Type Purpose
harm_categories list[str] Free-form strings like "violence", "cybercrime". No enum — each dataset uses its own vocabulary.
modalities list[DatasetModality] TEXT, IMAGE, AUDIO, VIDEO. Indicates what data types the seeds contain.
size DatasetSize SMALL (<50 seeds), MEDIUM (50–500), LARGE (>500). Self-declared by the provider author.
tags set[str] Flexible labels. "default" means “include me in the curated default set.” Anything else is free-form.
source_type str "local" or "remote". Set once on the base classes: _LocalDatasetLoader returns "local" and _RemoteDatasetLoader returns "remote".

2. DatasetProviderFilter — the user-facing filter

@dataclass
class DatasetProviderFilter:
    """
    Filters dataset providers based on their declared metadata.

    All fields are optional. None means "don't filter on this axis."
    Across axes: AND (all specified conditions must match).
    Within each axis: OR (provider needs at least one overlap).

    Special tag behavior:
    - tags={"all"} → skip tag filtering entirely, return everything
    - tags={"default"} → only providers that have "default" in their tags
    - tags=None → no tag filtering (same as "all")
    """

    harm_categories: Optional[list[str]] = None
    source_type: Optional[Literal["local", "remote"]] = None
    modalities: Optional[list[DatasetModality]] = None
    sizes: Optional[list[DatasetSize]] = None
    tags: Optional[set[str]] = None

    def matches(self, *, provider: SeedDatasetProvider) -> bool:
        """Return True if the provider passes all filter conditions."""

        # Tags: "all" means skip tag check
        if self.tags is not None and "all" not in self.tags:
            if not self.tags & provider.tags:  # set intersection — need at least one overlap
                return False

        # Harm categories: provider must have at least one matching category
        if self.harm_categories is not None:
            if not set(self.harm_categories) & set(provider.harm_categories):
                return False

        # Source type
        if self.source_type is not None:
            if provider.source_type != self.source_type:
                return False

        # Modalities: provider must support at least one requested modality
        if self.modalities is not None:
            if not set(self.modalities) & set(provider.modalities):
                return False

        # Size
        if self.sizes is not None:
            if provider.size not in self.sizes:
                return False

        return True

Matching logic in plain English:

Each specified filter field must be satisfied (AND)
Within a field, the provider only needs to overlap on one value (OR)
None on any field = don't care about that field
tags={"all"} = special: return everything regardless of tags
tags=None = also returns everything (no tag filter applied)

Fetch Datasets Update

In seed_dataset_provider.py, the existing method gains two new parameter, filter and max_seeds. Before creating tasks, datasets are narrowed:

# Apply filter to narrow down which providers to even consider
providers = cls._registry
if filter:
    providers = {
        name: pclass for name, pclass in providers.items()
        if filter.matches(provider=pclass())
    }

max_seeds needs to be implemented in the dataset classes, but would allow us to limit the number of seeds retrieved. This way we can still have integration tests for all datasets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with these points. A few things to work on in my opinion:

Caller Interface
A dictionary literal is definitely too ambiguous like you said. Users have no way of intuiting what is or isn't a valid filter. But I think calling the enums directly is verbose, and where those live and why they work isn't obvious to the user. Users who just want to grab all small datasets have to invoke a new class and several custom types with that approach, and they shouldn't have to dive into the type system of DatasetMetadata to do it.

I think we have a few options to narrow it down. We could use a typed dictionary, something like this:

f: DatasetFilters = {
    "sizes": "small",
    "modalities": ["text", "image"]
}
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

This lets us constrain the types allowed in the filters without making it cumbersome for the user. I'll keep iterating on this, but I think it's a good start.

Class Attributes and Instantiating
My first instinct was to just add the metadata as fields to each SeedDatasetProvider child, but there are some issues with this. The first is that our default implementation instantiates the class in SeedDatasetProvider.__init_subclass__ anyway, which makes that seem like the more natural injection point. The second is that derived attributes like exact size cannot be used as metadata if they're kept as class attributes. And the third is that we run the risk of having out-of-date metadata. We could just do a one-time scan of each dataset and store its size, but for remote datasets especially I feel that drift is an issue.

I don't have a good solution to this, so I'm leaning towards scoping derived attributes out of the PR, but worth thinking about tradeoffs.

Dataset-Specific Metadata
One issue I ran into early was whether or not we need each dataset to explicitly define its own metadata. The answer can definitely be a yes, but I wanted to do it in a way that would make it easier for users to add or remove datasets without spending too much time on it.

The first attempt I did was a factory method that produced a metadata object. That seemed more cumbersome. The second was a private class attribute, which has hidden state, but is more convenient. Neither approach is great.

Where I'm leaning now is private class attribute for metadata that has custom tags as an attribute. But I don't like that approach very much. I'll keep iterating, but I think we should try to keep metadata together as a single object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants