Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
001851f
first partial commit of dataset registry and caching
Feb 23, 2026
76177ed
chore: allow arbitrary dhis2eo download funcs, get dataset cache info…
Feb 23, 2026
b3d66bf
chore: define datasets router prefix centrally
Feb 23, 2026
8d1b3af
complete aggregate pipeline with unit conversion, hacky constants module
Feb 23, 2026
8e99b86
make aggregation statistics dynamic and not hardcoded for each dataset
Feb 24, 2026
073a6a1
reorganize api so time and space aggregation happens hierarchically b…
Feb 24, 2026
937a277
move some funcs to utils.py
Feb 24, 2026
b48bd80
add raster with time period aggrgation download endpoint as part of t…
Feb 24, 2026
7752544
add dummy tiles endpoint
Feb 24, 2026
21c6d00
apply unit conversion earlier so it also affects raster download
Feb 24, 2026
297c6e7
switch to proper logging
Feb 24, 2026
7890d1f
check and raise dataset id errors more gracefully
Feb 24, 2026
b9d27c2
validate target period type and skip temporal aggregation if same per…
Feb 24, 2026
d95c1d6
fix misc dataset metadata
Feb 24, 2026
e243b8d
fix period type validation bug, redo json serialization and show corr…
Feb 24, 2026
ccca427
easier to switch out to different org units, cache builds to latest d…
Feb 24, 2026
0e56e82
switch to more reliable cache background worker, fix dynamic bbox error
Feb 24, 2026
825e3ea
speedup unit conversion by doing it on fewer values, stabilize intern…
Feb 24, 2026
134070e
add cache optimization which builds zarr archive, dataset openers use…
Feb 24, 2026
367b06d
improve read and write of zarr cache
Feb 25, 2026
62c33db
fix same period type aggregation error, clarify array subsetting in t…
Feb 25, 2026
0c426e2
smaller chunk size for improved read speed, memory, and optimize buil…
Feb 25, 2026
a87c106
add simple map image preview endpoints for orgunits and raster endpoints
Feb 25, 2026
2786da9
fix more proper preview map figure generation that avoids thread-unsa…
Feb 25, 2026
437dd65
upgrade to dhis2eo v1.1.1, tweak preview figures
Feb 26, 2026
15d69ff
merge with main
Feb 26, 2026
42f33bd
remove old files, make datasets module work with new setup
Feb 26, 2026
335adac
update docstring for preview endpoints
Feb 26, 2026
9ccfc1f
some ruff fixes, ignore datasets folder for ruff and mypy checks
Feb 26, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
__pycache__/
src/eo_api/datasets/cache
.venv/
.env
eo_api.egg-info/
Expand Down
Empty file added __init__.py
Empty file.
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@ dependencies = [
"geojson-pydantic>=2.1.0",
"httpx>=0.28.1",
"prefect>=3.6",
"earthkit-transforms==0.5.*",
"metpy>=1.7,<2",
"zarr==3.1.5",
]

[tool.ruff]
target-version = "py313"
line-length = 120
exclude = [
"src/eo_api/datasets"
]

[tool.ruff.lint]
fixable = ["ALL"]
Expand Down Expand Up @@ -52,6 +58,9 @@ no_implicit_optional = true
warn_unused_ignores = true
strict_equality = true
mypy_path = ["src"]
exclude = [
"src/eo_api/datasets"
]

[[tool.mypy.overrides]]
module = "tests.*"
Expand Down
Empty file added src/eo_api/datasets/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions src/eo_api/datasets/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@

from fastapi import APIRouter, BackgroundTasks, HTTPException, Response
from fastapi.responses import FileResponse
from starlette.background import BackgroundTask

from . import cache, constants, raster, registry, serialize, units

router = APIRouter()

@router.get("/")
def list_datasets():
"""Returned list of available datasets from registry.
"""
datasets = registry.list_datasets()
return datasets

