diff --git a/src/inference_endpoint/dataset_manager/predefined/shopify_product_catalogue/__init__.py b/src/inference_endpoint/dataset_manager/predefined/shopify_product_catalogue/__init__.py index 6bae9f43..e0e3a9ac 100644 --- a/src/inference_endpoint/dataset_manager/predefined/shopify_product_catalogue/__init__.py +++ b/src/inference_endpoint/dataset_manager/predefined/shopify_product_catalogue/__init__.py @@ -32,6 +32,29 @@ logger = getLogger(__name__) +DEFAULT_CALIBRATION_SAMPLE_INDEX = { + 20232, + 21162, + 33584, + 46825, + 45190, + 46143, + 14189, + 16658, + 26406, + 9565, + 33733, + 31057, + 47465, + 33503, + 42293, + 7768, + 1962, + 39746, + 13568, + 22527, +} + def _process_sample_to_row(sample: dict[str, Any]) -> dict[str, Any]: """Convert a single HF dataset sample to a row dict for parquet storage. @@ -101,6 +124,7 @@ def generate( force: bool = False, token: str | None = None, revision: str = "main", + calibration_sample_index: set[int] | None = DEFAULT_CALIBRATION_SAMPLE_INDEX, **kwargs: Any, ) -> pd.DataFrame: """Generate the Shopify product catalogue dataset. @@ -148,6 +172,8 @@ def generate( desc=f"Converting images ({split_key})", unit="rows", ): + if calibration_sample_index is not None and i in calibration_sample_index: + continue sample = ds[i] all_rows.append(_process_sample_to_row(sample))