Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion dataretrieval/waterdata/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_nearest_continuous(
... )
"""
_check_nearest_kwargs(kwargs, on_tie)
targets = pd.DatetimeIndex(pd.to_datetime(targets, utc=True))
targets = _coerce_targets(targets)
window_td = pd.Timedelta(window)

if len(targets) == 0:
Expand All @@ -151,6 +151,11 @@ def get_nearest_continuous(
filter_lang="cql-text",
**kwargs,
)
if "time" not in df.columns:
raise ValueError(
"get_nearest_continuous requires a 'time' column in the response; "
"if a `properties` kwarg was passed, include 'time' in it"
)
if df.empty:
return _empty_nearest_result(df), md

Expand All @@ -172,6 +177,21 @@ def get_nearest_continuous(
return pd.DataFrame(selected).reset_index(drop=True), md


def _coerce_targets(targets) -> pd.DatetimeIndex:
"""Accept anything ``pandas.to_datetime`` consumes, including a single value.

A bare scalar (string, ``Timestamp``, ``datetime``, …) becomes a
one-element ``DatetimeIndex``; an iterable (list, ``Series``, ``ndarray``)
is wrapped directly so its elements are preserved.
"""
parsed = pd.to_datetime(targets, utc=True)
if isinstance(parsed, pd.DatetimeIndex):
return parsed
if pd.api.types.is_scalar(parsed):
return pd.DatetimeIndex([parsed])
return pd.DatetimeIndex(parsed)


def _check_nearest_kwargs(kwargs: dict, on_tie: OnTie) -> None:
"""Reject kwargs the helper owns; validate ``on_tie``."""
for forbidden in ("time", "filter", "filter_lang"):
Expand Down
65 changes: 65 additions & 0 deletions tests/waterdata_nearest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,68 @@ def test_forwards_kwargs_to_get_continuous(patch_get_continuous):
_, kwargs = patch_get_continuous.call_args
assert kwargs["statistic_id"] == "00011"
assert kwargs["approval_status"] == "Approved"


def test_accepts_single_string_target(patch_get_continuous):
"""A bare scalar target must round-trip through pd.to_datetime.

Regression: previously `pd.DatetimeIndex(pd.to_datetime("...", utc=True))`
raised TypeError because pd.to_datetime returns a scalar Timestamp for a
single-string input.
"""
patch_get_continuous.return_value = (
_fake_df([{"time": "2023-06-15T10:30:00Z", "value": 22.4}]),
mock.Mock(),
)
result, _ = get_nearest_continuous(
"2023-06-15T10:30:31Z", monitoring_location_id="USGS-02238500"
)
assert len(result) == 1
assert result["target_time"].iloc[0] == pd.Timestamp("2023-06-15T10:30:31Z")


def test_accepts_single_timestamp_target(patch_get_continuous):
"""A single ``pd.Timestamp`` target also round-trips."""
patch_get_continuous.return_value = (
_fake_df([{"time": "2023-06-15T10:30:00Z", "value": 22.4}]),
mock.Mock(),
)
target = pd.Timestamp("2023-06-15T10:30:31Z")
result, _ = get_nearest_continuous(target, monitoring_location_id="USGS-02238500")
assert len(result) == 1


def test_accepts_pandas_series_targets(patch_get_continuous):
"""A ``pd.Series`` of timestamps preserves all elements (not just the first)."""
patch_get_continuous.return_value = (
_fake_df(
[
{"time": "2023-06-15T10:30:00Z", "value": 22.4},
{"time": "2023-06-16T10:30:00Z", "value": 22.5},
]
),
mock.Mock(),
)
targets = pd.Series(["2023-06-15T10:30:31Z", "2023-06-16T10:30:31Z"])
result, _ = get_nearest_continuous(targets, monitoring_location_id="USGS-02238500")
assert len(result) == 2


def test_missing_time_column_raises_helpful_error(patch_get_continuous):
"""If the response has no 'time' column (e.g. user passed `properties`
that excluded it), raise ValueError instead of crashing with KeyError.
"""
df_no_time = pd.DataFrame(
{
"value": [22.4],
"monitoring_location_id": ["USGS-02238500"],
}
)
patch_get_continuous.return_value = (df_no_time, mock.Mock())

with pytest.raises(ValueError, match="'time' column"):
get_nearest_continuous(
["2023-06-15T10:30:31Z"],
monitoring_location_id="USGS-02238500",
properties=["value", "monitoring_location_id"],
)
Loading