From 4e61a6c15e3b5cb86cecc981600722b55dbb9e69 Mon Sep 17 00:00:00 2001 From: Morten Hansen Date: Thu, 26 Feb 2026 15:22:02 +0100 Subject: [PATCH 1/3] refactor: split main.py into startup, lifecycle, and router modules Extract early-boot side effects (env vars, PROJ config, logging, dotenv, OpenAPI generation) into startup.py, Prefect lifespan management into lifecycle.py, and the /ogcapi redirect into routers/root.py. Reduces main.py from 138 lines to ~25 lines of pure app wiring. --- src/eo_api/lifecycle.py | 34 ++++++++++ src/eo_api/main.py | 124 ++----------------------------------- src/eo_api/routers/root.py | 7 +++ src/eo_api/startup.py | 81 ++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 118 deletions(-) create mode 100644 src/eo_api/lifecycle.py create mode 100644 src/eo_api/startup.py diff --git a/src/eo_api/lifecycle.py b/src/eo_api/lifecycle.py new file mode 100644 index 0000000..d3c9735 --- /dev/null +++ b/src/eo_api/lifecycle.py @@ -0,0 +1,34 @@ +"""Application lifespan: Prefect server bootstrap and flow runner.""" + +import asyncio +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +from fastapi import FastAPI + + +async def _serve_flows() -> None: + """Register Prefect deployments and start a runner to execute them.""" + from prefect.runner import Runner + + from eo_api.prefect_flows.flows import ALL_FLOWS + + runner = Runner() + for fl in ALL_FLOWS: + await runner.aadd_flow(fl, name=fl.name) + await runner.start() + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[None]: + """Start Prefect server, then register and serve pipeline deployments.""" + from eo_api.routers import prefect + + # Mounted sub-apps don't get their lifespans called automatically, + # so we trigger the Prefect server's lifespan here to initialize + # the database, docket, and background workers. + prefect_app = prefect.app + async with prefect_app.router.lifespan_context(prefect_app): + task = asyncio.create_task(_serve_flows()) + yield + task.cancel() diff --git a/src/eo_api/main.py b/src/eo_api/main.py index c7dbe45..f99667a 100644 --- a/src/eo_api/main.py +++ b/src/eo_api/main.py @@ -1,116 +1,11 @@ -"""DHIS2 EO API - Earth observation data API for DHIS2. +"""DHIS2 EO API -- Earth observation data API for DHIS2.""" -load_dotenv() is called before pygeoapi import because pygeoapi -reads PYGEOAPI_CONFIG and PYGEOAPI_OPENAPI at import time. - -Prefect UI env vars are set before any imports because Prefect -caches its settings on first import. -""" - -import logging -import os -import warnings -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from importlib.util import find_spec -from pathlib import Path -from typing import Any, cast - -os.environ.setdefault("PREFECT_UI_SERVE_BASE", "/prefect/") -os.environ.setdefault("PREFECT_UI_API_URL", "/prefect/api") -os.environ.setdefault("PREFECT_SERVER_API_BASE_PATH", "/prefect/api") -os.environ.setdefault("PREFECT_API_URL", "http://localhost:8000/prefect/api") -os.environ.setdefault("PREFECT_SERVER_ANALYTICS_ENABLED", "false") -os.environ.setdefault("PREFECT_SERVER_UI_SHOW_PROMOTIONAL_CONTENT", "false") - - -def _configure_proj_data() -> None: - """Point PROJ to rasterio bundled data to avoid mixed-install conflicts.""" - spec = find_spec("rasterio") - if spec is None or spec.origin is None: - return - - proj_data = Path(spec.origin).parent / "proj_data" - if not proj_data.is_dir(): - return - - proj_data_path = str(proj_data) - os.environ["PROJ_DATA"] = proj_data_path - os.environ["PROJ_LIB"] = proj_data_path - - -_configure_proj_data() - -warnings.filterwarnings("ignore", message="ecCodes .* or higher is recommended") -warnings.filterwarnings("ignore", message=r"Engine 'cfgrib' loading failed:[\s\S]*", category=RuntimeWarning) - -logging.getLogger("pygeoapi.api.processes").setLevel(logging.ERROR) -logging.getLogger("pygeoapi.l10n").setLevel(logging.ERROR) - -from dotenv import load_dotenv # noqa: E402 - -load_dotenv() - -openapi_path = os.getenv("PYGEOAPI_OPENAPI") -config_path = os.getenv("PYGEOAPI_CONFIG") -if openapi_path and config_path and not Path(openapi_path).exists(): - from pygeoapi.openapi import generate_openapi_document # noqa: E402 - - with Path(config_path).open(encoding="utf-8") as config_file: - openapi_doc = generate_openapi_document( - config_file, - output_format=cast(Any, "yaml"), - fail_on_invalid_collection=False, - ) - Path(openapi_path).write_text(openapi_doc, encoding="utf-8") - warnings.warn(f"Generated missing OpenAPI document at '{openapi_path}'.", RuntimeWarning) - -from fastapi import FastAPI # noqa: E402 -from fastapi.middleware.cors import CORSMiddleware # noqa: E402 -from fastapi.responses import RedirectResponse # noqa: E402 - -from eo_api.routers import cog, ogcapi, pipelines, prefect, root # noqa: E402 - -# Keep app progress logs visible while muting noisy third-party info logs. -eo_logger = logging.getLogger("eo_api") -eo_logger.setLevel(logging.INFO) -if not eo_logger.handlers: - handler = logging.StreamHandler() - handler.setLevel(logging.INFO) - handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s - %(message)s")) - eo_logger.addHandler(handler) -eo_logger.propagate = False - -logging.getLogger("dhis2eo").setLevel(logging.WARNING) -logging.getLogger("xarray").setLevel(logging.WARNING) - - -async def _serve_flows() -> None: - """Register Prefect deployments and start a runner to execute them.""" - from prefect.runner import Runner - - from eo_api.prefect_flows.flows import ALL_FLOWS - - runner = Runner() - for fl in ALL_FLOWS: - await runner.aadd_flow(fl, name=fl.name) - await runner.start() - - -@asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncIterator[None]: - """Start Prefect server, then register and serve pipeline deployments.""" - import asyncio - - # Mounted sub-apps don't get their lifespans called automatically, - # so we trigger the Prefect server's lifespan here to initialize - # the database, docket, and background workers. - prefect_app = prefect.app - async with prefect_app.router.lifespan_context(prefect_app): - task = asyncio.create_task(_serve_flows()) - yield - task.cancel() +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +import eo_api.startup # noqa: F401 # pyright: ignore[reportUnusedImport] +from eo_api.lifecycle import lifespan +from eo_api.routers import cog, ogcapi, pipelines, prefect, root app = FastAPI(lifespan=lifespan) @@ -126,12 +21,5 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: app.include_router(cog.router, prefix="/cog", tags=["Cloud Optimized GeoTIFF"]) app.include_router(pipelines.router, prefix="/pipelines", tags=["Pipelines"]) - -@app.get("/ogcapi", include_in_schema=False) -async def ogcapi_redirect() -> RedirectResponse: - """Redirect /ogcapi to /ogcapi/ for trailing-slash consistency.""" - return RedirectResponse(url="/ogcapi/") - - app.mount(path="/ogcapi", app=ogcapi.app) app.mount(path="/", app=prefect.app) diff --git a/src/eo_api/routers/root.py b/src/eo_api/routers/root.py index a9cbc60..c497ac6 100644 --- a/src/eo_api/routers/root.py +++ b/src/eo_api/routers/root.py @@ -4,6 +4,7 @@ from importlib.metadata import version from fastapi import APIRouter, Request +from fastapi.responses import RedirectResponse from eo_api.schemas import AppInfo, HealthStatus, Link, RootResponse, Status @@ -40,3 +41,9 @@ def info() -> AppInfo: pygeoapi_version=version("pygeoapi"), uvicorn_version=version("uvicorn"), ) + + +@router.get("/ogcapi", include_in_schema=False) +async def ogcapi_redirect() -> RedirectResponse: + """Redirect /ogcapi to /ogcapi/ for trailing-slash consistency.""" + return RedirectResponse(url="/ogcapi/") diff --git a/src/eo_api/startup.py b/src/eo_api/startup.py new file mode 100644 index 0000000..710d8e9 --- /dev/null +++ b/src/eo_api/startup.py @@ -0,0 +1,81 @@ +"""Early-boot side effects: env vars, PROJ config, logging, dotenv, OpenAPI. + +This module is imported before any other eo_api modules so that +environment variables and logging are configured before Prefect/pygeoapi +read them at import time. +""" + +import logging +import os +import warnings +from importlib.util import find_spec +from pathlib import Path +from typing import Any, cast + +# -- Prefect env-var defaults (must be set before Prefect is imported) -------- +os.environ.setdefault("PREFECT_UI_SERVE_BASE", "/prefect/") +os.environ.setdefault("PREFECT_UI_API_URL", "/prefect/api") +os.environ.setdefault("PREFECT_SERVER_API_BASE_PATH", "/prefect/api") +os.environ.setdefault("PREFECT_API_URL", "http://localhost:8000/prefect/api") +os.environ.setdefault("PREFECT_SERVER_ANALYTICS_ENABLED", "false") +os.environ.setdefault("PREFECT_SERVER_UI_SHOW_PROMOTIONAL_CONTENT", "false") + + +# -- PROJ data configuration -------------------------------------------------- +def _configure_proj_data() -> None: + """Point PROJ to rasterio bundled data to avoid mixed-install conflicts.""" + spec = find_spec("rasterio") + if spec is None or spec.origin is None: + return + + proj_data = Path(spec.origin).parent / "proj_data" + if not proj_data.is_dir(): + return + + proj_data_path = str(proj_data) + os.environ["PROJ_DATA"] = proj_data_path + os.environ["PROJ_LIB"] = proj_data_path + + +_configure_proj_data() + +# -- Warning filters --------------------------------------------------------- +warnings.filterwarnings("ignore", message="ecCodes .* or higher is recommended") +warnings.filterwarnings("ignore", message=r"Engine 'cfgrib' loading failed:[\s\S]*", category=RuntimeWarning) + +# -- Silence noisy third-party loggers early ---------------------------------- +logging.getLogger("pygeoapi.api.processes").setLevel(logging.ERROR) +logging.getLogger("pygeoapi.l10n").setLevel(logging.ERROR) + +# -- Load .env (must happen before pygeoapi reads PYGEOAPI_CONFIG) ------------ +from dotenv import load_dotenv # noqa: E402 + +load_dotenv() + +# -- Generate missing OpenAPI document ---------------------------------------- +openapi_path = os.getenv("PYGEOAPI_OPENAPI") +config_path = os.getenv("PYGEOAPI_CONFIG") +if openapi_path and config_path and not Path(openapi_path).exists(): + from pygeoapi.openapi import generate_openapi_document # noqa: E402 + + with Path(config_path).open(encoding="utf-8") as config_file: + openapi_doc = generate_openapi_document( + config_file, + output_format=cast(Any, "yaml"), + fail_on_invalid_collection=False, + ) + Path(openapi_path).write_text(openapi_doc, encoding="utf-8") + warnings.warn(f"Generated missing OpenAPI document at '{openapi_path}'.", RuntimeWarning) + +# -- eo_api / third-party logging setup --------------------------------------- +eo_logger = logging.getLogger("eo_api") +eo_logger.setLevel(logging.INFO) +if not eo_logger.handlers: + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s - %(message)s")) + eo_logger.addHandler(handler) +eo_logger.propagate = False + +logging.getLogger("dhis2eo").setLevel(logging.WARNING) +logging.getLogger("xarray").setLevel(logging.WARNING) From d1c95dffee08c248381c03d35b8a4ea1d00f2557 Mon Sep 17 00:00:00 2001 From: Morten Hansen Date: Thu, 26 Feb 2026 15:23:38 +0100 Subject: [PATCH 2/3] feat(mypy): enable pydantic mypy plugin Add type annotation for FeatureCollection variable to satisfy the stricter checking from the pydantic plugin. --- pyproject.toml | 1 + src/eo_api/routers/ogcapi/plugins/providers/dhis2_org_units.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6369ef8..f17f6e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ docstring-code-line-length = "dynamic" [tool.mypy] python_version = "3.13" +plugins = ["pydantic.mypy"] warn_return_any = true warn_unused_configs = true disallow_untyped_defs = true diff --git a/src/eo_api/routers/ogcapi/plugins/providers/dhis2_org_units.py b/src/eo_api/routers/ogcapi/plugins/providers/dhis2_org_units.py index 5a06587..e702274 100644 --- a/src/eo_api/routers/ogcapi/plugins/providers/dhis2_org_units.py +++ b/src/eo_api/routers/ogcapi/plugins/providers/dhis2_org_units.py @@ -52,7 +52,7 @@ def query( number_matched = len(org_units) page = org_units[offset : offset + limit] - fc = FeatureCollection( + fc: FeatureCollection = FeatureCollection( type="FeatureCollection", features=[org_unit_to_feature(ou) for ou in page], ) From 309d81e257fedb2f120df1d1e71328693f215c8e Mon Sep 17 00:00:00 2001 From: Morten Hansen Date: Thu, 26 Feb 2026 18:25:56 +0100 Subject: [PATCH 3/3] fix: remove datasets lint excludes and fix all linting issues Remove ruff, mypy, and pyright exclude entries for src/eo_api/datasets. Add module docstrings, function docstrings, and type annotations across all datasets files. Replace deprecated tempfile.mktemp with NamedTemporaryFile. Add missing third-party modules to mypy ignore_missing_imports. Remove accidental root __init__.py. --- pyproject.toml | 10 +- src/eo_api/datasets/api.py | 186 ++++++++++++++----------- src/eo_api/datasets/cache.py | 221 +++++++++++++++--------------- src/eo_api/datasets/constants.py | 6 +- src/eo_api/datasets/preprocess.py | 16 +-- src/eo_api/datasets/raster.py | 113 +++++++-------- src/eo_api/datasets/registry.py | 43 +++--- src/eo_api/datasets/serialize.py | 80 +++++------ src/eo_api/datasets/units.py | 33 +++-- src/eo_api/datasets/utils.py | 67 +++++---- 10 files changed, 376 insertions(+), 399 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ad3e79e..d319dfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,6 @@ dependencies = [ [tool.ruff] target-version = "py313" line-length = 120 -exclude = [ - "src/eo_api/datasets" -] [tool.ruff.lint] fixable = ["ALL"] @@ -59,21 +56,18 @@ no_implicit_optional = true warn_unused_ignores = true strict_equality = true mypy_path = ["src"] -exclude = [ - "src/eo_api/datasets" -] [[tool.mypy.overrides]] module = "tests.*" disallow_untyped_defs = false [[tool.mypy.overrides]] -module = ["dhis2eo.*", "dhis2_client.*", "pygeoapi.*", "titiler.*", "rasterio.*", "pygeofilter.*", "prefect.*", "requests.*"] +module = ["dhis2eo.*", "dhis2_client.*", "pygeoapi.*", "titiler.*", "rasterio.*", "pygeofilter.*", "prefect.*", "requests.*", "geopandas.*", "earthkit.*", "metpy.*", "matplotlib.*", "yaml"] ignore_missing_imports = true [tool.pyright] include = ["src", "tests"] -exclude = ["**/.venv", "src/eo_api/datasets"] +exclude = ["**/.venv"] pythonVersion = "3.13" typeCheckingMode = "strict" useLibraryCodeForTypes = true diff --git a/src/eo_api/datasets/api.py b/src/eo_api/datasets/api.py index 7c25fe0..5273335 100644 --- a/src/eo_api/datasets/api.py +++ b/src/eo_api/datasets/api.py @@ -1,4 +1,8 @@ +"""FastAPI router exposing dataset endpoints.""" +from typing import Any + +import xarray as xr from fastapi import APIRouter, BackgroundTasks, HTTPException, Response from fastapi.responses import FileResponse from starlette.background import BackgroundTask @@ -7,146 +11,164 @@ router = APIRouter() + @router.get("/") -def list_datasets(): - """Returned list of available datasets from registry. - """ - datasets = registry.list_datasets() - return datasets +def list_datasets() -> list[dict[str, Any]]: + """Return list of available datasets from registry.""" + return registry.list_datasets() + -def get_dataset_or_404(dataset_id: str): +def _get_dataset_or_404(dataset_id: str) -> dict[str, Any]: + """Look up a dataset by ID or raise 404.""" dataset = registry.get_dataset(dataset_id) if not dataset: raise HTTPException(status_code=404, detail=f"Dataset '{dataset_id}' not found") return dataset + @router.get("/{dataset_id}", response_model=dict) -def get_dataset(dataset_id: str): - """Get a single dataset by ID. - """ - dataset = get_dataset_or_404(dataset_id) +def get_dataset(dataset_id: str) -> dict[str, Any]: + """Get a single dataset by ID.""" + dataset = _get_dataset_or_404(dataset_id) cache_info = cache.get_cache_info(dataset) dataset.update(cache_info) return dataset + @router.get("/{dataset_id}/build_cache", response_model=dict) -def build_dataset_cache(dataset_id: str, start: str, end: str | None = None, overwrite: bool = False, background_tasks: BackgroundTasks = None): - """Download and cache dataset as local netcdf files direct from the source. - """ - dataset = get_dataset_or_404(dataset_id) +def build_dataset_cache( + dataset_id: str, + start: str, + end: str | None = None, + overwrite: bool = False, + background_tasks: BackgroundTasks | None = None, +) -> dict[str, str]: + """Download and cache dataset as local netcdf files direct from the source.""" + dataset = _get_dataset_or_404(dataset_id) cache.build_dataset_cache(dataset, start=start, end=end, overwrite=overwrite, background_tasks=background_tasks) - return {'status': 'Dataset caching request submitted for processing'} + return {"status": "Dataset caching request submitted for processing"} + @router.get("/{dataset_id}/optimize_cache", response_model=dict) -def optimize_dataset_cache(dataset_id: str, background_tasks: BackgroundTasks = None): - """Optimize dataset cache by collecting all cache files to a single zarr archive. - """ - dataset = get_dataset_or_404(dataset_id) - background_tasks.add_task(cache.optimize_dataset_cache, dataset) - return {'status': 'Dataset cache optimization submitted for processing'} - -def get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation): +def optimize_dataset_cache( + dataset_id: str, + background_tasks: BackgroundTasks | None = None, +) -> dict[str, str]: + """Optimize dataset cache by collecting all cache files to a single zarr archive.""" + dataset = _get_dataset_or_404(dataset_id) + if background_tasks is not None: + background_tasks.add_task(cache.optimize_dataset_cache, dataset) + return {"status": "Dataset cache optimization submitted for processing"} + + +def _get_dataset_period_type( + dataset: dict[str, Any], + period_type: str, + start: str, + end: str, + temporal_aggregation: str, +) -> xr.Dataset: + """Load and temporally aggregate a dataset.""" # TODO: maybe move this and similar somewhere better like a pipelines.py file? - # ... - - # get raster data ds = raster.get_data(dataset, start, end) - - # aggregate to period type ds = raster.to_timeperiod(ds, dataset, period_type, statistic=temporal_aggregation) - - # return return ds -@router.get("/{dataset_id}/{period_type}/orgunits", response_model=list) -def get_dataset_period_type_org_units(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str, spatial_aggregation: str): - """Get a dataset dynamically aggregated to a given period type and org units and return json values. - """ - # get dataset metadata - dataset = get_dataset_or_404(dataset_id) - # get dataset for period type and start/end period - ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) +@router.get("/{dataset_id}/{period_type}/orgunits", response_model=list) +def get_dataset_period_type_org_units( + dataset_id: str, + period_type: str, + start: str, + end: str, + temporal_aggregation: str, + spatial_aggregation: str, +) -> list[dict[str, Any]]: + """Get a dataset aggregated to a given period type and org units as JSON values.""" + dataset = _get_dataset_or_404(dataset_id) + ds = _get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) - # aggregate to geojson features df = raster.to_features(ds, dataset, features=constants.ORG_UNITS_GEOJSON, statistic=spatial_aggregation) # convert units if needed (inplace) - # NOTE: here we do it after agggregation to dataframe to speedup computation + # NOTE: here we do it after aggregation to dataframe to speedup computation units.convert_pandas_units(df, dataset) - # serialize to json - data = serialize.dataframe_to_json_data(df, dataset, period_type) - return data + return serialize.dataframe_to_json_data(df, dataset, period_type) + @router.get("/{dataset_id}/{period_type}/orgunits/preview", response_model=list) -def get_dataset_period_type_org_units_preview(dataset_id: str, period_type: str, period: str, temporal_aggregation: str, spatial_aggregation: str): - """Preview a PNG map image of a dataset dynamically aggregated to a given period and org units. - """ - # get dataset metadata - dataset = get_dataset_or_404(dataset_id) +def get_dataset_period_type_org_units_preview( + dataset_id: str, + period_type: str, + period: str, + temporal_aggregation: str, + spatial_aggregation: str, +) -> Response: + """Preview a PNG map image of a dataset aggregated to a given period and org units.""" + dataset = _get_dataset_or_404(dataset_id) - # get dataset for period type and a single period start = end = period - ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) + ds = _get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) - # aggregate to geojson features df = raster.to_features(ds, dataset, features=constants.ORG_UNITS_GEOJSON, statistic=spatial_aggregation) # convert units if needed (inplace) - # NOTE: here we do it after agggregation to dataframe to speedup computation + # NOTE: here we do it after aggregation to dataframe to speedup computation units.convert_pandas_units(df, dataset) - # serialize to image image_data = serialize.dataframe_to_preview(df, dataset, period_type) - - # return as image return Response(content=image_data, media_type="image/png") -@router.get("/{dataset_id}/{period_type}/raster") -def get_dataset_period_type_raster(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str): - """Get a dataset dynamically aggregated to a given period type and return as downloadable raster file. - """ - # get dataset metadata - dataset = get_dataset_or_404(dataset_id) - # get dataset for period type and start/end period - ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) +@router.get("/{dataset_id}/{period_type}/raster") +def get_dataset_period_type_raster( + dataset_id: str, + period_type: str, + start: str, + end: str, + temporal_aggregation: str, +) -> FileResponse: + """Get a dataset aggregated to a given period type as a downloadable raster file.""" + dataset = _get_dataset_or_404(dataset_id) + ds = _get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) - # convert units if needed (inplace) units.convert_xarray_units(ds, dataset) - # serialize to temporary netcdf file_path = serialize.xarray_to_temporary_netcdf(ds) - - # return as streaming file and delete after completion return FileResponse( file_path, media_type="application/x-netcdf", - filename='eo-api-raster-download.nc', - background=BackgroundTask(serialize.cleanup_file, file_path) + filename="eo-api-raster-download.nc", + background=BackgroundTask(serialize.cleanup_file, file_path), ) + @router.get("/{dataset_id}/{period_type}/raster/preview") -def get_dataset_period_type_raster_preview(dataset_id: str, period_type: str, period: str, temporal_aggregation: str): - """Preview a PNG map image of a dataset dynamically aggregated to a given period. - """ - # get dataset metadata - dataset = get_dataset_or_404(dataset_id) +def get_dataset_period_type_raster_preview( + dataset_id: str, + period_type: str, + period: str, + temporal_aggregation: str, +) -> Response: + """Preview a PNG map image of a dataset aggregated to a given period.""" + dataset = _get_dataset_or_404(dataset_id) - # get dataset for period type and a single period start = end = period - ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) + ds = _get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation) - # convert units if needed (inplace) units.convert_xarray_units(ds, dataset) - # serialize to image image_data = serialize.xarray_to_preview(ds, dataset, period_type) - - # return as image return Response(content=image_data, media_type="image/png") + @router.get("/{dataset_id}/{period_type}/tiles") -def get_dataset_period_type_tiles(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str): - pass +def get_dataset_period_type_tiles( + dataset_id: str, + period_type: str, + start: str, + end: str, + temporal_aggregation: str, +) -> None: + """Placeholder for future tile-based dataset access.""" diff --git a/src/eo_api/datasets/cache.py b/src/eo_api/datasets/cache.py index e8d7cbd..2597985 100644 --- a/src/eo_api/datasets/cache.py +++ b/src/eo_api/datasets/cache.py @@ -1,179 +1,172 @@ +"""Dataset cache: download, store, and optimize raster data as local files.""" + import datetime import importlib import inspect import logging +from collections.abc import Callable from pathlib import Path +from typing import Any import xarray as xr +from fastapi import BackgroundTasks from .constants import BBOX, CACHE_OVERRIDE, COUNTRY_CODE from .utils import get_lon_lat_dims, get_time_dim, numpy_period_string -# logger logger = logging.getLogger(__name__) -# paths SCRIPT_DIR = Path(__file__).parent.resolve() -CACHE_DIR = SCRIPT_DIR / 'cache' +_cache_dir = SCRIPT_DIR / "cache" if CACHE_OVERRIDE: - CACHE_DIR = Path(CACHE_OVERRIDE) - -def build_dataset_cache(dataset, start, end, overwrite, background_tasks): - # get download function - cache_info = dataset['cacheInfo'] - eo_download_func_path = cache_info['eoFunction'] - eo_download_func = get_dynamic_function(eo_download_func_path) - #logger.info(eo_download_func_path, eo_download_func) - - # construct standard params - params = cache_info['defaultParams'] - params.update({ - 'start': start, - 'end': end or datetime.date.today().isoformat(), # todays date if empty - 'dirname': CACHE_DIR, - 'prefix': get_cache_prefix(dataset), - 'overwrite': overwrite, - }) - - # add in varying spatial args + _cache_dir = Path(CACHE_OVERRIDE) +CACHE_DIR: Path = _cache_dir + + +def build_dataset_cache( + dataset: dict[str, Any], + start: str, + end: str | None, + overwrite: bool, + background_tasks: BackgroundTasks | None, +) -> None: + """Download dataset from source and store as local NetCDF cache files.""" + cache_info = dataset["cacheInfo"] + eo_download_func_path = cache_info["eoFunction"] + eo_download_func = _get_dynamic_function(eo_download_func_path) + + params: dict[str, Any] = dict(cache_info["defaultParams"]) + params.update( + { + "start": start, + "end": end or datetime.date.today().isoformat(), + "dirname": CACHE_DIR, + "prefix": _get_cache_prefix(dataset), + "overwrite": overwrite, + } + ) + sig = inspect.signature(eo_download_func) - if 'bbox' in sig.parameters.keys(): - params['bbox'] = BBOX - elif 'country_code' in sig.parameters.keys(): - params['country_code'] = COUNTRY_CODE + if "bbox" in sig.parameters: + params["bbox"] = BBOX + elif "country_code" in sig.parameters: + params["country_code"] = COUNTRY_CODE + + if background_tasks is not None: + background_tasks.add_task(eo_download_func, **params) - # execute the download - background_tasks.add_task(eo_download_func, **params) -def optimize_dataset_cache(dataset): - logger.info(f'Optimizing cache for dataset {dataset["id"]}') +def optimize_dataset_cache(dataset: dict[str, Any]) -> None: + """Collect all cache files into a single optimised zarr archive.""" + logger.info(f"Optimizing cache for dataset {dataset['id']}") - # open all cache files as xarray files = get_cache_files(dataset) - logger.info(f'Opening {len(files)} files from cache') - # for fil in files: - # d = xr.open_dataset(fil) - # print(d) - # fdsfs + logger.info(f"Opening {len(files)} files from cache") ds = xr.open_mfdataset(files) # trim to only minimal vars and coords - logger.info('Trimming unnecessary variables and coordinates') - varname = dataset['variable'] + logger.info("Trimming unnecessary variables and coordinates") + varname = dataset["variable"] ds = ds[[varname]] keep_coords = [get_time_dim(ds)] + list(get_lon_lat_dims(ds)) - drop_coords = [ - c for c in ds.coords - if c not in keep_coords - ] + drop_coords = [c for c in ds.coords if c not in keep_coords] ds = ds.drop_vars(drop_coords) # determine optimal chunk sizes - logger.info('Determining optimal chunk size for zarr archive') - ds_autochunk = ds.chunk('auto').unify_chunks() - # extract the first chunk size for each dimension to force uniformity - uniform_chunks = {dim: ds_autochunk.chunks[dim][0] for dim in ds_autochunk.dims} - # override with time space chunks - time_space_chunks = compute_time_space_chunks(ds, dataset) - uniform_chunks.update( time_space_chunks ) - logging.info(f'--> {uniform_chunks}') + logger.info("Determining optimal chunk size for zarr archive") + ds_autochunk = ds.chunk("auto").unify_chunks() + uniform_chunks: dict[str, Any] = {str(dim): ds_autochunk.chunks[dim][0] for dim in ds_autochunk.dims} + time_space_chunks = _compute_time_space_chunks(ds, dataset) + uniform_chunks.update(time_space_chunks) + logging.info(f"--> {uniform_chunks}") # save as zarr - logger.info('Saving to optimized zarr file') - zarr_path = CACHE_DIR / f'{get_cache_prefix(dataset)}.zarr' + logger.info("Saving to optimized zarr file") + zarr_path = CACHE_DIR / f"{_get_cache_prefix(dataset)}.zarr" ds_chunked = ds.chunk(uniform_chunks) - ds_chunked.to_zarr(zarr_path, mode='w') + ds_chunked.to_zarr(zarr_path, mode="w") ds_chunked.close() - logger.info('Finished cache optimization') + logger.info("Finished cache optimization") -def compute_time_space_chunks(ds, dataset, max_spatial_chunk=256): - chunks = {} - # time - # set to common access patterns depending on original dataset period - # TODO: could potentially allow this to be customized in the dataset yaml file +def _compute_time_space_chunks( + ds: xr.Dataset, + dataset: dict[str, Any], + max_spatial_chunk: int = 256, +) -> dict[str, int]: + """Compute chunk sizes tuned for common temporal access patterns.""" + chunks: dict[str, int] = {} + dim = get_time_dim(ds) - period_type = dataset['periodType'] - if period_type == 'hourly': + period_type = dataset["periodType"] + if period_type == "hourly": chunks[dim] = 24 * 7 - elif period_type == 'daily': + elif period_type == "daily": chunks[dim] = 30 - elif period_type == 'monthly': + elif period_type == "monthly": chunks[dim] = 12 - elif period_type == 'yearly': + elif period_type == "yearly": chunks[dim] = 1 - # space - lon_dim,lat_dim = get_lon_lat_dims(ds) + lon_dim, lat_dim = get_lon_lat_dims(ds) chunks[lon_dim] = min(ds.sizes[lon_dim], max_spatial_chunk) chunks[lat_dim] = min(ds.sizes[lat_dim], max_spatial_chunk) return chunks -def get_cache_info(dataset): - # find all files with cache prefix + +def get_cache_info(dataset: dict[str, Any]) -> dict[str, Any]: + """Return temporal and spatial coverage metadata for the cached dataset.""" files = get_cache_files(dataset) if not files: - cache_info = dict( - temporal_coverage = None, - spatial_coverage = None, - ) - return cache_info + return {"temporal_coverage": None, "spatial_coverage": None} - # open first of sorted filenames, should be sufficient to get earliest time period ds = xr.open_dataset(sorted(files)[0]) - # get dim names time_dim = get_time_dim(ds) lon_dim, lat_dim = get_lon_lat_dims(ds) - # get start time - start = numpy_period_string(ds[time_dim].min().values, dataset['periodType']) + start = numpy_period_string(ds[time_dim].min().values, dataset["periodType"]) # type: ignore[arg-type] - # get space scope - xmin,xmax = ds[lon_dim].min().item(), ds[lon_dim].max().item() - ymin,ymax = ds[lat_dim].min().item(), ds[lat_dim].max().item() + xmin, xmax = ds[lon_dim].min().item(), ds[lon_dim].max().item() + ymin, ymax = ds[lat_dim].min().item(), ds[lat_dim].max().item() - # open last of sorted filenames, should be sufficient to get latest time period ds = xr.open_dataset(sorted(files)[-1]) + end = numpy_period_string(ds[time_dim].max().values, dataset["periodType"]) # type: ignore[arg-type] - # get end time - end = numpy_period_string(ds[time_dim].max().values, dataset['periodType']) + return { + "coverage": { + "temporal": {"start": start, "end": end}, + "spatial": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax}, + } + } - # cache info - cache_info = dict( - coverage=dict( - temporal = {'start': start, 'end': end}, - spatial = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax}, - ) - ) - return cache_info - -def get_cache_prefix(dataset): - prefix = dataset['id'] - return prefix - -def get_cache_files(dataset): - # TODO: this is not bulletproof, eg 2m_temperature might also get another dataset named 2m_temperature_modified - # ...probably need a delimeter to specify end of dataset name... - prefix = get_cache_prefix(dataset) - files = list(CACHE_DIR.glob(f'{prefix}*.nc')) - return files - -def get_zarr_path(dataset): - prefix = get_cache_prefix(dataset) - optimized = CACHE_DIR / f'{prefix}.zarr' + +def _get_cache_prefix(dataset: dict[str, Any]) -> str: + return str(dataset["id"]) + + +def get_cache_files(dataset: dict[str, Any]) -> list[Path]: + """Return all NetCDF cache files matching this dataset's prefix.""" + # TODO: not bulletproof -- e.g. 2m_temperature matches 2m_temperature_modified + prefix = _get_cache_prefix(dataset) + return list(CACHE_DIR.glob(f"{prefix}*.nc")) + + +def get_zarr_path(dataset: dict[str, Any]) -> Path | None: + """Return the optimised zarr archive path if it exists.""" + prefix = _get_cache_prefix(dataset) + optimized = CACHE_DIR / f"{prefix}.zarr" if optimized.exists(): return optimized + return None + -def get_dynamic_function(full_path): - # Split the path into: 'dhis2eo.data.cds.era5_land.hourly' and 'function' - parts = full_path.split('.') +def _get_dynamic_function(full_path: str) -> Callable[..., Any]: + """Import and return a function given its dotted module path.""" + parts = full_path.split(".") module_path = ".".join(parts[:-1]) function_name = parts[-1] - - # This handles all the intermediate sub-package imports automatically module = importlib.import_module(module_path) - - return getattr(module, function_name) + return getattr(module, function_name) # type: ignore[no-any-return] diff --git a/src/eo_api/datasets/constants.py b/src/eo_api/datasets/constants.py index e4cc206..b3b51a2 100644 --- a/src/eo_api/datasets/constants.py +++ b/src/eo_api/datasets/constants.py @@ -1,3 +1,5 @@ +"""Module-level constants loaded at import time (DHIS2 org units, bbox, env config).""" + import json import os @@ -13,5 +15,5 @@ # env variables we need from .env # TODO: should probably centralize to shared config module -COUNTRY_CODE = os.getenv('COUNTRY_CODE') -CACHE_OVERRIDE = os.getenv('CACHE_OVERRIDE') +COUNTRY_CODE = os.getenv("COUNTRY_CODE") +CACHE_OVERRIDE = os.getenv("CACHE_OVERRIDE") diff --git a/src/eo_api/datasets/preprocess.py b/src/eo_api/datasets/preprocess.py index 62ef101..eacbe55 100644 --- a/src/eo_api/datasets/preprocess.py +++ b/src/eo_api/datasets/preprocess.py @@ -1,26 +1,26 @@ +"""Preprocessing functions applied to raster datasets before aggregation.""" + import logging import xarray as xr -# logger logger = logging.getLogger(__name__) -def deaccumulate_era5(ds_cumul): - """Convert ERA5 cumulative hourly data to incremental hourly data""" - logger.info('Deaccumulating ERA5 dataset') +def deaccumulate_era5(ds_cumul: xr.Dataset) -> xr.Dataset: + """Convert ERA5 cumulative hourly data to incremental hourly data.""" + logger.info("Deaccumulating ERA5 dataset") # NOTE: this is hardcoded to era5 specific cumulative patterns and varnames # shift all values to previous hour, so the values don't spill over to the next day ds_cumul = ds_cumul.shift(valid_time=-1) # convert cumulative to diffs - ds_diffs = ds_cumul.diff(dim='valid_time') + ds_diffs = ds_cumul.diff(dim="valid_time") ds_diffs = ds_diffs.reindex(valid_time=ds_cumul.valid_time) # use cumul values where accumulation resets (00:00) and diff everywhere else - is_reset = ds_cumul['valid_time'].dt.hour == 0 + is_reset = ds_cumul["valid_time"].dt.hour == 0 ds_hourly = xr.where(is_reset, ds_cumul, ds_diffs) - # return - return ds_hourly + return ds_hourly # type: ignore[no-any-return] diff --git a/src/eo_api/datasets/raster.py b/src/eo_api/datasets/raster.py index 86afd31..8dce903 100644 --- a/src/eo_api/datasets/raster.py +++ b/src/eo_api/datasets/raster.py @@ -1,140 +1,127 @@ +"""Raster data loading, temporal aggregation, and spatial feature extraction.""" + import json import logging +from typing import Any import geopandas as gpd +import pandas as pd import xarray as xr from earthkit import transforms from . import cache, preprocess from .utils import get_time_dim -# logger logger = logging.getLogger(__name__) -def get_data(dataset, start, end): - """Get xarray raster dataset for given time range""" - # load xarray from cache - logger.info('Opening dataset') - # first check for optimized zarr archive +def get_data(dataset: dict[str, Any], start: str, end: str) -> xr.Dataset: + """Load an xarray raster dataset for the given time range.""" + logger.info("Opening dataset") zarr_path = cache.get_zarr_path(dataset) if zarr_path: - ds = xr.open_zarr(zarr_path, consolidated=True) # consolidated means caching metadatafa - # fallback to reading raw cache files (slower) + ds = xr.open_zarr(zarr_path, consolidated=True) else: - logger.warning(f'Could not find optimized zarr file for dataset {dataset["id"]}, using slower netcdf files instead.') + logger.warning( + f"Could not find optimized zarr file for dataset {dataset['id']}, using slower netcdf files instead." + ) files = cache.get_cache_files(dataset) ds = xr.open_mfdataset( files, data_vars="minimal", - coords="minimal", - compat="override" + coords="minimal", # pyright: ignore[reportArgumentType] + compat="override", ) - # subset time dim - logger.info(f'Subsetting time to {start} and {end}') + logger.info(f"Subsetting time to {start} and {end}") time_dim = get_time_dim(ds) - ds = ds.sel(**{time_dim: slice(start, end)}) + ds = ds.sel(**{time_dim: slice(start, end)}) # pyright: ignore[reportArgumentType] - # apply any preprocessing functions - for prep_name in dataset.get('preProcess', []): + for prep_name in dataset.get("preProcess", []): prep_func = getattr(preprocess, prep_name) ds = prep_func(ds) - # return - return ds + return ds # type: ignore[no-any-return] -def to_timeperiod(ds, dataset, period_type, statistic, timezone_offset=0): - """Aggregate given xarray dataset to another period type""" - # validate period types - valid_period_types = ['hourly', 'daily', 'monthly', 'yearly'] +def to_timeperiod( + ds: xr.Dataset, + dataset: dict[str, Any], + period_type: str, + statistic: str, + timezone_offset: int = 0, +) -> xr.Dataset: + """Aggregate an xarray dataset to another period type.""" + valid_period_types = ["hourly", "daily", "monthly", "yearly"] if period_type not in valid_period_types: - raise ValueError(f'Period type not supported: {period_type}') + raise ValueError(f"Period type not supported: {period_type}") - # return early if no change - if dataset['periodType'] == period_type: + if dataset["periodType"] == period_type: return ds - # begin - logger.info(f'Aggregating period type from {dataset["periodType"]} to {period_type}') + logger.info(f"Aggregating period type from {dataset['periodType']} to {period_type}") - # process only the array belonging to varname - varname = dataset['variable'] + varname = dataset["variable"] arr = ds[varname] - # remember mask of valid pixels from original dataset (only one time point needed) time_dim = get_time_dim(ds) valid = arr.isel({time_dim: 0}).notnull() - # hourly datasets - if dataset['periodType'] == 'hourly': - if period_type == 'daily': + if dataset["periodType"] == "hourly": + if period_type == "daily": arr = transforms.temporal.daily_reduce( arr, how=statistic, time_shift={"hours": timezone_offset}, remove_partial_periods=False, ) - - elif period_type == 'monthly': + elif period_type == "monthly": arr = transforms.temporal.monthly_reduce( arr, how=statistic, time_shift={"hours": timezone_offset}, remove_partial_periods=False, ) - else: - raise Exception(f'Unsupported period aggregation from {dataset["periodType"]} to {period_type}') + raise ValueError(f"Unsupported period aggregation from {dataset['periodType']} to {period_type}") - # daily datasets - elif dataset['periodType'] == 'daily': - if period_type == 'monthly': + elif dataset["periodType"] == "daily": + if period_type == "monthly": arr = transforms.temporal.monthly_reduce( arr, how=statistic, remove_partial_periods=False, ) - else: - raise Exception(f'Unsupported period aggregation from {dataset["periodType"]} to {period_type}') + raise ValueError(f"Unsupported period aggregation from {dataset['periodType']} to {period_type}") else: - raise Exception(f'Unsupported period aggregation from {dataset["periodType"]} to {period_type}') + raise ValueError(f"Unsupported period aggregation from {dataset['periodType']} to {period_type}") - # apply the original mask in case the aggregation turned nan values to 0s arr = xr.where(valid, arr, None) - - # IMPORTANT: compute to avoid slow dask graphs arr = arr.compute() - - # convert back to dataset ds = arr.to_dataset() - # return return ds -def to_features(ds, dataset, features, statistic): - """Aggregate given xarray to geojson features and return pandas dataframe""" - logger.info('Aggregating to org units') +def to_features( + ds: xr.Dataset, + dataset: dict[str, Any], + features: dict[str, Any], + statistic: str, +) -> pd.DataFrame: + """Aggregate an xarray dataset to GeoJSON features and return a DataFrame.""" + logger.info("Aggregating to org units") - # load geojson as geopandas gdf = gpd.read_file(json.dumps(features)) - # aggregate - varname = dataset['variable'] - ds = transforms.spatial.reduce( + varname = dataset["variable"] + ds_reduced = transforms.spatial.reduce( ds[varname], gdf, - mask_dim="id", # TODO: DONT HARDCODE + mask_dim="id", # TODO: DONT HARDCODE how=statistic, ) - # convert to df - df = ds.to_dataframe().reset_index() - - # return - return df - + return ds_reduced.to_dataframe().reset_index() # type: ignore[no-any-return] diff --git a/src/eo_api/datasets/registry.py b/src/eo_api/datasets/registry.py index a5ffd78..03ea31f 100644 --- a/src/eo_api/datasets/registry.py +++ b/src/eo_api/datasets/registry.py @@ -1,40 +1,37 @@ +"""Dataset registry backed by YAML config files.""" + +import logging from pathlib import Path +from typing import Any import yaml +logger = logging.getLogger(__name__) + SCRIPT_DIR = Path(__file__).parent.resolve() -CONFIGS_DIR = SCRIPT_DIR / 'registry' +CONFIGS_DIR = SCRIPT_DIR / "registry" + -def list_datasets(): - """Loops through configs folder, loads YAML files, and returns a list - of datasets. - """ - datasets = [] +def list_datasets() -> list[dict[str, Any]]: + """Load all YAML files in the registry folder and return a flat list of datasets.""" + datasets: list[dict[str, Any]] = [] folder = CONFIGS_DIR - # Check if directory exists if not folder.is_dir(): raise ValueError(f"Path is not a directory: {folder}") - # Iterate over .yaml and .yml files - for file_path in folder.glob('*.y*ml'): + for file_path in folder.glob("*.y*ml"): try: - with open(file_path, 'r', encoding='utf-8') as f: - # Use safe_load to avoid security risks + with open(file_path, encoding="utf-8") as f: file_datasets = yaml.safe_load(f) - datasets.extend( file_datasets ) - except Exception as e: - print(f"Error loading {file_path.name}: {e}") + datasets.extend(file_datasets) + except Exception: + logger.exception("Error loading %s", file_path.name) return datasets -def get_dataset(dataset_id): - """Get dataset dict for a given id - """ - datasets_lookup = {d['id']: d for d in list_datasets()} - if dataset_id in datasets_lookup: - # get base dataset info - dataset = datasets_lookup[dataset_id] - # return - return dataset +def get_dataset(dataset_id: str) -> dict[str, Any] | None: + """Get dataset dict for a given id.""" + datasets_lookup = {d["id"]: d for d in list_datasets()} + return datasets_lookup.get(dataset_id) diff --git a/src/eo_api/datasets/serialize.py b/src/eo_api/datasets/serialize.py index 4586854..e7761d8 100644 --- a/src/eo_api/datasets/serialize.py +++ b/src/eo_api/datasets/serialize.py @@ -1,10 +1,15 @@ +"""Serialization of xarray/pandas data to JSON, PNG previews, and NetCDF files.""" + import io import json import logging import os import tempfile +from typing import Any import geopandas as gpd +import pandas as pd +import xarray as xr from matplotlib.figure import Figure from . import constants @@ -12,99 +17,82 @@ logger = logging.getLogger(__name__) -def dataframe_to_json_data(df, dataset, period_type): - time_dim = get_time_dim(df) - varname = dataset['variable'] - # create smaller dataframe with known columns - temp_df = df[[time_dim, "id", varname]].rename(columns={time_dim:'period', 'id':'orgunit', varname:'value'}) +def dataframe_to_json_data(df: pd.DataFrame, dataset: dict[str, Any], period_type: str) -> list[dict[str, Any]]: + """Convert a DataFrame to a list of ``{period, orgunit, value}`` dicts.""" + time_dim = get_time_dim(df) + varname = dataset["variable"] - # convert period string depending on period type - temp_df['period'] = pandas_period_string(temp_df['period'], period_type) + temp_df = df[[time_dim, "id", varname]].rename(columns={time_dim: "period", "id": "orgunit", varname: "value"}) + temp_df["period"] = pandas_period_string(temp_df["period"], period_type) - # convert to list of json dicts - data = temp_df.to_dict(orient="records") + return temp_df.to_dict(orient="records") # type: ignore[return-value] - # return - return data -def dataframe_to_preview(df, dataset, period_type): - logger.info('Generating dataframe map preview') +def dataframe_to_preview(df: pd.DataFrame, dataset: dict[str, Any], period_type: str) -> bytes: + """Render a DataFrame as a choropleth PNG map image.""" + logger.info("Generating dataframe map preview") time_dim = get_time_dim(df) - varname = dataset['variable'] + varname = dataset["variable"] - # create smaller dataframe with known columns temp_df = df[[time_dim, "id", varname]] - - # convert period string depending on period type temp_df[time_dim] = pandas_period_string(temp_df[time_dim], period_type) - # validate only one period assert len(temp_df[time_dim].unique()) == 1 - # merge with org units geojson org_units = gpd.read_file(json.dumps(constants.ORG_UNITS_GEOJSON)) - org_units_with_temp = org_units.merge(temp_df, on='id', how='left') + org_units_with_temp = org_units.merge(temp_df, on="id", how="left") - # plot to map fig = Figure() ax = fig.subplots() period = temp_df[time_dim].values[0] - org_units_with_temp.plot(ax=ax, column=varname, legend=True, legend_kwds={'label': varname}) - ax.set_title(f'{period}') + org_units_with_temp.plot(ax=ax, column=varname, legend=True, legend_kwds={"label": varname}) + ax.set_title(f"{period}") - # save to in-memory image buf = io.BytesIO() fig.savefig(buf, format="png", dpi=150) buf.seek(0) - # return as image image_data = buf.getvalue() buf.close() return image_data -def xarray_to_preview(ds, dataset, period_type): - logger.info('Generating xarray map preview') + +def xarray_to_preview(ds: xr.Dataset, dataset: dict[str, Any], period_type: str) -> bytes: + """Render an xarray Dataset as a PNG map image.""" + logger.info("Generating xarray map preview") time_dim = get_time_dim(ds) - varname = dataset['variable'] + varname = dataset["variable"] - # create smaller dataframe with known columns temp_ds = ds[[time_dim, varname]] + temp_ds = temp_ds.assign_coords({time_dim: lambda x: numpy_period_array(x[time_dim].values, period_type)}) - # convert period string depending on period type - temp_ds = temp_ds.assign_coords({ - time_dim: lambda x: numpy_period_array(x[time_dim].values, period_type) - }) - - # validate only one period assert len(temp_ds[time_dim].values) == 1 - # plot to map fig = Figure() ax = fig.subplots() period = temp_ds[time_dim].values[0] temp_ds[varname].plot(ax=ax) - ax.set_title(f'{period}') + ax.set_title(f"{period}") - # save to in-memory image buf = io.BytesIO() fig.savefig(buf, format="png", dpi=150) buf.seek(0) - # return as image image_data = buf.getvalue() buf.close() return image_data -def xarray_to_temporary_netcdf(ds): - # temporary file path - path = tempfile.mktemp() - # save to path +def xarray_to_temporary_netcdf(ds: xr.Dataset) -> str: + """Write a dataset to a temporary NetCDF file and return the path.""" + fd = tempfile.NamedTemporaryFile(suffix=".nc", delete=False) + path = fd.name + fd.close() ds.to_netcdf(path) - - # return return path -def cleanup_file(path: str): + +def cleanup_file(path: str) -> None: + """Remove a file from disk.""" os.remove(path) diff --git a/src/eo_api/datasets/units.py b/src/eo_api/datasets/units.py index fa00758..9a56ca8 100644 --- a/src/eo_api/datasets/units.py +++ b/src/eo_api/datasets/units.py @@ -1,40 +1,39 @@ +"""Unit conversion helpers for pandas DataFrames and xarray Datasets.""" + import logging +from typing import Any +import xarray as xr from metpy.units import units -# logger logger = logging.getLogger(__name__) -def convert_pandas_units(ds, dataset): - varname = dataset['variable'] - from_units = dataset['units'] - to_units = dataset.get('convertUnits') + +def convert_pandas_units(ds: Any, dataset: dict[str, Any]) -> None: + """Convert values in a pandas DataFrame column from source to target units.""" + varname = dataset["variable"] + from_units = dataset["units"] + to_units = dataset.get("convertUnits") if to_units and to_units != from_units: logger.info(f"Applying unit conversion from {from_units} to {to_units}...") - # values with source units values_with_units = ds[varname].values * units(from_units) - # convert to target units converted = values_with_units.to(to_units).magnitude - # update the dataframe ds[varname] = converted - else: logger.info("No unit conversion needed") -def convert_xarray_units(ds, dataset): - varname = dataset['variable'] - from_units = dataset['units'] - to_units = dataset.get('convertUnits') + +def convert_xarray_units(ds: xr.Dataset, dataset: dict[str, Any]) -> None: + """Convert values in an xarray Dataset variable from source to target units.""" + varname = dataset["variable"] + from_units = dataset["units"] + to_units = dataset.get("convertUnits") if to_units and to_units != from_units: logger.info(f"Applying unit conversion from {from_units} to {to_units}...") - # values with source units values_with_units = ds[varname].values * units(from_units) - # convert to target units converted = values_with_units.to(to_units).magnitude - # update the ds ds[varname].values = converted - else: logger.info("No unit conversion needed") diff --git a/src/eo_api/datasets/utils.py b/src/eo_api/datasets/utils.py index 981c1d5..374ad5e 100644 --- a/src/eo_api/datasets/utils.py +++ b/src/eo_api/datasets/utils.py @@ -1,73 +1,68 @@ +"""Utility helpers for time and spatial dimension discovery and formatting.""" + +from typing import Any import numpy as np +import pandas as pd -def get_time_dim(ds): - # get first available time dim - time_dim = None - for time_name in ['valid_time', 'time']: +def get_time_dim(ds: Any) -> str: + """Return the name of the time dimension in a dataset or dataframe.""" + for time_name in ["valid_time", "time"]: if hasattr(ds, time_name): - time_dim = time_name - break - if time_dim is None: - raise Exception(f'Unable to find time dimension: {ds.coordinates}') - - return time_dim - -def get_lon_lat_dims(ds): - # get first available spatial dim - lat_dim = None - lon_dim = None - for lon_name,lat_name in [('lon','lat'), ('longitude','latitude'), ('x','y')]: + return time_name + raise ValueError(f"Unable to find time dimension: {ds.coordinates}") + + +def get_lon_lat_dims(ds: Any) -> tuple[str, str]: + """Return ``(lon, lat)`` dimension names from a dataset.""" + for lon_name, lat_name in [("lon", "lat"), ("longitude", "latitude"), ("x", "y")]: if hasattr(ds, lat_name): - lat_dim = lat_name - lon_dim = lon_name - break - if lat_dim is None: - raise Exception(f'Unable to find space dimension: {ds.coordinates}') + return lon_name, lat_name + raise ValueError(f"Unable to find space dimension: {ds.coordinates}") - return lon_dim, lat_dim def numpy_period_string(t: np.datetime64, period_type: str) -> str: - # convert numpy dateime to period string + """Convert a single numpy datetime to a period string.""" s = np.datetime_as_string(t, unit="s") if period_type == "hourly": - return s[:13] # YYYY-MM-DDTHH + return s[:13] # YYYY-MM-DDTHH if period_type == "daily": - return s[:10] # YYYY-MM-DD + return s[:10] # YYYY-MM-DD if period_type == "monthly": - return s[:7] # YYYY-MM + return s[:7] # YYYY-MM if period_type == "yearly": - return s[:4] # YYYY + return s[:4] # YYYY raise ValueError(f"Unknown periodType: {period_type}") -def numpy_period_array(t_array: np.ndarray, period_type: str) -> np.ndarray: - # TODO: this and numpy_period_string should be merged - # ... - # Convert the whole array to strings at once +def numpy_period_array(t_array: np.ndarray[Any, Any], period_type: str) -> np.ndarray[Any, Any]: + """Convert an array of numpy datetimes to truncated period strings.""" + # TODO: this and numpy_period_string should be merged s = np.datetime_as_string(t_array, unit="s") # Map periods to string lengths: YYYY-MM-DDTHH (13), YYYY-MM-DD (10), etc. lengths = {"hourly": 13, "daily": 10, "monthly": 7, "yearly": 4} return s.astype(f"U{lengths[period_type]}") -def pandas_period_string(column, period_type): + +def pandas_period_string(column: pd.Series[Any], period_type: str) -> pd.Series[Any]: + """Format a pandas datetime column as period strings.""" if period_type == "hourly": - return column.dt.strftime('%Y-%m-%dT%H') + return column.dt.strftime("%Y-%m-%dT%H") # type: ignore[no-any-return] if period_type == "daily": - return column.dt.strftime('%Y-%m-%d') + return column.dt.strftime("%Y-%m-%d") # type: ignore[no-any-return] if period_type == "monthly": - return column.dt.strftime('%Y-%m') + return column.dt.strftime("%Y-%m") # type: ignore[no-any-return] if period_type == "yearly": - return column.dt.strftime('%Y') + return column.dt.strftime("%Y") # type: ignore[no-any-return] raise ValueError(f"Unknown periodType: {period_type}")