def get_dataset_or_404(dataset_id: str):
dataset = registry.get_dataset(dataset_id)
if not dataset:
raise HTTPException(status_code=404, detail=f"Dataset '{dataset_id}' not found")
return dataset

@router.get("/{dataset_id}", response_model=dict)
def get_dataset(dataset_id: str):
"""Get a single dataset by ID.
"""
dataset = get_dataset_or_404(dataset_id)
cache_info = cache.get_cache_info(dataset)
dataset.update(cache_info)
return dataset

@router.get("/{dataset_id}/build_cache", response_model=dict)
def build_dataset_cache(dataset_id: str, start: str, end: str | None = None, overwrite: bool = False, background_tasks: BackgroundTasks = None):
"""Download and cache dataset as local netcdf files direct from the source.
"""
dataset = get_dataset_or_404(dataset_id)
cache.build_dataset_cache(dataset, start=start, end=end, overwrite=overwrite, background_tasks=background_tasks)
return {'status': 'Dataset caching request submitted for processing'}

@router.get("/{dataset_id}/optimize_cache", response_model=dict)
def optimize_dataset_cache(dataset_id: str, background_tasks: BackgroundTasks = None):
"""Optimize dataset cache by collecting all cache files to a single zarr archive.
"""
dataset = get_dataset_or_404(dataset_id)
background_tasks.add_task(cache.optimize_dataset_cache, dataset)
return {'status': 'Dataset cache optimization submitted for processing'}

def get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation):
# TODO: maybe move this and similar somewhere better like a pipelines.py file?
# ...

# get raster data
ds = raster.get_data(dataset, start, end)

# aggregate to period type
ds = raster.to_timeperiod(ds, dataset, period_type, statistic=temporal_aggregation)

# return
return ds

@router.get("/{dataset_id}/{period_type}/orgunits", response_model=list)
def get_dataset_period_type_org_units(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str, spatial_aggregation: str):
"""Get a dataset dynamically aggregated to a given period type and org units and return json values.
"""
# get dataset metadata
dataset = get_dataset_or_404(dataset_id)

# get dataset for period type and start/end period
ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation)

# aggregate to geojson features
df = raster.to_features(ds, dataset, features=constants.ORG_UNITS_GEOJSON, statistic=spatial_aggregation)

# convert units if needed (inplace)
# NOTE: here we do it after agggregation to dataframe to speedup computation
units.convert_pandas_units(df, dataset)

# serialize to json
data = serialize.dataframe_to_json_data(df, dataset, period_type)
return data

@router.get("/{dataset_id}/{period_type}/orgunits/preview", response_model=list)
def get_dataset_period_type_org_units_preview(dataset_id: str, period_type: str, period: str, temporal_aggregation: str, spatial_aggregation: str):
"""Preview a PNG map image of a dataset dynamically aggregated to a given period and org units.
"""
# get dataset metadata
dataset = get_dataset_or_404(dataset_id)

# get dataset for period type and a single period
start = end = period
ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation)

# aggregate to geojson features
df = raster.to_features(ds, dataset, features=constants.ORG_UNITS_GEOJSON, statistic=spatial_aggregation)

# convert units if needed (inplace)
# NOTE: here we do it after agggregation to dataframe to speedup computation
units.convert_pandas_units(df, dataset)

# serialize to image
image_data = serialize.dataframe_to_preview(df, dataset, period_type)

# return as image
return Response(content=image_data, media_type="image/png")

@router.get("/{dataset_id}/{period_type}/raster")
def get_dataset_period_type_raster(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str):
"""Get a dataset dynamically aggregated to a given period type and return as downloadable raster file.
"""
# get dataset metadata
dataset = get_dataset_or_404(dataset_id)

# get dataset for period type and start/end period
ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation)

# convert units if needed (inplace)
units.convert_xarray_units(ds, dataset)

# serialize to temporary netcdf
file_path = serialize.xarray_to_temporary_netcdf(ds)

# return as streaming file and delete after completion
return FileResponse(
file_path,
media_type="application/x-netcdf",
filename='eo-api-raster-download.nc',
background=BackgroundTask(serialize.cleanup_file, file_path)
)

@router.get("/{dataset_id}/{period_type}/raster/preview")
def get_dataset_period_type_raster_preview(dataset_id: str, period_type: str, period: str, temporal_aggregation: str):
"""Preview a PNG map image of a dataset dynamically aggregated to a given period.
"""
# get dataset metadata
dataset = get_dataset_or_404(dataset_id)

# get dataset for period type and a single period
start = end = period
ds = get_dataset_period_type(dataset, period_type, start, end, temporal_aggregation)

# convert units if needed (inplace)
units.convert_xarray_units(ds, dataset)

# serialize to image
image_data = serialize.xarray_to_preview(ds, dataset, period_type)

# return as image
return Response(content=image_data, media_type="image/png")

@router.get("/{dataset_id}/{period_type}/tiles")
def get_dataset_period_type_tiles(dataset_id: str, period_type: str, start: str, end: str, temporal_aggregation: str):
pass
179 changes: 179 additions & 0 deletions src/eo_api/datasets/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import datetime
import importlib
import inspect
import logging
from pathlib import Path

import xarray as xr

from .constants import BBOX, CACHE_OVERRIDE, COUNTRY_CODE
from .utils import get_lon_lat_dims, get_time_dim, numpy_period_string

# logger
logger = logging.getLogger(__name__)

# paths
SCRIPT_DIR = Path(__file__).parent.resolve()
CACHE_DIR = SCRIPT_DIR / 'cache'
if CACHE_OVERRIDE:
CACHE_DIR = Path(CACHE_OVERRIDE)

def build_dataset_cache(dataset, start, end, overwrite, background_tasks):
# get download function
cache_info = dataset['cacheInfo']
eo_download_func_path = cache_info['eoFunction']
eo_download_func = get_dynamic_function(eo_download_func_path)
#logger.info(eo_download_func_path, eo_download_func)

# construct standard params
params = cache_info['defaultParams']
params.update({
'start': start,
'end': end or datetime.date.today().isoformat(), # todays date if empty
'dirname': CACHE_DIR,
'prefix': get_cache_prefix(dataset),
'overwrite': overwrite,
})

# add in varying spatial args
sig = inspect.signature(eo_download_func)
if 'bbox' in sig.parameters.keys():
params['bbox'] = BBOX
elif 'country_code' in sig.parameters.keys():
params['country_code'] = COUNTRY_CODE

# execute the download
background_tasks.add_task(eo_download_func, **params)

def optimize_dataset_cache(dataset):
logger.info(f'Optimizing cache for dataset {dataset["id"]}')

# open all cache files as xarray
files = get_cache_files(dataset)
logger.info(f'Opening {len(files)} files from cache')
# for fil in files:
# d = xr.open_dataset(fil)
# print(d)
# fdsfs
ds = xr.open_mfdataset(files)

# trim to only minimal vars and coords
logger.info('Trimming unnecessary variables and coordinates')
varname = dataset['variable']
ds = ds[[varname]]
keep_coords = [get_time_dim(ds)] + list(get_lon_lat_dims(ds))
drop_coords = [
c for c in ds.coords
if c not in keep_coords
]
ds = ds.drop_vars(drop_coords)

# determine optimal chunk sizes
logger.info('Determining optimal chunk size for zarr archive')
ds_autochunk = ds.chunk('auto').unify_chunks()
# extract the first chunk size for each dimension to force uniformity
uniform_chunks = {dim: ds_autochunk.chunks[dim][0] for dim in ds_autochunk.dims}
# override with time space chunks
time_space_chunks = compute_time_space_chunks(ds, dataset)
uniform_chunks.update( time_space_chunks )
logging.info(f'--> {uniform_chunks}')

# save as zarr
logger.info('Saving to optimized zarr file')
zarr_path = CACHE_DIR / f'{get_cache_prefix(dataset)}.zarr'
ds_chunked = ds.chunk(uniform_chunks)
ds_chunked.to_zarr(zarr_path, mode='w')
ds_chunked.close()

logger.info('Finished cache optimization')

def compute_time_space_chunks(ds, dataset, max_spatial_chunk=256):
chunks = {}

# time
# set to common access patterns depending on original dataset period
# TODO: could potentially allow this to be customized in the dataset yaml file
dim = get_time_dim(ds)
period_type = dataset['periodType']
if period_type == 'hourly':
chunks[dim] = 24 * 7
elif period_type == 'daily':
chunks[dim] = 30
elif period_type == 'monthly':
chunks[dim] = 12
elif period_type == 'yearly':
chunks[dim] = 1

# space
lon_dim,lat_dim = get_lon_lat_dims(ds)
chunks[lon_dim] = min(ds.sizes[lon_dim], max_spatial_chunk)
chunks[lat_dim] = min(ds.sizes[lat_dim], max_spatial_chunk)

return chunks

def get_cache_info(dataset):
# find all files with cache prefix
files = get_cache_files(dataset)
if not files:
cache_info = dict(
temporal_coverage = None,
spatial_coverage = None,
)
return cache_info

# open first of sorted filenames, should be sufficient to get earliest time period
ds = xr.open_dataset(sorted(files)[0])

# get dim names
time_dim = get_time_dim(ds)
lon_dim, lat_dim = get_lon_lat_dims(ds)

# get start time
start = numpy_period_string(ds[time_dim].min().values, dataset['periodType'])

# get space scope
xmin,xmax = ds[lon_dim].min().item(), ds[lon_dim].max().item()
ymin,ymax = ds[lat_dim].min().item(), ds[lat_dim].max().item()

# open last of sorted filenames, should be sufficient to get latest time period
ds = xr.open_dataset(sorted(files)[-1])

# get end time
end = numpy_period_string(ds[time_dim].max().values, dataset['periodType'])

# cache info
cache_info = dict(
coverage=dict(
temporal = {'start': start, 'end': end},
spatial = {'xmin': xmin, 'ymin': ymin, 'xmax': xmax, 'ymax': ymax},
)
)
return cache_info

def get_cache_prefix(dataset):
prefix = dataset['id']
return prefix

def get_cache_files(dataset):
# TODO: this is not bulletproof, eg 2m_temperature might also get another dataset named 2m_temperature_modified
# ...probably need a delimeter to specify end of dataset name...
prefix = get_cache_prefix(dataset)
files = list(CACHE_DIR.glob(f'{prefix}*.nc'))
return files

def get_zarr_path(dataset):
prefix = get_cache_prefix(dataset)
optimized = CACHE_DIR / f'{prefix}.zarr'
if optimized.exists():
return optimized

def get_dynamic_function(full_path):
# Split the path into: 'dhis2eo.data.cds.era5_land.hourly' and 'function'
parts = full_path.split('.')
module_path = ".".join(parts[:-1])
function_name = parts[-1]

# This handles all the intermediate sub-package imports automatically
module = importlib.import_module(module_path)

return getattr(module, function_name)
17 changes: 17 additions & 0 deletions src/eo_api/datasets/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json
import os

import geopandas as gpd

from ..integrations.dhis2_adapter import create_client, get_org_units_geojson

# load geojson from dhis2 at startup and keep in-memory
# TODO: should probably save to file instead
client = create_client()
ORG_UNITS_GEOJSON = get_org_units_geojson(client, level=2)
BBOX = list(map(float, gpd.read_file(json.dumps(ORG_UNITS_GEOJSON)).total_bounds))

# env variables we need from .env
# TODO: should probably centralize to shared config module
COUNTRY_CODE = os.getenv('COUNTRY_CODE')
CACHE_OVERRIDE = os.getenv('CACHE_OVERRIDE')
Loading
Loading