diff --git a/CHANGELOG.md b/CHANGELOG.md index de07d1afd..98686e692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,15 +9,29 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +### Changed + +### Removed + +### Fixed + + +## [0.47.0] - 2025-12-17 + +### Added + - `MultiBackendJobManager`: add `download_results` option to enable/disable the automated download of job results once completed by the job manager ([#744](https://github.com/Open-EO/openeo-python-client/issues/744)) +- Support UDF based spatial and temporal extents in `load_collection`, `load_stac` and `filter_temporal` ([#831](https://github.com/Open-EO/openeo-python-client/pull/831)) +- `MultiBackendJobManager`: keep number of "queued" jobs below 10 for better CDSE compatibility ([#839](https://github.com/Open-EO/openeo-python-client/pull/839), eu-cdse/openeo-cdse-infra#859) ### Changed - Internal reorganization of `openeo.extra.job_management` submodule to ease future development ([#741](https://github.com/Open-EO/openeo-python-client/issues/741)) +- `openeo.Connection`: add some more HTTP error codes to the default retry list: `502 Bad Gateway`, `503 Service Unavailable` and `504 Gateway Timeout` ([#835](https://github.com/Open-EO/openeo-python-client/issues/835)) ### Removed -### Fixed +- Remove `Connection.load_disk_collection` (wrapper for non-standard `load_disk_data` process), deprecated since version 0.25.0 (related to [Open-EO/openeo-geopyspark-driver#1457](https://github.com/Open-EO/openeo-geopyspark-driver/issues/1457)) ## [0.46.0] - 2025-10-31 diff --git a/openeo/_version.py b/openeo/_version.py index 5eb931bdc..f640a084f 100644 --- a/openeo/_version.py +++ b/openeo/_version.py @@ -1 +1 @@ -__version__ = "0.47.0a3" +__version__ = "0.48.0a1" diff --git a/openeo/extra/job_management/_manager.py b/openeo/extra/job_management/_manager.py index 2ce78ba52..b4b2bc0e5 100644 --- a/openeo/extra/job_management/_manager.py +++ b/openeo/extra/job_management/_manager.py @@ -32,6 +32,7 @@ from openeo.extra.job_management._thread_worker import ( _JobManagerWorkerThreadPool, _JobStartTask, + _JobDownloadTask ) from openeo.rest import OpenEoApiError from openeo.rest.auth.auth import BearerAuth @@ -60,6 +61,9 @@ class _Backend(NamedTuple): # Maximum number of jobs to allow in parallel on a backend parallel_jobs: int + # Maximum number of jobs to allow in queue on a backend + queueing_limit: int = 10 + @dataclasses.dataclass(frozen=True) class _ColumnProperties: @@ -172,6 +176,7 @@ def start_job( .. versionchanged:: 0.47.0 Added ``download_results`` parameter. + """ # Expected columns in the job DB dataframes. @@ -229,7 +234,13 @@ def add_backend( parallel_jobs: int = 2, ): """ - Register a backend with a name and a Connection getter. + Register a backend with a name and a :py:class:`Connection` getter. + + .. note:: + For optimal throughput and responsiveness, it is recommended + to provide a :py:class:`Connection` instance without its own (blocking) retry behavior, + since the job manager will do retries in a non-blocking way, + allowing to take care of other tasks before retrying failed requests. :param name: Name of the backend. @@ -246,7 +257,8 @@ def add_backend( c = connection connection = lambda: c assert callable(connection) - self.backends[name] = _Backend(get_connection=connection, parallel_jobs=parallel_jobs) + # TODO: expose queueing_limit? + self.backends[name] = _Backend(get_connection=connection, parallel_jobs=parallel_jobs, queueing_limit=10) def _get_connection(self, backend_name: str, resilient: bool = True) -> Connection: """Get a connection for the backend and optionally make it resilient (adds retry behavior) @@ -363,6 +375,9 @@ def run_loop(): ).values() ) > 0 + + or (self._worker_pool.number_pending_tasks() > 0) + and not self._stop_thread ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) @@ -388,7 +403,10 @@ def stop_job_thread(self, timeout_seconds: Optional[float] = _UNSET): .. versionadded:: 0.32.0 """ - self._worker_pool.shutdown() + if self._worker_pool is not None or self._worker_pool.number_pending_tasks() > 0: + self._worker_pool.shutdown() + self._worker_pool = None + if self._thread is not None: self._stop_thread = True @@ -494,13 +512,15 @@ def run_jobs( self._worker_pool = _JobManagerWorkerThreadPool() + while ( sum( job_db.count_by_status( statuses=["not_started", "created", "queued_for_start", "queued", "running"] - ).values() - ) - > 0 + ).values()) > 0 + + or (self._worker_pool.number_pending_tasks() > 0) + ): self._job_update_loop(job_db=job_db, start_job=start_job, stats=stats) stats["run_jobs loop"] += 1 @@ -510,8 +530,10 @@ def run_jobs( time.sleep(self.poll_sleep) stats["sleep"] += 1 - # TODO; run post process after shutdown once more to ensure completion? + + self._worker_pool.shutdown() + self._worker_pool = None return stats @@ -534,17 +556,21 @@ def _job_update_loop( not_started = job_db.get_by_status(statuses=["not_started"], max=200).copy() if len(not_started) > 0: - # Check number of jobs running at each backend + # Check number of jobs queued/running at each backend # TODO: should "created" be included in here? Calling this "running" is quite misleading then. + # apparently (see #839/#840) this seemingly simple change makes a lot of MultiBackendJobManager tests flaky running = job_db.get_by_status(statuses=["created", "queued", "queued_for_start", "running"]) - stats["job_db get_by_status"] += 1 - per_backend = running.groupby("backend_name").size().to_dict() - _log.info(f"Running per backend: {per_backend}") + queued = running[running["status"] == "queued"] + running_per_backend = running.groupby("backend_name").size().to_dict() + queued_per_backend = queued.groupby("backend_name").size().to_dict() + _log.info(f"{running_per_backend=} {queued_per_backend=}") + total_added = 0 for backend_name in self.backends: - backend_load = per_backend.get(backend_name, 0) - if backend_load < self.backends[backend_name].parallel_jobs: - to_add = self.backends[backend_name].parallel_jobs - backend_load + queue_capacity = self.backends[backend_name].queueing_limit - queued_per_backend.get(backend_name, 0) + run_capacity = self.backends[backend_name].parallel_jobs - running_per_backend.get(backend_name, 0) + to_add = min(queue_capacity, run_capacity) + if to_add > 0: for i in not_started.index[total_added : total_added + to_add]: self._launch_job(start_job, df=not_started, i=i, backend_name=backend_name, stats=stats) stats["job launch"] += 1 @@ -553,7 +579,9 @@ def _job_update_loop( stats["job_db persist"] += 1 total_added += 1 - self._process_threadworker_updates(self._worker_pool, job_db=job_db, stats=stats) + if self._worker_pool is not None: + self._process_threadworker_updates(worker_pool=self._worker_pool, job_db=job_db, stats=stats) + # TODO: move this back closer to the `_track_statuses` call above, once job done/error handling is also handled in threads? for job, row in jobs_done: @@ -565,6 +593,7 @@ def _job_update_loop( for job, row in jobs_cancel: self.on_job_cancel(job, row) + def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = None): """Helper method for launching jobs @@ -629,7 +658,7 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No df_idx=i, ) _log.info(f"Submitting task {task} to thread pool") - self._worker_pool.submit_task(task) + self._worker_pool.submit_task(task=task, pool_name="job_start") stats["job_queued_for_start"] += 1 df.loc[i, "status"] = "queued_for_start" @@ -675,7 +704,7 @@ def _process_threadworker_updates( :param stats: Dictionary accumulating statistic counters """ # Retrieve completed task results immediately - results, _ = worker_pool.process_futures(timeout=0) + results = worker_pool.process_futures(timeout=0) # Collect update dicts updates: List[Dict[str, Any]] = [] @@ -721,17 +750,28 @@ def on_job_done(self, job: BatchJob, row): :param job: The job that has finished. :param row: DataFrame row containing the job's metadata. """ - # TODO: param `row` is never accessed in this method. Remove it? Is this intended for future use? if self._download_results: - job_metadata = job.describe() - job_dir = self.get_job_dir(job.job_id) - metadata_path = self.get_job_metadata_path(job.job_id) + job_dir = self.get_job_dir(job.job_id) self.ensure_job_dir_exists(job.job_id) - job.get_results().download_files(target=job_dir) - with metadata_path.open("w", encoding="utf-8") as f: - json.dump(job_metadata, f, ensure_ascii=False) + # Proactively refresh bearer token (because task in thread will not be able to do that + job_con = job.connection + self._refresh_bearer_token(connection=job_con) + + task = _JobDownloadTask( + job_id=job.job_id, + df_idx=row.name, #this is going to be the index in the not saterted dataframe; should not be an issue as there is no db update for download task + root_url=job_con.root_url, + bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None, + download_dir=job_dir, + ) + _log.info(f"Submitting download task {task} to download thread pool") + + if self._worker_pool is None: + self._worker_pool = _JobManagerWorkerThreadPool() + + self._worker_pool.submit_task(task=task, pool_name="job_download") def on_job_error(self, job: BatchJob, row): """ @@ -783,6 +823,7 @@ def _cancel_prolonged_job(self, job: BatchJob, row): except Exception as e: _log.error(f"Unexpected error while handling job {job.job_id}: {e}") + #TODO pull this functionality away from the manager to a general utility class? job dir creation could be reused for tje Jobdownload task def get_job_dir(self, job_id: str) -> Path: """Path to directory where job metadata, results and error logs are be saved.""" return self._root_dir / f"job_{job_id}" diff --git a/openeo/extra/job_management/_thread_worker.py b/openeo/extra/job_management/_thread_worker.py index 6040fade1..86610f50f 100644 --- a/openeo/extra/job_management/_thread_worker.py +++ b/openeo/extra/job_management/_thread_worker.py @@ -7,7 +7,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union +from pathlib import Path +import json import urllib3.util import openeo @@ -99,7 +101,7 @@ def get_connection(self, retry: Union[urllib3.util.Retry, dict, bool, None] = No connection.authenticate_bearer_token(self.bearer_token) return connection - +@dataclass(frozen=True) class _JobStartTask(ConnectedTask): """ Task for starting an openEO batch job (the `POST /jobs//result` request). @@ -139,9 +141,52 @@ def execute(self) -> _TaskResult: db_update={"status": "start_failed"}, stats_update={"start_job error": 1}, ) + +@dataclass(frozen=True) +class _JobDownloadTask(ConnectedTask): + """ + Task for downloading job results and metadata. + :param download_dir: + Root directory where job results and metadata will be downloaded. + """ + download_dir: Path = field(default=None, repr=False) -class _JobManagerWorkerThreadPool: + def execute(self) -> _TaskResult: + + try: + job = self.get_connection(retry=True).job(self.job_id) + + # Download results + job.get_results().download_files(target=self.download_dir) + + # Count assets (files to download) + assets = job.list_results().get('assets', {}) + file_count = len(assets) + + # Download metadata + job_metadata = job.describe() + metadata_path = self.download_dir / f"job_{self.job_id}.json" + with metadata_path.open("w", encoding="utf-8") as f: + json.dump(job_metadata, f, ensure_ascii=False) + + _log.info(f"Job {self.job_id!r} results downloaded successfully") + return _TaskResult( + job_id=self.job_id, + df_idx=self.df_idx, + db_update={}, #TODO consider db updates? + stats_update={"job download": 1, "files downloaded": file_count}, + ) + except Exception as e: + _log.error(f"Failed to download results for job {self.job_id!r}: {e!r}") + return _TaskResult( + job_id=self.job_id, + df_idx=self.df_idx, + db_update={}, + stats_update={"job download error": 1, "files downloaded": 0}, + ) + +class _TaskThreadPool: """ Thread pool-based worker that manages the execution of asynchronous tasks. @@ -150,12 +195,14 @@ class _JobManagerWorkerThreadPool: :param max_workers: Maximum number of concurrent threads to use for execution. - Defaults to 2. + Defaults to 1. """ - def __init__(self, max_workers: int = 2): + def __init__(self, max_workers: int = 1, name: str = 'default'): self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) self._future_task_pairs: List[Tuple[concurrent.futures.Future, Task]] = [] + self._name = name + self._max_workers = max_workers def submit_task(self, task: Task) -> None: """ @@ -206,9 +253,99 @@ def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskRe _log.info("process_futures: %d tasks done, %d tasks remaining", len(results), len(to_keep)) self._future_task_pairs = to_keep - return results, len(to_keep) + return results + + def number_pending_tasks(self) -> int: + """Return the number of tasks that are still pending (not completed).""" + return len(self._future_task_pairs) def shutdown(self) -> None: """Shuts down the thread pool gracefully.""" _log.info("Shutting down thread pool") self._executor.shutdown(wait=True) + + +class _JobManagerWorkerThreadPool: + + """ + Generic wrapper that manages multiple thread pools with a dict. + """ + + def __init__(self, pool_configs: Optional[Dict[str, int]] = None): + """ + :param pool_configs: Dict of pool_name -> max_workers + Example: {"job_start": 1, "download": 2} + """ + self._pools: Dict[str, _TaskThreadPool] = {} + self._pool_configs = pool_configs or {} + + # Create all pools upfront from config + for pool_name, max_workers in self._pool_configs.items(): + self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers) + _log.info(f"Created pool '{pool_name}' with {max_workers} workers") + + def submit_task(self, task: Task, pool_name: str = "default") -> None: + """ + Submit a task to a specific pool. + Creates pool dynamically only if not in config. + """ + if pool_name not in self._pools: + # Check if pool_name is in config but somehow wasn't created + if pool_name in self._pool_configs: + # This shouldn't happen, but create it + max_workers = self._pool_configs[pool_name] + self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers, name=pool_name) + _log.warning(f"Created missing pool '{pool_name}' from config with {max_workers} workers") + else: + # Not in config - create with default + max_workers = 1 + self._pools[pool_name] = _TaskThreadPool(max_workers=max_workers, name=pool_name) + _log.info(f"Created dynamic pool '{pool_name}' with {max_workers} workers") + + self._pools[pool_name].submit_task(task) + + def process_futures(self, timeout: Union[float, None] = 0) -> Tuple[List[_TaskResult], Dict[str, int]]: + """ + Process updates from ALL pools. + Returns: (all_results, dict of remaining tasks per pool) + """ + all_results = [] + + for pool_name, pool in self._pools.items(): + results = pool.process_futures(timeout) + all_results.extend(results) + + return all_results + + def number_pending_tasks(self, pool_name: Optional[str] = None) -> int: + if pool_name: + pool = self._pools.get(pool_name) + return pool.number_pending_tasks() if pool else 0 + else: + return sum(pool.number_pending_tasks() for pool in self._pools.values()) + + def shutdown(self, pool_name: Optional[str] = None) -> None: + """ + Shutdown pools. + If pool_name is None, shuts down all pools. + """ + if pool_name: + if pool_name in self._pools: + self._pools[pool_name].shutdown() + del self._pools[pool_name] + if pool_name in self._pool_configs: + del self._pool_configs[pool_name] + else: + for pool_name, pool in list(self._pools.items()): + pool.shutdown() + del self._pools[pool_name] + if pool_name in self._pool_configs: + del self._pool_configs[pool_name] + + self._pool_configs.clear() + + def list_pools(self) -> List[str]: + """List all active pool names.""" + return list(self._pools.keys()) + + diff --git a/openeo/internal/graph_building.py b/openeo/internal/graph_building.py index 6f5918ea2..10b395fdd 100644 --- a/openeo/internal/graph_building.py +++ b/openeo/internal/graph_building.py @@ -95,6 +95,8 @@ def print_json( class _FromNodeMixin(abc.ABC): """Mixin for classes that want to hook into the generation of a "from_node" reference.""" + # TODO: rename this class: it's more an interface than a mixin, and "from node" might be confusing as explained below. + @abc.abstractmethod def from_node(self) -> PGNode: # TODO: "from_node" is a bit a confusing name: diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 5625ef479..25f9030fb 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -35,7 +35,12 @@ import openeo from openeo.config import config_log, get_config_option from openeo.internal.documentation import openeo_process -from openeo.internal.graph_building import FlatGraphableMixin, PGNode, as_flat_graph +from openeo.internal.graph_building import ( + FlatGraphableMixin, + PGNode, + _FromNodeMixin, + as_flat_graph, +) from openeo.internal.jupyter import VisualDict, VisualList from openeo.internal.processes.builder import ProcessBuilderBase from openeo.internal.warnings import deprecated, legacy_alias @@ -1186,8 +1191,8 @@ def load_collection( self, collection_id: Union[str, Parameter], spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None, - temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, - bands: Union[Iterable[str], Parameter, str, None] = None, + temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, + bands: Union[Iterable[str], Parameter, str, _FromNodeMixin, None] = None, properties: Union[ Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None ] = None, @@ -1287,8 +1292,10 @@ def load_result( def load_stac( self, url: str, - spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, None] = None, - temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, + spatial_extent: Union[ + dict, Parameter, shapely.geometry.base.BaseGeometry, str, Path, _FromNodeMixin, None + ] = None, + temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, bands: Union[Iterable[str], Parameter, str, None] = None, properties: Union[ Dict[str, Union[PGNode, Callable]], List[CollectionProperty], CollectionProperty, None @@ -1873,26 +1880,6 @@ def service(self, service_id: str) -> Service: """ return Service(service_id, connection=self) - @deprecated( - reason="Depends on non-standard process, replace with :py:meth:`openeo.rest.connection.Connection.load_stac` where possible.", - version="0.25.0") - def load_disk_collection( - self, format: str, glob_pattern: str, options: Optional[dict] = None - ) -> DataCube: - """ - Loads image data from disk as a :py:class:`DataCube`. - - This is backed by a non-standard process ('load_disk_data'). This will eventually be replaced by standard options such as - :py:meth:`openeo.rest.connection.Connection.load_stac` or https://processes.openeo.org/#load_uploaded_files - - :param format: the file format, e.g. 'GTiff' - :param glob_pattern: a glob pattern that matches the files to load from disk - :param options: options specific to the file format - """ - return DataCube.load_disk_collection( - self, format, glob_pattern, **(options or {}) - ) - def as_curl( self, data: Union[dict, DataCube, FlatGraphableMixin], diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index 00d850d09..238d0bd51 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -91,7 +91,7 @@ # Type annotation aliases -InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, None] +InputDate = Union[str, datetime.date, Parameter, PGNode, ProcessBuilderBase, _FromNodeMixin, None] class DataCube(_ProcessGraphAbstraction): @@ -165,8 +165,10 @@ def load_collection( cls, collection_id: Union[str, Parameter], connection: Optional[Connection] = None, - spatial_extent: Union[dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, None] = None, - temporal_extent: Union[Sequence[InputDate], Parameter, str, None] = None, + spatial_extent: Union[ + dict, Parameter, shapely.geometry.base.BaseGeometry, str, pathlib.Path, _FromNodeMixin, None + ] = None, + temporal_extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, bands: Union[Iterable[str], Parameter, str, None] = None, fetch_metadata: bool = True, properties: Union[ @@ -307,31 +309,6 @@ def _build_load_properties_argument( return properties - @classmethod - @deprecated(reason="Depends on non-standard process, replace with :py:meth:`openeo.rest.connection.Connection.load_stac` where possible.",version="0.25.0") - def load_disk_collection(cls, connection: Connection, file_format: str, glob_pattern: str, **options) -> DataCube: - """ - Loads image data from disk as a DataCube. - This is backed by a non-standard process ('load_disk_data'). This will eventually be replaced by standard options such as - :py:meth:`openeo.rest.connection.Connection.load_stac` or https://processes.openeo.org/#load_uploaded_files - - - :param connection: The connection to use to connect with the backend. - :param file_format: the file format, e.g. 'GTiff' - :param glob_pattern: a glob pattern that matches the files to load from disk - :param options: options specific to the file format - :return: the data as a DataCube - """ - pg = PGNode( - process_id='load_disk_data', - arguments={ - 'format': file_format, - 'glob_pattern': glob_pattern, - 'options': options - } - ) - return cls(graph=pg, connection=connection) - @classmethod def load_stac( cls, @@ -505,22 +482,22 @@ def _get_temporal_extent( *args, start_date: InputDate = None, end_date: InputDate = None, - extent: Union[Sequence[InputDate], Parameter, str, None] = None, - ) -> Union[List[Union[str, Parameter, PGNode, None]], Parameter]: + extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, + ) -> Union[List[Union[str, Parameter, PGNode, _FromNodeMixin, None]], Parameter, _FromNodeMixin]: """Parameter aware temporal_extent normalizer""" # TODO: move this outside of DataCube class # TODO: return extent as tuple instead of list - if len(args) == 1 and isinstance(args[0], Parameter): + if len(args) == 1 and isinstance(args[0], (Parameter, _FromNodeMixin)): assert start_date is None and end_date is None and extent is None return args[0] - elif len(args) == 0 and isinstance(extent, Parameter): + elif len(args) == 0 and isinstance(extent, (Parameter, _FromNodeMixin)): assert start_date is None and end_date is None # TODO: warn about unexpected parameter schema return extent else: def convertor(d: Any) -> Any: # TODO: can this be generalized through _FromNodeMixin? - if isinstance(d, Parameter) or isinstance(d, PGNode): + if isinstance(d, Parameter) or isinstance(d, _FromNodeMixin): # TODO: warn about unexpected parameter schema return d elif isinstance(d, ProcessBuilderBase): @@ -556,7 +533,7 @@ def filter_temporal( *args, start_date: InputDate = None, end_date: InputDate = None, - extent: Union[Sequence[InputDate], Parameter, str, None] = None, + extent: Union[Sequence[InputDate], Parameter, str, _FromNodeMixin, None] = None, ) -> DataCube: """ Limit the DataCube to a certain date range, which can be specified in several ways: diff --git a/openeo/testing/stac.py b/openeo/testing/stac.py index 2b8507af1..e64121667 100644 --- a/openeo/testing/stac.py +++ b/openeo/testing/stac.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union class StacDummyBuilder: @@ -117,13 +117,25 @@ def asset( cls, href: str = "https://stac.test/asset.tiff", type: str = "image/tiff; application=geotiff", - roles: Optional[List[str]] = None, + roles: Union[Sequence[str], None] = ("data",), + proj_code: Optional[str] = None, + proj_bbox: Optional[Sequence[float]] = None, + proj_shape: Optional[Sequence[int]] = None, + proj_transform: Optional[Sequence[float]] = None, **kwargs, ): - d = {"href": href, **kwargs} + d: Dict[str, Any] = {"href": href, **kwargs} if type: d["type"] = type - if roles is not None: - d["roles"] = roles + if roles: + d["roles"] = list(roles) + if proj_code: + d["proj:code"] = proj_code + if proj_bbox: + d["proj:bbox"] = list(proj_bbox) + if proj_shape: + d["proj:shape"] = list(proj_shape) + if proj_transform: + d["proj:transform"] = list(proj_transform) return d diff --git a/openeo/utils/http.py b/openeo/utils/http.py index c314e59b8..6c8e7258d 100644 --- a/openeo/utils/http.py +++ b/openeo/utils/http.py @@ -43,6 +43,9 @@ DEFAULT_RETRY_FORCELIST = frozenset( [ HTTP_429_TOO_MANY_REQUESTS, + HTTP_502_BAD_GATEWAY, + HTTP_503_SERVICE_UNAVAILABLE, + HTTP_504_GATEWAY_TIMEOUT, ] ) diff --git a/tests/extra/job_management/test_manager.py b/tests/extra/job_management/test_manager.py index 1d02afb1c..2c4162974 100644 --- a/tests/extra/job_management/test_manager.py +++ b/tests/extra/job_management/test_manager.py @@ -729,7 +729,7 @@ def get_status(job_id, current_status): assert isinstance(rfc3339.parse_datetime(filled_running_start_time), datetime.datetime) def test_process_threadworker_updates(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) # Submit tasks covering all cases @@ -769,7 +769,7 @@ def test_process_threadworker_updates(self, tmp_path, caplog): assert caplog.messages == [] def test_process_threadworker_updates_unknown(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) pool.submit_task(DummyResultTask("j-123", df_idx=0, db_update={"status": "queued"}, stats_update={"queued": 1})) @@ -806,7 +806,7 @@ def test_process_threadworker_updates_unknown(self, tmp_path, caplog): assert caplog.messages == [dirty_equals.IsStr(regex=".*Ignoring unknown.*indices.*4.*")] def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) df_initial = pd.DataFrame({"id": ["j-0"], "status": ["created"]}) @@ -820,7 +820,7 @@ def test_no_results_leaves_db_and_stats_untouched(self, tmp_path, caplog): assert stats == {} def test_logs_on_invalid_update(self, tmp_path, caplog): - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _JobManagerWorkerThreadPool() stats = collections.defaultdict(int) # Malformed db_update (not a dict unpackable via **) diff --git a/tests/extra/job_management/test_thread_worker.py b/tests/extra/job_management/test_thread_worker.py index 52ee833f1..b1f830f77 100644 --- a/tests/extra/job_management/test_thread_worker.py +++ b/tests/extra/job_management/test_thread_worker.py @@ -3,14 +3,18 @@ import time from dataclasses import dataclass from typing import Iterator +from pathlib import Path +from requests_mock import Mocker import pytest from openeo.extra.job_management._thread_worker import ( Task, + _TaskThreadPool, _JobManagerWorkerThreadPool, _JobStartTask, _TaskResult, + _JobDownloadTask ) from openeo.rest._testing import DummyBackend @@ -79,6 +83,130 @@ def test_hide_token(self, serializer): assert "job-123" in serialized assert secret not in serialized +class TestJobDownloadTask: + + + def test_job_download_success(self, requests_mock: Mocker, tmp_path: Path): + """ + Test a successful job download and verify file content and stats update. + """ + job_id = "job-007" + df_idx = 42 + + # We set up a dummy backend to simulate the job results and assert the expected calls are triggered + backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock) + backend.next_result = b"The downloaded file content." + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} + + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["status"] = "finished" + + download_dir = tmp_path / job_id / "results" + download_dir.mkdir(parents=True) + + # Create the task instance + task = _JobDownloadTask( + root_url="https://openeo.dummy.test/", + bearer_token="dummy-token-7", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, + ) + + # Execute the task + result = task.execute() + + # Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # Verify stats update for the MultiBackendJobManager + assert result.stats_update == {'files downloaded': 1, "job download": 1} + + # Verify download content (crucial part of the unit test) + downloaded_file = download_dir / "result.data" + assert downloaded_file.exists() + assert downloaded_file.read_bytes() == b"The downloaded file content." + + # Verify metadata file + metadata_file = download_dir / f"job_{job_id}.json" + assert metadata_file.exists() + metadata = metadata_file.read_text() + assert job_id in metadata + + + def test_job_download_failure(self, requests_mock: Mocker, tmp_path: Path): + """ + Test a failed download (e.g., bad connection) and verify error reporting. + """ + job_id = "job-008" + df_idx = 55 + + # Set up dummy backend to simulate failure during results listing + backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock) + + #simulate and error when downloading the results + requests_mock.get( + f"https://openeo.dummy.test/jobs/{job_id}/results", + status_code=500, + json={"code": "InternalError", "message": "Failed to list results"}) + + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "created"} + backend._set_job_status(job_id=job_id, status="finished") + backend.batch_jobs[job_id]["finished"] = "error" + + download_dir = tmp_path / job_id / "results" + download_dir.mkdir(parents=True) + + # Create the task instance + task = _JobDownloadTask( + root_url="https://openeo.dummy.test/", + bearer_token="dummy-token-8", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, + ) + + # Execute the task + result = task.execute() + + # Verify TaskResult structure + assert isinstance(result, _TaskResult) + assert result.job_id == job_id + assert result.df_idx == df_idx + + # Verify stats update for the MultiBackendJobManager + assert result.stats_update == {'files downloaded': 0, "job download error": 1} + + # Verify no file was created (or only empty/failed files) + assert not any(p.is_file() for p in download_dir.glob("*")) + + def test_download_directory_permission_error(self, requests_mock: Mocker, tmp_path: Path): + """Test download when directory has permission issues.""" + job_id = "job-perm" + df_idx = 99 + + backend = DummyBackend.at_url("https://openeo.dummy.test/", requests_mock=requests_mock) + backend.batch_jobs[job_id] = {"job_id": job_id, "pg": {}, "status": "finished"} + + # Create a read-only directory (simulate permission issue on some systems) + download_dir = tmp_path / "readonly" + download_dir.mkdir() + + task = _JobDownloadTask( + root_url="https://openeo.dummy.test/", + bearer_token="token", + job_id=job_id, + df_idx=df_idx, + download_dir=download_dir, + ) + + # Should handle permission error gracefully + result = task.execute() + assert result.job_id == job_id + assert result.stats_update["job download error"] == 1 + class NopTask(Task): """Do Nothing""" @@ -116,40 +244,37 @@ def execute(self) -> _TaskResult: return _TaskResult(job_id=self.job_id, df_idx=self.df_idx, db_update={"status": "all fine"}) -class TestJobManagerWorkerThreadPool: +class TestTaskThreadPool: @pytest.fixture - def worker_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + def worker_pool(self) -> Iterator[_TaskThreadPool]: """Fixture for creating and cleaning up a worker thread pool.""" - pool = _JobManagerWorkerThreadPool(max_workers=2) + pool = _TaskThreadPool() yield pool pool.shutdown() def test_no_tasks(self, worker_pool): - results, remaining = worker_pool.process_futures(timeout=10) + results = worker_pool.process_futures(timeout=10) assert results == [] - assert remaining == 0 def test_submit_and_process(self, worker_pool): worker_pool.submit_task(DummyTask(job_id="j-123", df_idx=0)) - results, remaining = worker_pool.process_futures(timeout=10) + results = worker_pool.process_futures(timeout=10) assert results == [ _TaskResult(job_id="j-123", df_idx=0, db_update={"status": "dummified"}, stats_update={"dummy": 1}), ] - assert remaining == 0 def test_submit_and_process_zero_timeout(self, worker_pool): worker_pool.submit_task(DummyTask(job_id="j-123", df_idx=0)) # Trigger context switch time.sleep(0.1) - results, remaining = worker_pool.process_futures(timeout=0) + results = worker_pool.process_futures(timeout=0) assert results == [ _TaskResult(job_id="j-123", df_idx=0, db_update={"status": "dummified"}, stats_update={"dummy": 1}), ] - assert remaining == 0 def test_submit_and_process_with_error(self, worker_pool): worker_pool.submit_task(DummyTask(job_id="j-666", df_idx=0)) - results, remaining = worker_pool.process_futures(timeout=10) + results = worker_pool.process_futures(timeout=10) assert results == [ _TaskResult( job_id="j-666", @@ -158,20 +283,18 @@ def test_submit_and_process_with_error(self, worker_pool): stats_update={"threaded task failed": 1}, ), ] - assert remaining == 0 + def test_submit_and_process_iterative(self, worker_pool): worker_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [_TaskResult(job_id="j-1", df_idx=1)] - assert remaining == 0 # Add some more worker_pool.submit_task(NopTask(job_id="j-22", df_idx=22)) worker_pool.submit_task(NopTask(job_id="j-222", df_idx=222)) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [_TaskResult(job_id="j-22", df_idx=22), _TaskResult(job_id="j-222", df_idx=222)] - assert remaining == 0 def test_submit_multiple_simple(self, worker_pool): # A bunch of dummy tasks @@ -179,7 +302,7 @@ def test_submit_multiple_simple(self, worker_pool): worker_pool.submit_task(NopTask(job_id=f"j-{j}", df_idx=j)) # Process all of them (non-zero timeout, which should be plenty of time for all of them to finish) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) expected = [_TaskResult(job_id=f"j-{j}", df_idx=j) for j in range(5)] assert sorted(results, key=lambda r: r.job_id) == expected @@ -200,25 +323,24 @@ def test_submit_multiple_blocking_and_failing(self, worker_pool): ) # Initial state: nothing happened yet - results, remaining = worker_pool.process_futures(timeout=0) - assert (results, remaining) == ([], n) + results = worker_pool.process_futures(timeout=0) + assert results == [] # No changes even after timeout - results, remaining = worker_pool.process_futures(timeout=0.1) - assert (results, remaining) == ([], n) + results = worker_pool.process_futures(timeout=0.1) + assert results == [] # Set one event and wait for corresponding result events[0].set() - results, remaining = worker_pool.process_futures(timeout=0.1) + results = worker_pool.process_futures(timeout=0.1) assert results == [ _TaskResult(job_id="j-0", df_idx=0, db_update={"status": "all fine"}), ] - assert remaining == n - 1 # Release all but one event for j in range(n - 1): events[j].set() - results, remaining = worker_pool.process_futures(timeout=0.1) + results = worker_pool.process_futures(timeout=0.1) assert results == [ _TaskResult(job_id="j-1", df_idx=1, db_update={"status": "all fine"}), _TaskResult(job_id="j-2", df_idx=2, db_update={"status": "all fine"}), @@ -229,22 +351,20 @@ def test_submit_multiple_blocking_and_failing(self, worker_pool): stats_update={"threaded task failed": 1}, ), ] - assert remaining == 1 # Release all events for j in range(n): events[j].set() - results, remaining = worker_pool.process_futures(timeout=0.1) + results = worker_pool.process_futures(timeout=0.1) assert results == [ _TaskResult(job_id="j-4", df_idx=4, db_update={"status": "all fine"}), ] - assert remaining == 0 def test_shutdown(self, worker_pool): # Before shutdown worker_pool.submit_task(NopTask(job_id="j-123", df_idx=0)) - results, remaining = worker_pool.process_futures(timeout=0.1) - assert (results, remaining) == ([_TaskResult(job_id="j-123", df_idx=0)], 0) + results = worker_pool.process_futures(timeout=0.1) + assert results == [_TaskResult(job_id="j-123", df_idx=0)] worker_pool.shutdown() @@ -258,7 +378,7 @@ def test_job_start_task(self, worker_pool, dummy_backend, caplog): task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token=None) worker_pool.submit_task(task) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [ _TaskResult( job_id="job-000", @@ -267,7 +387,6 @@ def test_job_start_task(self, worker_pool, dummy_backend, caplog): stats_update={"job start": 1}, ) ] - assert remaining == 0 assert caplog.messages == [] def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): @@ -278,13 +397,420 @@ def test_job_start_task_failure(self, worker_pool, dummy_backend, caplog): task = _JobStartTask(job_id=job.job_id, df_idx=0, root_url=dummy_backend.connection.root_url, bearer_token=None) worker_pool.submit_task(task) - results, remaining = worker_pool.process_futures(timeout=1) + results = worker_pool.process_futures(timeout=1) assert results == [ _TaskResult( job_id="job-000", df_idx=0, db_update={"status": "start_failed"}, stats_update={"start_job error": 1} ) ] - assert remaining == 0 assert caplog.messages == [ "Failed to start job 'job-000': OpenEoApiError('[500] Internal: No job starting for you, buddy')" ] + + + +class TestJobManagerWorkerThreadPool: + @pytest.fixture + def thread_pool(self) -> Iterator[_JobManagerWorkerThreadPool]: + """Fixture for creating and cleaning up a thread pool manager.""" + pool = _JobManagerWorkerThreadPool() + yield pool + pool.shutdown() + + + def test_init_empty_config(self): + """Test initialization with empty config.""" + pool = _JobManagerWorkerThreadPool() + assert pool._pools == {} + assert pool._pool_configs == {} + pool.shutdown() + + def test_init_with_config(self): + """Test that pools are created immediately from config (not lazily).""" + pool_configs = {"default": 3, "download": 5} + pool = _JobManagerWorkerThreadPool(pool_configs) + + # Pools should exist immediately (eager initialization) + assert set(pool.list_pools()) == {"default", "download"} + assert pool._pools["default"]._max_workers == 3 + assert pool._pools["download"]._max_workers == 5 + + def test_init_with_duplicate_pool_names(self): + """Test behavior with duplicate pool names in config.""" + pool_configs = {"duplicate": 2, "duplicate": 3} + pool = _JobManagerWorkerThreadPool(pool_configs) + + # Last value should win + assert pool._pools["duplicate"]._max_workers == 3 + + + def test_submit_with_with_config(self): + """Test initialization with pool configurations.""" + pool_configs = { + "default": 2, + "download": 3, + "processing": 4 + } + pool = _JobManagerWorkerThreadPool(pool_configs) + + # Pools should NOT be created until first use + # Pools should NOT be created until first use (lazy initialization) + assert len(pool._pools) == 3 # All 3 pools should exist + assert pool._pool_configs == pool_configs + assert set(pool._pools.keys()) == {"default", "download", "processing"} + + task = NopTask(job_id="j-1", df_idx=1) + pool.submit_task(task = task, pool_name = 'no_task') + + # Verify each pool has correct number of workers + assert pool._pools["default"]._max_workers == 2 + assert pool._pools["download"]._max_workers == 3 + assert pool._pools["processing"]._max_workers == 4 + assert pool._pools["no_task"]._max_workers == 1 + + + + def test_submit_task_creates_pool(self, thread_pool): + """Test that submitting a task creates a pool dynamically.""" + task = NopTask(job_id="j-1", df_idx=1) + + assert thread_pool.list_pools() == [] + + # Submit task - should create pool + thread_pool.submit_task(task) + + # Pool should be created + assert thread_pool.list_pools() == ["default"] + assert "default" in thread_pool._pools + + # Process to complete the task + results = thread_pool.process_futures(timeout=0.1) + assert len(results) == 1 + assert results[0].job_id == "j-1" + + + def test_submit_multiple_task_types(self, thread_pool): + """Test submitting different task types to different pools.""" + # Submit different task types + task1 = NopTask(job_id="j-1", df_idx=1) + task2 = DummyTask(job_id="j-2", df_idx=2) + task3 = DummyTask(job_id="j-3", df_idx=3) + + thread_pool.submit_task(task1) + thread_pool.submit_task(task2) + thread_pool.submit_task(task3, "seperate") + + # Should have 2 pools + pools = sorted(thread_pool.list_pools()) + assert pools == ["default", "seperate"] + + # Check pending tasks + assert thread_pool.number_pending_tasks() == 3 + assert thread_pool.number_pending_tasks("default") == 2 + assert thread_pool.number_pending_tasks("seperate") == 1 + + def test_process_futures_updates_empty(self, thread_pool): + """Test process futures with no pools.""" + results = thread_pool.process_futures(timeout=0) + assert results == [] + + def test_process_futures_updates_multiple_pools(self, thread_pool): + """Test processing updates across multiple pools.""" + # Submit tasks to different pools + thread_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) + thread_pool.submit_task(NopTask(job_id="j-2", df_idx=2)) + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3), "dummy") + + results = thread_pool.process_futures(timeout=0.1) + + assert len(results) == 3 + + nop_results = [r for r in results if r.job_id in ["j-1", "j-2"]] + dummy_results = [r for r in results if r.job_id == "j-3"] + assert len(nop_results) == 2 + assert len(dummy_results) == 1 + + def test_process_futures_updates_partial_completion(self): + """Test processing when some tasks are still running.""" + # Use a pool with blocking tasks + pool = _JobManagerWorkerThreadPool() + + # Create a blocking task + event = threading.Event() + blocking_task = BlockingTask(job_id="j-block", df_idx=0, event=event, success=True) + + # Create a quick task + quick_task = NopTask(job_id="j-quick", df_idx=1) + + pool.submit_task(blocking_task, "blocking") # BlockingTask pool + pool.submit_task(quick_task, "quick") # NopTask pool + + # Process with timeout=0 - only quick task should complete + results = pool.process_futures(timeout=0) + + # Only quick task completed + assert len(results) == 1 + assert results[0].job_id == "j-quick" + + # Blocking task still pending + assert pool.number_pending_tasks() == 1 + assert pool.number_pending_tasks("blocking") == 1 + + # Release blocking task and process again + event.set() + results2 = pool.process_futures(timeout=0.1) + + assert len(results2) == 1 + assert results2[0].job_id == "j-block" + + pool.shutdown() + + def test_concurrent_submissions(self, thread_pool): + """Test concurrent task submissions to same pool.""" + import concurrent.futures + + def submit_tasks(start_idx: int): + for i in range(5): + thread_pool.submit_task(NopTask(job_id=f"j-{start_idx + i}", df_idx=start_idx + i)) + + # Submit tasks from multiple threads + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(submit_tasks, i * 10) for i in range(3)] + concurrent.futures.wait(futures) + + # Should have all tasks in one pool + assert thread_pool.list_pools() == ["default"] + assert thread_pool.number_pending_tasks() == 15 + + # Process them all + results = thread_pool.process_futures(timeout=0.5) + + assert len(results) == 15 + + def test_num_pending_tasks(self, thread_pool): + """Test counting pending tasks.""" + # Initially empty + assert thread_pool.number_pending_tasks() == 0 + assert thread_pool.number_pending_tasks("default") == 0 + + # Add some tasks + thread_pool.submit_task(NopTask(job_id="j-1", df_idx=1)) + thread_pool.submit_task(NopTask(job_id="j-2", df_idx=2)) + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3), "dummy") + + # Check totals + assert thread_pool.number_pending_tasks() == 3 + assert thread_pool.number_pending_tasks("default") == 2 + assert thread_pool.number_pending_tasks("dummy") == 1 + + # Process all + thread_pool.process_futures(timeout=0.1) + + # Should be empty + assert thread_pool.number_pending_tasks() == 0 + assert thread_pool.number_pending_tasks("default") == 0 + + + def test_process_futures_returns_remaining_counts(self): + """Test that process_futures returns remaining task counts per pool.""" + # Note: Your current implementation doesn't return remaining counts! + # It returns List[_TaskResult] but the docstring says it returns + # (all_results, dict of remaining tasks per pool) + + pool = _JobManagerWorkerThreadPool({"fast": 1, "slow": 1}) + + # Add blocking and non-blocking tasks + event = threading.Event() + pool.submit_task(BlockingTask(job_id="j-block", df_idx=0, event=event), "slow") + pool.submit_task(NopTask(job_id="j-quick", df_idx=1), "fast") + + # Process with timeout=0 + results = pool.process_futures(timeout=0) + + # Should get only quick task + assert len(results) == 1 + assert results[0].job_id == "j-quick" + + # Check remaining tasks + assert pool.number_pending_tasks("slow") == 1 + assert pool.number_pending_tasks("fast") == 0 + + def test_pool_parallelism_with_blocking_tasks(self): + """Test that multiple workers allow parallel execution.""" + pool = _JobManagerWorkerThreadPool({ + "BlockingTask": 3, # 3 workers for blocking tasks + }) + + # Create multiple blocking tasks + events = [threading.Event() for _ in range(5)] + + for i, event in enumerate(events): + pool.submit_task(BlockingTask( + job_id=f"j-block-{i}", + df_idx=i, + event=event, + success=True + )) + + # Initially all pending + assert pool.number_pending_tasks() == 5 + + # Release all events at once + for event in events: + event.set() + + results = pool.process_futures(timeout=0.5) + assert len(results) == 5 + + for result in results: + assert result.job_id.startswith("j-block-") + + pool.shutdown() + + def test_task_with_error_handling(self, thread_pool): + """Test that task errors are properly handled in the pool.""" + # Submit a failing DummyTask (j-666 fails) + thread_pool.submit_task(DummyTask(job_id="j-666", df_idx=0)) + + # Process it + results = thread_pool.process_futures(timeout=0.1) + + # Should get error result + assert len(results) == 1 + result = results[0] + assert result.job_id == "j-666" + assert result.db_update == {"status": "threaded task failed"} + assert result.stats_update == {"threaded task failed": 1} + + def test_mixed_success_and_error_tasks(self, thread_pool): + """Test mix of successful and failing tasks.""" + # Submit mix of tasks + thread_pool.submit_task(DummyTask(job_id="j-1", df_idx=1)) # Success + thread_pool.submit_task(DummyTask(job_id="j-666", df_idx=2)) # Failure + thread_pool.submit_task(DummyTask(job_id="j-3", df_idx=3)) # Success + + # Process all + results = thread_pool.process_futures(timeout=0.1) + + # Should get 3 results + assert len(results) == 3 + + # Check results + success_results = [r for r in results if r.job_id != "j-666"] + error_results = [r for r in results if r.job_id == "j-666"] + + assert len(success_results) == 2 + assert len(error_results) == 1 + + # Verify success results + for result in success_results: + assert result.db_update == {"status": "dummified"} + assert result.stats_update == {"dummy": 1} + + # Verify error result + error_result = error_results[0] + assert error_result.db_update == {"status": "threaded task failed"} + assert error_result.stats_update == {"threaded task failed": 1} + + def test_submit_same_task_to_multiple_pools(self): + """Test submitting the same task instance to different pools.""" + pool = _JobManagerWorkerThreadPool({"pool1": 1, "pool2": 1}) + task = NopTask(job_id="j-1", df_idx=1) + + # Submit same task to two different pools + pool.submit_task(task, "pool1") + pool.submit_task(task, "pool2") # Same instance! + + assert pool.number_pending_tasks("pool1") == 1 + assert pool.number_pending_tasks("pool2") == 1 + assert pool.number_pending_tasks() == 2 + + def test_shutdown_specific_pool(self): + """Test shutting down a specific pool.""" + # Create fresh pool for destructive test + pool = _JobManagerWorkerThreadPool() + + # Create two pools + pool.submit_task(NopTask(job_id="j-1", df_idx=1), "notask") # NopTask pool + pool.submit_task(DummyTask(job_id="j-2", df_idx=2), "dummy") # DummyTask pool + + assert sorted(pool.list_pools()) == ["dummy", "notask"] + + # Shutdown NopTask pool only + pool.shutdown("notask") + + # Only DummyTask pool should remain + assert pool.list_pools() == ["dummy"] + + # Can't submit to shutdown pool + # Actually, it will create a new pool since we deleted it + pool.submit_task(NopTask(job_id="j-3", df_idx=3)) # Creates new NopTask pool + assert sorted(pool.list_pools()) == [ "default", "dummy"] + + pool.shutdown() + + def test_shutdown_all(self): + """Test shutting down all pools.""" + # Create fresh pool for destructive test + pool = _JobManagerWorkerThreadPool() + + # Create multiple pools + pool.submit_task(NopTask(job_id="j-1", df_idx=1), "notask") # NopTask pool + pool.submit_task(DummyTask(job_id="j-2", df_idx=2), "dummy") + + assert len(pool.list_pools()) == 2 + + # Shutdown all + pool.shutdown() + + assert pool.list_pools() == [] + assert len(pool._pools) == 0 + + + def test_submit_to_nonexistent_pool_after_shutdown(self): + """Test submitting to a pool that was previously shutdown.""" + pool = _JobManagerWorkerThreadPool({"download": 2}) + + # Shutdown the pool + pool.shutdown("download") + assert "download" not in pool.list_pools() + + # Submit again - should create new pool (since dynamic creation allowed) + task = NopTask(job_id="j-1", df_idx=1) + pool.submit_task(task, "download") + + assert "download" in pool.list_pools() + assert pool._pools["download"]._max_workers == 1 + + + def test_number_pending_tasks_for_shutdown_pool(self): + """Test checking pending tasks for a pool that was shutdown.""" + pool = _JobManagerWorkerThreadPool({"download": 2}) + + # Submit and process many tasks + for i in range(10): + pool.submit_task(NopTask(job_id=f"j-{i}", df_idx=i)) + + pool.shutdown("download") + + # Should return 0 for shutdown/nonexistent pool + assert pool.number_pending_tasks("download") == 0 + + def test_future_task_pairs_postprocessing(self): + """Test that completed tasks don't accumulate in _future_task_pairs.""" + pool = _JobManagerWorkerThreadPool() + + # Submit and process many tasks + for i in range(100): + pool.submit_task(NopTask(job_id=f"j-{i}", df_idx=i)) + + # Process all + results = pool.process_futures(timeout=0.1) + assert len(results) == 100 + + # Internal tracking list should be empty + for pool_obj in pool._pools.values(): + assert len(pool_obj._future_task_pairs) == 0 + + pool.shutdown() + diff --git a/tests/rest/datacube/test_datacube.py b/tests/rest/datacube/test_datacube.py index f99f24726..4389d52ce 100644 --- a/tests/rest/datacube/test_datacube.py +++ b/tests/rest/datacube/test_datacube.py @@ -18,8 +18,10 @@ import shapely import shapely.geometry +import openeo.processes from openeo import collection_property from openeo.api.process import Parameter +from openeo.internal.graph_building import PGNode from openeo.metadata import SpatialDimension from openeo.rest import BandMathException, OpenEoClientException from openeo.rest._testing import build_capabilities @@ -698,6 +700,69 @@ def test_filter_temporal_single_arg(s2cube: DataCube, arg, expect_failure): _ = s2cube.filter_temporal(arg) +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_filter_temporal_from_udf(s2cube: DataCube, udf_factory): + temporal_extent = udf_factory(data=[1, 2, 3], udf="print('hello time')", runtime="Python") + cube = s2cube.filter_temporal(temporal_extent) + assert get_download_graph(cube, drop_save_result=True) == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello time')", "runtime": "Python"}, + }, + "filtertemporal1": { + "process_id": "filter_temporal", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "extent": {"from_node": "runudf1"}, + }, + }, + } + + +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_filter_temporal_start_end_from_udf(s2cube: DataCube, udf_factory): + start = udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python") + end = udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python") + cube = s2cube.filter_temporal(start_date=start, end_date=end) + assert get_download_graph(cube, drop_save_result=True) == { + "loadcollection1": { + "process_id": "load_collection", + "arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None}, + }, + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"}, + }, + "filtertemporal1": { + "process_id": "filter_temporal", + "arguments": { + "data": {"from_node": "loadcollection1"}, + "extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}], + }, + }, + } + + def test_max_time(s2cube, api_version): im = s2cube.max_time() graph = _get_leaf_node(im, force_flat=True) diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index 97801365b..2788c6a7a 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -2375,6 +2375,70 @@ def test_load_collection_parameterized_extents(con100, spatial_extent, temporal_ } +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_load_collection_extents_from_udf(con100, udf_factory): + spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python") + temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python") + cube = con100.load_collection("S2", spatial_extent=spatial_extent, temporal_extent=temporal_extent) + assert get_download_graph(cube, drop_save_result=True) == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"}, + }, + "loadcollection1": { + "process_id": "load_collection", + "arguments": { + "id": "S2", + "spatial_extent": {"from_node": "runudf1"}, + "temporal_extent": {"from_node": "runudf2"}, + }, + }, + } + + +@pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], +) +def test_load_collection_temporal_extent_from_udf(con100, udf_factory): + temporal_extent = [ + udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"), + udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"), + ] + cube = con100.load_collection("S2", temporal_extent=temporal_extent) + assert get_download_graph(cube, drop_save_result=True) == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"}, + }, + "loadcollection1": { + "process_id": "load_collection", + "arguments": { + "id": "S2", + "spatial_extent": None, + "temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}], + }, + }, + } + + def test_apply_dimension_temporal_cumsum_with_target(con100, test_data): cumsum = con100.load_collection("S2").apply_dimension('cumsum', dimension="t", target_dimension="MyNewTime") actual_graph = cumsum.flat_graph() diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 0fc0a3976..6cf05e62a 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -17,6 +17,7 @@ import shapely.geometry import openeo +import openeo.processes from openeo import BatchJob from openeo.api.process import Parameter from openeo.internal.graph_building import FlatGraphableMixin, PGNode @@ -3715,6 +3716,73 @@ def test_load_stac_spatial_extent_vector_cube(self, dummy_backend): }, } + @pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], + ) + def test_load_stac_extents_from_udf(self, dummy_backend, udf_factory): + spatial_extent = udf_factory(data=[1, 2, 3], udf="print('hello space')", runtime="Python") + temporal_extent = udf_factory(data=[4, 5, 6], udf="print('hello time')", runtime="Python") + cube = dummy_backend.connection.load_stac( + "https://stac.test/data", spatial_extent=spatial_extent, temporal_extent=temporal_extent + ) + cube.execute() + assert dummy_backend.get_sync_pg() == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello space')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello time')", "runtime": "Python"}, + }, + "loadstac1": { + "process_id": "load_stac", + "arguments": { + "url": "https://stac.test/data", + "spatial_extent": {"from_node": "runudf1"}, + "temporal_extent": {"from_node": "runudf2"}, + }, + "result": True, + }, + } + + @pytest.mark.parametrize( + "udf_factory", + [ + (lambda data, udf, runtime: openeo.processes.run_udf(data=data, udf=udf, runtime=runtime)), + (lambda data, udf, runtime: PGNode(process_id="run_udf", data=data, udf=udf, runtime=runtime)), + ], + ) + def test_load_stac_temporal_extent_from_udf(self, dummy_backend, udf_factory): + temporal_extent = [ + udf_factory(data=[1, 2, 3], udf="print('hello start')", runtime="Python"), + udf_factory(data=[4, 5, 6], udf="print('hello end')", runtime="Python"), + ] + cube = dummy_backend.connection.load_stac("https://stac.test/data", temporal_extent=temporal_extent) + cube.execute() + assert dummy_backend.get_sync_pg() == { + "runudf1": { + "process_id": "run_udf", + "arguments": {"data": [1, 2, 3], "udf": "print('hello start')", "runtime": "Python"}, + }, + "runudf2": { + "process_id": "run_udf", + "arguments": {"data": [4, 5, 6], "udf": "print('hello end')", "runtime": "Python"}, + }, + "loadstac1": { + "process_id": "load_stac", + "arguments": { + "url": "https://stac.test/data", + "temporal_extent": [{"from_node": "runudf1"}, {"from_node": "runudf2"}], + }, + "result": True, + }, + } + @pytest.mark.parametrize( "data", diff --git a/tests/rest/test_job.py b/tests/rest/test_job.py index 19de98ef8..b5cd8d6c4 100644 --- a/tests/rest/test_job.py +++ b/tests/rest/test_job.py @@ -353,8 +353,8 @@ def test_start_and_wait_with_error_require_success(dummy_backend, require_succes pytest.raises(OpenEoApiPlainError, match=re.escape("[500] Internal Server Error")), [23], ), - ( # Default config with a 503 error (skipped by soft error feature of execute_batch poll loop) - None, + ( # Only retry by default on 429, but still handle a 503 error with soft error skipping feature of execute_batch poll loop + {"status_forcelist": [HTTP_429_TOO_MANY_REQUESTS]}, [httpretty.Response(status=HTTP_503_SERVICE_UNAVAILABLE, body="Service Unavailable")], contextlib.nullcontext(), [23, 12.34, 34], diff --git a/tests/testing/test_stac.py b/tests/testing/test_stac.py index 025e5928e..090e009d1 100644 --- a/tests/testing/test_stac.py +++ b/tests/testing/test_stac.py @@ -87,6 +87,32 @@ def test_asset_default(self): assert asset == { "href": "https://stac.test/asset.tiff", "type": "image/tiff; application=geotiff", + "roles": ["data"], + } + pystac.Asset.from_dict(asset) + + def test_asset_no_roles(self): + asset = StacDummyBuilder.asset(roles=None) + assert asset == { + "href": "https://stac.test/asset.tiff", + "type": "image/tiff; application=geotiff", + } + pystac.Asset.from_dict(asset) + + def test_asset_proj_fields(self): + asset = StacDummyBuilder.asset( + proj_code="EPSG:4326", + proj_bbox=(3, 51, 4, 52), + proj_shape=(100, 100), + proj_transform=(0.01, 0.0, 3.0, 0.0, -0.01, 52.0), + ) + assert asset == { + "href": "https://stac.test/asset.tiff", + "type": "image/tiff; application=geotiff", + "roles": ["data"], + "proj:code": "EPSG:4326", + "proj:bbox": [3, 51, 4, 52], + "proj:shape": [100, 100], + "proj:transform": [0.01, 0.0, 3.0, 0.0, -0.01, 52.0], } - # Check if the default asset validates pystac.Asset.from_dict(asset)