diff --git a/src/eo_api/datasets/registry.py b/src/eo_api/datasets/registry.py index 03ea31f..fae4a85 100644 --- a/src/eo_api/datasets/registry.py +++ b/src/eo_api/datasets/registry.py @@ -1,5 +1,6 @@ """Dataset registry backed by YAML config files.""" +import functools import logging from pathlib import Path from typing import Any @@ -12,6 +13,7 @@ CONFIGS_DIR = SCRIPT_DIR / "registry" +@functools.lru_cache(maxsize=1) def list_datasets() -> list[dict[str, Any]]: """Load all YAML files in the registry folder and return a flat list of datasets.""" datasets: list[dict[str, Any]] = [] @@ -25,7 +27,7 @@ def list_datasets() -> list[dict[str, Any]]: with open(file_path, encoding="utf-8") as f: file_datasets = yaml.safe_load(f) datasets.extend(file_datasets) - except Exception: + except (yaml.YAMLError, OSError): logger.exception("Error loading %s", file_path.name) return datasets diff --git a/src/eo_api/datasets/serialize.py b/src/eo_api/datasets/serialize.py index e7761d8..c6266a6 100644 --- a/src/eo_api/datasets/serialize.py +++ b/src/eo_api/datasets/serialize.py @@ -38,7 +38,8 @@ def dataframe_to_preview(df: pd.DataFrame, dataset: dict[str, Any], period_type: temp_df = df[[time_dim, "id", varname]] temp_df[time_dim] = pandas_period_string(temp_df[time_dim], period_type) - assert len(temp_df[time_dim].unique()) == 1 + if len(temp_df[time_dim].unique()) != 1: + raise ValueError("dataframe_to_preview expects exactly one timestep") org_units = gpd.read_file(json.dumps(constants.ORG_UNITS_GEOJSON)) org_units_with_temp = org_units.merge(temp_df, on="id", how="left") @@ -67,7 +68,8 @@ def xarray_to_preview(ds: xr.Dataset, dataset: dict[str, Any], period_type: str) temp_ds = ds[[time_dim, varname]] temp_ds = temp_ds.assign_coords({time_dim: lambda x: numpy_period_array(x[time_dim].values, period_type)}) - assert len(temp_ds[time_dim].values) == 1 + if len(temp_ds[time_dim].values) != 1: + raise ValueError("xarray_to_preview expects exactly one timestep") fig = Figure() ax = fig.subplots() @@ -89,7 +91,11 @@ def xarray_to_temporary_netcdf(ds: xr.Dataset) -> str: fd = tempfile.NamedTemporaryFile(suffix=".nc", delete=False) path = fd.name fd.close() - ds.to_netcdf(path) + try: + ds.to_netcdf(path) + except Exception: + os.remove(path) + raise return path diff --git a/src/eo_api/main.py b/src/eo_api/main.py index f99667a..e707022 100644 --- a/src/eo_api/main.py +++ b/src/eo_api/main.py @@ -12,7 +12,6 @@ app.add_middleware( CORSMiddleware, allow_origins=["*"], - allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) diff --git a/src/eo_api/prefect_flows/tasks.py b/src/eo_api/prefect_flows/tasks.py index d25ce95..2a045ff 100644 --- a/src/eo_api/prefect_flows/tasks.py +++ b/src/eo_api/prefect_flows/tasks.py @@ -5,6 +5,7 @@ """ import logging +import os from pathlib import Path import httpx @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) -OGCAPI_BASE_URL = "http://localhost:8000/ogcapi" +OGCAPI_BASE_URL = os.getenv("OGCAPI_BASE_URL", "http://localhost:8000/ogcapi") PROCESS_TIMEOUT_SECONDS = 600 diff --git a/src/eo_api/routers/ogcapi/plugins/processes/chirps3_dhis2_pipeline.py b/src/eo_api/routers/ogcapi/plugins/processes/chirps3_dhis2_pipeline.py index 0351b3d..8cdb5d4 100644 --- a/src/eo_api/routers/ogcapi/plugins/processes/chirps3_dhis2_pipeline.py +++ b/src/eo_api/routers/ogcapi/plugins/processes/chirps3_dhis2_pipeline.py @@ -521,7 +521,7 @@ def execute(self, data: dict[str, Any], outputs: Any = None) -> tuple[str, dict[ "or increase DHIS2_HTTP_TIMEOUT_SECONDS." ) from None except Exception as e: - raise ProcessorExecuteError(str(e)) from None + raise ProcessorExecuteError(str(e)) from e finally: client.close() diff --git a/src/eo_api/routers/ogcapi/plugins/processes/zonal_statistics.py b/src/eo_api/routers/ogcapi/plugins/processes/zonal_statistics.py index 4cce0cb..9f6edaa 100644 --- a/src/eo_api/routers/ogcapi/plugins/processes/zonal_statistics.py +++ b/src/eo_api/routers/ogcapi/plugins/processes/zonal_statistics.py @@ -3,11 +3,9 @@ from __future__ import annotations import json -from pathlib import Path from typing import Any -from urllib.parse import urlparse -from urllib.request import urlopen +import httpx import numpy as np import rasterio from pydantic import ValidationError @@ -27,7 +25,7 @@ "inputs": { "geojson": { "title": "GeoJSON FeatureCollection", - "description": "FeatureCollection object or URI/path to GeoJSON file.", + "description": "FeatureCollection object or HTTP(S) URL to GeoJSON file.", "schema": {"oneOf": [{"type": "object"}, {"type": "string"}]}, "minOccurs": 1, "maxOccurs": 1, @@ -109,16 +107,16 @@ def _read_geojson_input(geojson_input: dict[str, Any] | str) -> dict[str, Any]: if isinstance(geojson_input, dict): return geojson_input - parsed = urlparse(geojson_input) - if parsed.scheme in {"http", "https"}: - with urlopen(geojson_input) as response: - payload = response.read().decode("utf-8") - return _parse_geojson_object(payload) + if not geojson_input.startswith(("http://", "https://")): + raise ProcessorExecuteError("GeoJSON string input must be an HTTP(S) URL") - path = Path(geojson_input) - if not path.exists(): - raise ProcessorExecuteError(f"GeoJSON file not found: {geojson_input}") - return _parse_geojson_object(path.read_text(encoding="utf-8")) + try: + response = httpx.get(geojson_input, timeout=30) + response.raise_for_status() + except httpx.HTTPError as err: + raise ProcessorExecuteError(f"Failed to fetch GeoJSON from URL: {err}") from err + + return _parse_geojson_object(response.text) def _parse_geojson_object(payload: str) -> dict[str, Any]: