-
Notifications
You must be signed in to change notification settings - Fork 689
[DRAFT] FEAT: Dataset Loading Changes #1451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
857f596
1aaa978
15b58e8
5209a5a
fc43c8c
f4296f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from pyrit.models.seeds import SeedDataset | ||||||||||||||||||||
| from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata | ||||||||||||||||||||
|
|
||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -25,6 +26,10 @@ class SeedDatasetProvider(ABC): | |||||||||||||||||||
| Subclasses must implement: | ||||||||||||||||||||
| - fetch_dataset(): Fetch and return the dataset as a SeedDataset | ||||||||||||||||||||
| - dataset_name property: Human-readable name for the dataset | ||||||||||||||||||||
|
|
||||||||||||||||||||
| All subclasses also have a _metadata property that is optional to make | ||||||||||||||||||||
| dataset addition easier, but failing to complete it makes downstream | ||||||||||||||||||||
| analysis more difficult. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| _registry: dict[str, type["SeedDatasetProvider"]] = {} | ||||||||||||||||||||
|
|
@@ -40,6 +45,10 @@ def __init_subclass__(cls, **kwargs: Any) -> None: | |||||||||||||||||||
| if not inspect.isabstract(cls) and getattr(cls, "should_register", True): | ||||||||||||||||||||
| SeedDatasetProvider._registry[cls.__name__] = cls | ||||||||||||||||||||
| logger.debug(f"Registered dataset provider: {cls.__name__}") | ||||||||||||||||||||
| # Providing metadata is optional | ||||||||||||||||||||
| if getattr(cls, "_metadata", False): | ||||||||||||||||||||
| logger.debug( | ||||||||||||||||||||
| f"Dataset provider {cls.__name__} provided metadata.") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @property | ||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||
|
|
@@ -78,10 +87,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: | |||||||||||||||||||
| return cls._registry.copy() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @classmethod | ||||||||||||||||||||
| def get_all_dataset_names(cls) -> list[str]: | ||||||||||||||||||||
| def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]: | ||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might start by going backwards from what will work well An open-ended dictionary might be tricky to use - as a caller it's not clear how to I'd want to use it. I'd Potentially have a Here is what it might look like (ty copilot) ProblemToday, We want to add metadata to each provider and a filter object so users can query datasets like this: Filtering happens before fetching — providers that don't match are skipped entirely, so nothing is downloaded unnecessarily. There are exactly two things to build: Provider metadata and DatasetProviderFilter 1. Provider metadata — each provider declares what it isEvery SeedDatasetProvider subclass needs to declare four pieces of metadata as class-level attributes (not instance properties — we need to read them without instantiating):
2. DatasetProviderFilter — the user-facing filterMatching logic in plain English: Each specified filter field must be satisfied (AND) Fetch Datasets UpdateIn seed_dataset_provider.py, the existing method gains two new parameter,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with these points. A few things to work on in my opinion: Caller Interface I think we have a few options to narrow it down. We could use a typed dictionary, something like this: f: DatasetFilters = {
"sizes": "small",
"modalities": ["text", "image"]
}
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)This lets us constrain the types allowed in the filters without making it cumbersome for the user. I'll keep iterating on this, but I think it's a good start. Class Attributes and Instantiating I don't have a good solution to this, so I'm leaning towards scoping derived attributes out of the PR, but worth thinking about tradeoffs. Dataset-Specific Metadata The first attempt I did was a factory method that produced a metadata object. That seemed more cumbersome. The second was a private class attribute, which has hidden state, but is more convenient. Neither approach is great. Where I'm leaning now is private class attribute for metadata that has custom tags as an attribute. But I don't like that approach very much. I'll keep iterating, but I think we should try to keep metadata together as a single object. |
||||||||||||||||||||
| """ | ||||||||||||||||||||
| Get the names of all registered datasets. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Args: | ||||||||||||||||||||
| filters (Optional[Dict[str, str]]): List of filters to apply. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Returns: | ||||||||||||||||||||
| List[str]: List of dataset names from all registered providers. | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -93,13 +105,42 @@ def get_all_dataset_names(cls) -> list[str]: | |||||||||||||||||||
| >>> print(f"Available datasets: {', '.join(names)}") | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| dataset_names = set() | ||||||||||||||||||||
| # 1 Remove invalid filters by checking ground truth in seed_metadata | ||||||||||||||||||||
| if filters: | ||||||||||||||||||||
| valid_filters = [f.value for f in SeedMetadata.DatasetFilters] | ||||||||||||||||||||
| # Prefer doing this to a list or set comprehension so we can raise ValueError on | ||||||||||||||||||||
| # specific unsupported filters | ||||||||||||||||||||
| for filter, _ in filters.items(): | ||||||||||||||||||||
| if filter not in valid_filters: | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"Tried to pass invalid filter `{filter}` to SeedDatasetProvider.get_all_dataset_names!") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| for provider_class in cls._registry.values(): | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| # Instantiate to get dataset name | ||||||||||||||||||||
| provider = provider_class() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if filters: | ||||||||||||||||||||
| # 1 Check if it has metadata | ||||||||||||||||||||
| # should this be none or false | ||||||||||||||||||||
| if getattr(provider, "_metadata", False): | ||||||||||||||||||||
| # Skip a dataset without metadata if we have filters enabled | ||||||||||||||||||||
| continue | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets | ||||||||||||||||||||
| # since we can't check them statically | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Solution: If filter is dynamic, then just download or load into central memory early to retrieve it | ||||||||||||||||||||
| # and present a warning to the user that this is occuring | ||||||||||||||||||||
|
|
||||||||||||||||||||
| dataset_names.add(provider.dataset_name) | ||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||
| raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"Could not get dataset name from {provider_class.__name__}: {e}") from e | ||||||||||||||||||||
| return sorted(dataset_names) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @classmethod | ||||||||||||||||||||
|
|
@@ -142,9 +183,11 @@ async def fetch_datasets_async( | |||||||||||||||||||
| # Validate dataset names if specified | ||||||||||||||||||||
| if dataset_names is not None: | ||||||||||||||||||||
| available_names = cls.get_all_dataset_names() | ||||||||||||||||||||
| invalid_names = [name for name in dataset_names if name not in available_names] | ||||||||||||||||||||
| invalid_names = [ | ||||||||||||||||||||
| name for name in dataset_names if name not in available_names] | ||||||||||||||||||||
| if invalid_names: | ||||||||||||||||||||
| raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def fetch_single_dataset( | ||||||||||||||||||||
| provider_name: str, provider_class: type["SeedDatasetProvider"] | ||||||||||||||||||||
|
|
@@ -170,7 +213,8 @@ async def fetch_single_dataset( | |||||||||||||||||||
|
|
||||||||||||||||||||
| # Progress tracking | ||||||||||||||||||||
| total_count = len(cls._registry) | ||||||||||||||||||||
| pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") | ||||||||||||||||||||
| pbar = tqdm(total=total_count, | ||||||||||||||||||||
| desc="Loading datasets - this can take a few minutes", unit="dataset") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def fetch_with_semaphore( | ||||||||||||||||||||
| provider_name: str, provider_class: type["SeedDatasetProvider"] | ||||||||||||||||||||
|
|
@@ -208,10 +252,12 @@ async def fetch_with_semaphore( | |||||||||||||||||||
| logger.info(f"Merging multiple sources for {dataset_name}.") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| existing_dataset = datasets[dataset_name] | ||||||||||||||||||||
| combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) | ||||||||||||||||||||
| combined_seeds = list( | ||||||||||||||||||||
| existing_dataset.seeds) + list(dataset.seeds) | ||||||||||||||||||||
| existing_dataset.seeds = combined_seeds | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| datasets[dataset_name] = dataset | ||||||||||||||||||||
|
|
||||||||||||||||||||
| logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") | ||||||||||||||||||||
| logger.info( | ||||||||||||||||||||
| f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") | ||||||||||||||||||||
| return list(datasets.values()) | ||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| from enum import Enum | ||
| from dataclasses import dataclass | ||
|
|
||
| """ | ||
| TODO Finish docstring | ||
|
|
||
| Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). | ||
|
|
||
| We have one DatasetMetadata dataclass that is our ground truth. As we instantiate datasets | ||
| using the subclass call in SeedDatasetProvider, we create DatasetMetadata and assign it to | ||
| a private variable there. | ||
|
|
||
| Some fields are dynamic (e.g. loading statistics, timestamp, dataset size) and are left as | ||
| NoneType until the SeedDatasetProvider actually downloads/parses the dataset and puts it in | ||
| CentralMemory. | ||
| """ | ||
|
|
||
|
|
||
| class DatasetLoadingRank(Enum): | ||
| """Represents the general difficulty of loading in a dataset.""" | ||
| DEFAULT = "default" | ||
| EXTENDED = "extended" | ||
| SLOW = "slow" | ||
|
|
||
|
|
||
| class DatasetModalities(Enum): | ||
| TEXT = "text" | ||
| IMAGE = "image" | ||
| VIDEO = "video" | ||
| AUDIO = "audio" | ||
|
|
||
|
|
||
| class DatasetSourceType(Enum): | ||
| GENERIC_URL = "generic_url" | ||
| LOCAL = "local" | ||
| HUGGING_FACE = "hugging_face" | ||
|
|
||
|
|
||
| @dataclass | ||
| class DatasetMetadata: | ||
| # TODO: separate dynamic fields from static fields and mark dynamic fields as None | ||
| size: int | ||
| modalities: list[DatasetModalities] | ||
| source: DatasetSourceType | ||
| rank: DatasetLoadingRank | ||
|
|
||
|
|
||
| class DatasetFilters(Enum): | ||
| # TODO: This is a bad way of extracting the fields from DatasetMetadata. | ||
| # A metaclass or even just calling getattr might be better. | ||
| SIZE = "size" | ||
| MODALITIES = "modalities" | ||
| SOURCE = "source" | ||
| RANK = "rank" | ||
|
|
||
| # TODO These stubs should be moved somewhere, maybe as static methods to the metadata dataclass? | ||
|
|
||
|
|
||
| def _validate_filter_value(v): | ||
| """Check if the filter value given is valid.""" | ||
|
|
||
|
|
||
| def _metadata_builder(): | ||
| """ | ||
| Force build metadata for all datasets. | ||
| Download/load into local memory. | ||
| Add a timestamp. | ||
| Add all derived attributes. | ||
| Make sure every dataset subclass has it. | ||
| """ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| """ | ||
| TODO | ||
|
|
||
| Tests for SeedDatasetMetadata | ||
| """ | ||
|
|
||
|
|
||
| class TestMetadataParsing: | ||
| def test_invalid_filter_key(self): | ||
| pass | ||
|
|
||
| def test_invalid_filter_value(self): | ||
| pass | ||
|
|
||
|
|
||
| class TestMetadataLifecycle: | ||
| def test_static_values_populated(self): | ||
| pass | ||
|
|
||
| def test_dynamic_values_populated(self): | ||
| pass | ||
|
|
||
|
|
||
| class TestMetadataPerformance: | ||
| def test_quick_retrieval_for_static_values(self): | ||
| pass | ||
|
|
||
| def test_acceptable_retrieval_for_dynamic_values(self): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we likely still want to load these; should we use a default harm category here?