From 7e986d183dff93b00b303e64ad1b3a900b9ad777 Mon Sep 17 00:00:00 2001 From: Gustavo Hidalgo Date: Tue, 21 Apr 2026 19:19:05 -0400 Subject: [PATCH] stacgeoparquet improvements --- .../stac-geoparquet/pc_stac_geoparquet.py | 266 +++++++++++++----- datasets/stac-geoparquet/workflow.yaml | 1 + datasets/stac-geoparquet/workflow_manual.yaml | 40 +++ datasets/stac-geoparquet/workflow_test.yaml | 20 -- 4 files changed, 229 insertions(+), 98 deletions(-) create mode 100644 datasets/stac-geoparquet/workflow_manual.yaml delete mode 100644 datasets/stac-geoparquet/workflow_test.yaml diff --git a/datasets/stac-geoparquet/pc_stac_geoparquet.py b/datasets/stac-geoparquet/pc_stac_geoparquet.py index be235a60..3161dd61 100644 --- a/datasets/stac-geoparquet/pc_stac_geoparquet.py +++ b/datasets/stac-geoparquet/pc_stac_geoparquet.py @@ -2,6 +2,7 @@ import argparse import collections.abc +import concurrent.futures import dataclasses import datetime import hashlib @@ -9,6 +10,7 @@ import json import logging import os +import threading import time import urllib from typing import Any, Set, Union @@ -22,11 +24,10 @@ import pandas as pd import pystac import requests -from stac_geoparquet.arrow import to_parquet +from stac_geoparquet.arrow import parse_stac_items_to_arrow, to_parquet from stac_geoparquet.pgstac_reader import ( get_pgstac_partitions, Partition, - pgstac_to_arrow, pgstac_to_iter, ) @@ -46,7 +47,7 @@ logger.addHandler(handler) logger.setLevel(logging.DEBUG) -CHUNK_SIZE = 8192 +CHUNK_SIZE = 32768 PARTITION_FREQUENCIES = { "3dep-lidar-classification": "YS", @@ -105,6 +106,14 @@ "modis-21A2-061": "MS", "modis-43A4-061": "MS", "modis-64A1-061": "MS", + "met-office-global-deterministic-height": "YS", + "met-office-global-deterministic-near-surface": "YS", + "met-office-global-deterministic-pressure": "YS", + "met-office-global-deterministic-whole-atmosphere": "YS", + "met-office-uk-deterministic-height": "YS", + "met-office-uk-deterministic-near-surface": "YS", + "met-office-uk-deterministic-pressure": "YS", + "met-office-uk-deterministic-whole-atmosphere": "YS", "mtbs": None, "naip": "YS", "nasa-nex-gddp-cmip6": None, @@ -147,6 +156,7 @@ "cil-gdpcir-cc-by-sa", } + def _pairwise( iterable: collections.abc.Iterable, ) -> Any: @@ -155,6 +165,7 @@ def _pairwise( next(b, None) return zip(a, b) + def _build_output_path( base_output_path: str, part_number: int | None, @@ -179,6 +190,7 @@ def _build_output_path( ) return output_path + def inject_links(item: dict[str, Any]) -> dict[str, Any]: item["links"] = [ { @@ -235,15 +247,19 @@ def inject_assets(item: dict[str, Any], render_config: str | None) -> dict[str, } return item + def naip_year_to_int(item: dict[str, Any]) -> dict[str, Any]: """Convert the year to an integer.""" - if "naip:year" in item["properties"] and isinstance(item["properties"]["naip:year"], str): + if "naip:year" in item["properties"] and isinstance( + item["properties"]["naip:year"], str + ): item["properties"]["naip:year"] = int(item["properties"]["naip:year"]) return item + def clean_item(item: dict[str, Any], render_config: str | None) -> dict[str, Any]: - """Clean items by making sure that naip:year is an int and injecting links and assets.""" - item = inject_links(inject_assets(item, render_config)) + """Clean items by making sure that naip:year is an int and injecting links.""" + item = inject_links(item) if "proj:epsg" in item["properties"] and not item["properties"]["proj:epsg"]: # This cannot be null @@ -253,6 +269,7 @@ def clean_item(item: dict[str, Any], render_config: str | None) -> dict[str, Any item = naip_year_to_int(item) return item + @dataclasses.dataclass class CollectionConfig: """ @@ -265,6 +282,9 @@ class CollectionConfig: stac_api: str = "https://planetarycomputer.microsoft.com/api/stac/v1" should_inject_dynamic_properties: bool = True render_config: str | None = None + temporal_extent: ( + tuple[datetime.datetime | None, datetime.datetime | None] | None + ) = None def __post_init__(self) -> None: self._collection: pystac.Collection | None = None @@ -284,7 +304,10 @@ def generate_endpoints( if self.partition_frequency is None: raise ValueError("Set partition_frequency") - start_datetime, end_datetime = self.collection.extent.temporal.intervals[0] + if self.temporal_extent is not None: + start_datetime, end_datetime = self.temporal_extent + else: + start_datetime, end_datetime = self.collection.extent.temporal.intervals[0] # https://github.com/dateutil/dateutil/issues/349 if start_datetime and start_datetime.tzinfo == dateutil.tz.tz.tzlocal(): @@ -323,41 +346,36 @@ def export_partition( end_datetime: datetime.datetime | None = None, storage_options: dict[str, Any] | None = None, rewrite: bool = False, + fs: fsspec.AbstractFileSystem | None = None, ) -> str | None: - # pass - fs = fsspec.filesystem(output_protocol, **storage_options) # type: ignore + if fs is None: + fs = fsspec.filesystem(output_protocol, **(storage_options or {})) if fs.exists(output_path) and not rewrite: logger.debug("Path %s already exists.", output_path) return output_path def _row_func(item: dict[str, Any]) -> dict[str, Any]: return clean_item(item, self.render_config) - if any( - pgstac_to_iter( - conninfo=conninfo, - collection=self.collection_id, - start_datetime=start_datetime, - end_datetime=end_datetime, - row_func=_row_func, - ) - ): + + items = pgstac_to_iter( + conninfo=conninfo, + collection=self.collection_id, + start_datetime=start_datetime, + end_datetime=end_datetime, + row_func=_row_func, + ) + first = next(items, None) + if first is not None: logger.info(f"Running parquet export with chunk size of {CHUNK_SIZE}") with tempfile.TemporaryDirectory() as tmpdir: - arrow = pgstac_to_arrow( - conninfo=conninfo, - collection=self.collection_id, - start_datetime=start_datetime, - end_datetime=end_datetime, - row_func=_row_func, + arrow = parse_stac_items_to_arrow( + itertools.chain([first], items), schema="ChunksToDisk", tmpdir=tmpdir, - chunk_size=CHUNK_SIZE + chunk_size=CHUNK_SIZE, ) - to_parquet( - arrow, - output_path, - filesystem=fs) + to_parquet(arrow, output_path, filesystem=fs) return output_path def export_partition_for_endpoints( @@ -371,6 +389,7 @@ def export_partition_for_endpoints( total: int | None = None, rewrite: bool = False, skip_empty_partitions: bool = False, + fs: fsspec.AbstractFileSystem | None = None, ) -> str | None: """ Export results for a pair of endpoints. @@ -385,6 +404,7 @@ def export_partition_for_endpoints( end_datetime=end, storage_options=storage_options, rewrite=rewrite, + fs=fs, ) def export_exists( @@ -404,8 +424,10 @@ def _partition_needs_to_be_rewritten( output_path: str, storage_options: dict[str, Any], partition: Partition, + fs: fsspec.AbstractFileSystem | None = None, ) -> bool: - fs = fsspec.filesystem(output_protocol, **storage_options) + if fs is None: + fs = fsspec.filesystem(output_protocol, **storage_options) if output_protocol: output_path = f"{output_protocol}://{output_path}" if not fs.exists(output_path): @@ -419,7 +441,7 @@ def _partition_needs_to_be_rewritten( else: # Assume it's a timestamp (int/float) file_modified_time = datetime.datetime.fromtimestamp(last_modified) - + partition_modified_time = partition.last_updated return file_modified_time < partition_modified_time @@ -433,6 +455,7 @@ def export_collection( rewrite: bool = False, skip_empty_partitions: bool = False, ) -> list[str | None]: + fs = fsspec.filesystem(output_protocol, **storage_options) if not self.partition_frequency: logger.info("Exporting single-partition collection %s", self.collection_id) @@ -442,14 +465,22 @@ def export_collection( conninfo, output_protocol, output_path, - storage_options=storage_options) + storage_options=storage_options, + fs=fs, + ) ] - elif self.partition_frequency and len(pgstac_partitions[self.collection_id]) == 1: + elif ( + self.partition_frequency and len(pgstac_partitions[self.collection_id]) == 1 + ): + fs.makedirs(output_path, exist_ok=True) endpoints = self.generate_endpoints() total = len(endpoints) logger.info( - "Exporting %d partitions for collection %s with frequency %s", total, self.collection_id, self.partition_frequency + "Exporting %d partitions for collection %s with frequency %s", + total, + self.collection_id, + self.partition_frequency, ) results = [] @@ -465,6 +496,7 @@ def export_collection( skip_empty_partitions=skip_empty_partitions, part_number=i, total=total, + fs=fs, ) ) else: @@ -476,17 +508,23 @@ def export_collection( # either None/Monthly/Yearly in the collections table. # Ideal size is 10M to 20M rows per partition, but that it dataset dependent. logger.info( - "Exporting %d partitions for collection %s using pgstac partitions", total, self.collection_id + "Exporting %d partitions for collection %s using pgstac partitions", + total, + self.collection_id, ) + fs.makedirs(output_path, exist_ok=True) results = [] for i, partition in tqdm.auto.tqdm(enumerate(partitions), total=total): - partition_path = _build_output_path(output_path, i, total, partition.start, partition.end) + partition_path = _build_output_path( + output_path, i, total, partition.start, partition.end + ) if self._partition_needs_to_be_rewritten( output_protocol=output_protocol, output_path=partition_path, storage_options=storage_options, partition=partition, + fs=fs, ): results.append( self.export_partition( @@ -496,9 +534,10 @@ def export_collection( start_datetime=partition.start, end_datetime=partition.end, storage_options=storage_options, - rewrite=rewrite + rewrite=rewrite, + fs=fs, ) - ) + ) else: logger.info( "Partition %s already exists and was last updated at %s, skipping", @@ -509,6 +548,7 @@ def export_collection( return results + def build_render_config(render_params: dict[str, Any], assets: dict[str, Any]) -> str: flat = [] if assets: @@ -554,8 +594,18 @@ def generate_configs_from_api(url: str) -> dict[str, CollectionConfig]: .get("partition_frequency", None) ) + interval = ( + collection.get("extent", {}) + .get("temporal", {}) + .get("interval", [[None, None]])[0] + ) + start_dt = pd.Timestamp(interval[0]).to_pydatetime() if interval[0] else None + end_dt = pd.Timestamp(interval[1]).to_pydatetime() if interval[1] else None + configs[collection["id"]] = CollectionConfig( - collection["id"], partition_frequency=partition_frequency + collection["id"], + partition_frequency=partition_frequency, + temporal_extent=(start_dt, end_dt), ) return configs @@ -588,11 +638,13 @@ class StacGeoparquetTaskInput(PCBaseModel): storage_options_credential: str | None = None extra_skip: Set[str] | None = None collections: str | Set[str] | None = None + output_dir: str | None = None class StacGeoparquetTaskOutput(PCBaseModel): n_failures: int + class StacGeoparquetTask(Task[StacGeoparquetTaskInput, StacGeoparquetTaskOutput]): _input_model = StacGeoparquetTaskInput _output_model = StacGeoparquetTaskOutput @@ -614,6 +666,7 @@ def run( storage_options_credential=input.storage_options_credential, extra_skip=input.extra_skip, collections=input.collections, + output_dir=input.output_dir, ) return StacGeoparquetTaskOutput(n_failures=result) @@ -683,7 +736,10 @@ def list_planetary_computer_collection_configs( return configs -def get_configs(table_client: azure.data.tables.TableClient) -> dict[str, CollectionConfig]: + +def get_configs( + table_client: azure.data.tables.TableClient, +) -> dict[str, CollectionConfig]: table_configs = generate_configs_from_storage_table(table_client) api_configs = generate_configs_from_api( "https://planetarycomputer.microsoft.com/api/stac/v1/collections" @@ -695,6 +751,40 @@ def get_configs(table_client: azure.data.tables.TableClient) -> dict[str, Collec return configs + +def _export_one( + config: CollectionConfig, + i: int, + N: int, + connection_info: str, + output_protocol: str, + output_dir: str | None, + storage_options: dict[str, Any], + recent_collection_updates: dict[str, list[Partition]], +) -> tuple[str, Exception | None]: + if output_dir is not None: + output_path = os.path.join(output_dir, f"{config.collection_id}.parquet") + else: + output_path = f"items/{config.collection_id}.parquet" + try: + t0 = time.monotonic() + config.export_collection( + connection_info, + output_protocol, + output_path, + storage_options, + pgstac_partitions=recent_collection_updates, + skip_empty_partitions=True, + rewrite=True, + ) + t1 = time.monotonic() + logger.info(f"Completed {config.collection_id} [{i}/{N}] in {t1 - t0:.2f}s") + return config.collection_id, None + except Exception as e: + logger.exception(f"Failed processing {config.collection_id}") + return config.collection_id, e + + def run( output_protocol: str = "abfs", connection_info: str | None = None, @@ -713,6 +803,7 @@ def run( extra_skip: Set[str] | None = None, collections: str | Set[str] | None = None, configs: dict[str, CollectionConfig] | None = None, + output_dir: str | None = None, ) -> int: if configs is None: configs = list_planetary_computer_collection_configs( @@ -734,22 +825,26 @@ def run( "STAC_GEOPARQUET_CONNECTION_INFO must be set if not explicitly provided" ) from e table_credential = table_credential or os.environ.get( - "STAC_GEOPARQUET_TABLE_CREDENTIAL", azure.identity.ManagedIdentityCredential() + "STAC_GEOPARQUET_TABLE_CREDENTIAL", get_credential() ) assert table_credential is not None - storage_options_account_name = ( - storage_options_account_name - or os.environ["STAC_GEOPARQUET_STORAGE_OPTIONS_ACCOUNT_NAME"] - ) - storage_options_credential = storage_options_credential or os.environ.get( - "STAC_GEOPARQUET_STORAGE_OPTIONS_CREDENTIAL", - azure.identity.ManagedIdentityCredential(), - ) - storage_options = { - "account_name": storage_options_account_name, - "credential": storage_options_credential, - } + if output_dir is not None: + output_protocol = "" + storage_options: dict[str, Any] = {} + else: + storage_options_account_name = ( + storage_options_account_name + or os.environ["STAC_GEOPARQUET_STORAGE_OPTIONS_ACCOUNT_NAME"] + ) + storage_options_credential = storage_options_credential or os.environ.get( + "STAC_GEOPARQUET_STORAGE_OPTIONS_CREDENTIAL", + get_credential(), + ) + storage_options = { + "account_name": storage_options_account_name, + "credential": storage_options_credential, + } N = len(configs) success = [] @@ -761,29 +856,34 @@ def run( recent_collection_updates.setdefault(partition.collection, []).append(partition) logger.info(f"Found {len(collection_partitions)} pgstac partitions") - for i, config in enumerate(configs.values(), 1): - output_path = f"items/{config.collection_id}.parquet" - try: - t0 = time.monotonic() - config.export_collection( + config_list = list(configs.values()) + lock = threading.Lock() + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = { + executor.submit( + _export_one, + config, + i, + N, connection_info, output_protocol, - output_path, + output_dir, storage_options, - pgstac_partitions=recent_collection_updates, - skip_empty_partitions=True, - rewrite=True - ) - t1 = time.monotonic() - logger.info(f"Completed {config.collection_id} [{i}/{N}] in {t1-t0:.2f}s") - except Exception as e: - failure.append((config.collection_id, e)) - logger.exception(f"Failed processing {config.collection_id}") - else: - success.append(config.collection_id) + recent_collection_updates, + ): config + for i, config in enumerate(config_list, 1) + } + for future in concurrent.futures.as_completed(futures): + collection_id, exc = future.result() + with lock: + if exc is None: + success.append(collection_id) + else: + failure.append((collection_id, exc)) return len(failure) + if __name__ == "__main__": # Remove all handlers associated with the root logger object. for h in logging.root.handlers[:]: @@ -792,26 +892,36 @@ def run( logging.basicConfig(handlers=[handler], level=logging.DEBUG, force=True) logging.getLogger().setLevel(logging.WARNING) logger.setLevel(logging.DEBUG) - logging.getLogger("stac_geoparquet").setLevel(logging.DEBUG) - parser = argparse.ArgumentParser(description="Export STAC collection to GeoParquet.") + logging.getLogger("stac_geoparquet").setLevel(logging.INFO) + parser = argparse.ArgumentParser( + description="Export STAC collection to GeoParquet." + ) parser.add_argument( - "--collection", + "--collection", type=str, required=False, help="The collection ID to export." + ) + parser.add_argument( + "--output-dir", type=str, required=False, - help="The collection ID to export." + default=None, + help="Write output to this local directory instead of Azure blob storage. Use '.' for the current directory.", ) args = parser.parse_args() configs = list_planetary_computer_collection_configs( connection_info=os.environ["STAC_GEOPARQUET_CONNECTION_INFO"], - table_credential=azure.identity.ManagedIdentityCredential(), + table_credential=get_credential(), table_name=os.environ["STAC_GEOPARQUET_TABLE_NAME"], table_account_url=os.environ["STAC_GEOPARQUET_TABLE_ACCOUNT_URL"], - storage_options_account_name=os.environ["STAC_GEOPARQUET_STORAGE_OPTIONS_ACCOUNT_NAME"], - storage_options_credential=azure.identity.ManagedIdentityCredential(), + storage_options_account_name=os.environ[ + "STAC_GEOPARQUET_STORAGE_OPTIONS_ACCOUNT_NAME" + ], + storage_options_credential=get_credential(), extra_skip=SKIP, collections=args.collection, ) - n_failures = run(collections=args.collection, configs=configs) + n_failures = run( + collections=args.collection, configs=configs, output_dir=args.output_dir + ) if n_failures == 0: logger.info("Export completed successfully.") else: diff --git a/datasets/stac-geoparquet/workflow.yaml b/datasets/stac-geoparquet/workflow.yaml index 21bee30f..0deaee69 100644 --- a/datasets/stac-geoparquet/workflow.yaml +++ b/datasets/stac-geoparquet/workflow.yaml @@ -28,3 +28,4 @@ jobs: environment: APPLICATIONINSIGHTS_CONNECTION_STRING: ${{ secrets.task-application-insights-connection-string }} STAC_GEOPARQUET_CONNECTION_INFO: ${{secrets.pgstac-connection-string}} + TMPDIR: /mnt/resource diff --git a/datasets/stac-geoparquet/workflow_manual.yaml b/datasets/stac-geoparquet/workflow_manual.yaml new file mode 100644 index 00000000..25224853 --- /dev/null +++ b/datasets/stac-geoparquet/workflow_manual.yaml @@ -0,0 +1,40 @@ +name: stac-geoparquet-manual +dataset: stac-geoparquet-manual +id: stac-geoparquet-manual + +jobs: + geoparquet: + tasks: + - id: update + image: pccomponents.azurecr.io/pctasks-stac-geoparquet:2026.01.05.1 + code: + src: ${{ local.path(pc_stac_geoparquet.py) }} + requirements: ${{ local.path(./requirements.txt) }} + task: pc_stac_geoparquet:StacGeoparquetTask + tags: + # # temporarily set this when creating the geoparquet files + # # for a new collection to avoid OOM kills + batch_pool_id: tsk_db_gr + args: + table_account_url: "https://pcapi.table.core.windows.net" + table_name: "greencollectionconfig" + storage_options_account_name: "pcstacitems" + # collections: "chesapeake-lu" # Set if you want to generate only one geoparquet file + collections: + - "met-office-global-deterministic-height" + - "met-office-global-deterministic-near-surface" + - "met-office-global-deterministic-pressure" + - "met-office-global-deterministic-whole-atmosphere" + - "met-office-uk-deterministic-height" + - "met-office-uk-deterministic-near-surface" + - "met-office-uk-deterministic-pressure" + - "met-office-uk-deterministic-whole-atmosphere" + extra_skip: + - "chesapeake-lc-13" + - "chesapeake-lc-7" + - "chesapeake-lu" + - "drcog-lulc" + environment: + APPLICATIONINSIGHTS_CONNECTION_STRING: ${{ secrets.task-application-insights-connection-string }} + STAC_GEOPARQUET_CONNECTION_INFO: ${{secrets.pgstac-connection-string}} + TMPDIR: /mnt/resource diff --git a/datasets/stac-geoparquet/workflow_test.yaml b/datasets/stac-geoparquet/workflow_test.yaml deleted file mode 100644 index abd21125..00000000 --- a/datasets/stac-geoparquet/workflow_test.yaml +++ /dev/null @@ -1,20 +0,0 @@ -name: stac-geoparquet -dataset: microsoft/stac-geoparquet -id: stac-geoparquet - -jobs: - stac: - tasks: - - id: create - image: pccomponentstest.azurecr.io/pctasks-stac-geoparquet:2025.10.28.5 - code: - src: ${{ local.path(pc_stac_geoparquet.py) }} - task: pc_stac_geoparquet:StacGeoparquetTask - args: - table_account_url: "https://pctapisstagingsa.table.core.windows.net" - table_name: "collectionconfig" - storage_options_account_name: "pcstacitems" - collections: "sentinel-2-l2a" - environment: - APPLICATIONINSIGHTS_CONNECTION_STRING: ${{ secrets.task-application-insights-connection-string }} - STAC_GEOPARQUET_CONNECTION_INFO: ${{secrets.pgstac-connection-string}}