diff --git a/CLAUDE.md b/CLAUDE.md index 92294dbe7..6d4bf185f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -54,6 +54,21 @@ This updates `roadmap/README.md` from the feature frontmatter. ## Coding Conventions +### Public vs. Internal API + +A name is **private** if any segment of its import path starts with `_`. +Mechanically enforced by `PLC2701`: no module may import a `_foo` name from +another module. Underscores on function names are otherwise a convention — +authors may use them as a hint that a function is file-local, but this is +not required and not enforced. + +- Public: `modelskill.matching.match`, `modelskill.utils.rename_coords_xr` +- Private (module path): `modelskill._names.get_name`, `modelskill.timeseries._timeseries.validate_data_var_name` +- File-local hint: a leading `_` on a function name inside any module, marking no cross-module use + +See [ADR-012](adr/012-public-private-api-convention.md). Enforced in CI by +Ruff rule `PLC2701` (zero violations in `src/`; tests have a per-file ignore). + ### Docstrings - All docstrings use **NumPy format** (not Google or reStructuredText style) - Include sections: Parameters, Returns, Raises, Examples, See Also, Notes as appropriate @@ -169,7 +184,6 @@ Plots support both matplotlib (static) and plotly (interactive) backends. - Internal data storage uses xarray Datasets with standardized coordinate/variable names - Time coordinates use pandas datetime64 - Spatial coordinates: `x`, `y` (and `z` when applicable) -- Reserved names in `_RESERVED_NAMES` should not be used for model/observation names - The `Quantity` class handles physical quantities with units and validation ### Testing Structure (`tests/`) diff --git a/adr/012-public-private-api-convention.md b/adr/012-public-private-api-convention.md new file mode 100644 index 000000000..cd8054c78 --- /dev/null +++ b/adr/012-public-private-api-convention.md @@ -0,0 +1,56 @@ +# ADR-012: Public vs. internal API convention + +**Status**: Accepted + +**Date**: 2026-05-14 + +## Context + +ModelSkill had accumulated an inconsistent privacy convention. Leading underscores appeared on both **file names** (`_comparison.py`, `_misc.py`, `_timeseries.py`) and **function or constant names** (`_get_name`, `_parse_metric`, `_validate_data_var_name`), with no written rule about what each level meant or how the two interacted. Symptoms: + +- `from modelskill.timeseries._timeseries import _validate_data_var_name` — privacy claimed at both the path segment *and* the function name, while the function was imported from 4+ other files across subpackages. +- `timeseries/__init__.py` placed five underscore-prefixed names in `__all__`, simultaneously declaring them "officially public" and "private." +- Pyright (Pylance) flagged `reportPrivateUsage` on every cross-module `_foo` import, generating ambient warnings during development. +- No written policy distinguished "what users can rely on" from "what the package uses internally." + +Three candidate definitions of "public" were considered: + +- **Strict**: only names in `modelskill.__all__` at the top level. +- **Documented**: anything appearing in `docs/api/`. +- **Conventional**: anything reachable as `modelskill.<...>` without a leading `_` in any path segment. + +The conventional definition was chosen because the team did not want documentation gaps to silently shrink the public API: a function reachable via a non-underscored import path could already have been adopted by users, even if undocumented. + +## Decision + +**Privacy is set by the module path, not the function name.** A name is private if any segment of its import path starts with `_`. The only mechanically enforced rule is `PLC2701`: no module may import a name starting with `_` from another module. A leading `_` on a function name is otherwise a convention with no enforced meaning — authors may use it as a hint that a function is file-local. Inside a private module, where the module path already carries the privacy signal, this hint is optional. + +Examples: + +- `modelskill.timeseries._timeseries.validate_data_var_name` — private (module path). +- `modelskill._names.get_name` — private (module path). +- `modelskill.metrics._format_directional_label` (hypothetical) — private (function name); used only inside `metrics.py`. +- `modelskill.matching.match` — public. + +This is the same convention used by scikit-learn (`sklearn.utils._param_validation.validate_params`). + +The convention is mechanically enforced by Ruff rule `PLC2701` (`import-private-name`), which flags `from x import _foo` style imports. After the renames applied with this ADR, the codebase has zero PLC2701 violations in `src/`. Tests are allowed to import private names via a per-file ignore, since white-box testing of internals is a legitimate need. + +## Alternatives Considered + +**Snapshot-test the public surface.** A test that asserts `dir(modelskill)` matches a frozen list catches accidental public additions but requires updating the expected list on every intentional public change. Rejected as ongoing maintenance cost for limited additional coverage beyond what Ruff and PR review already provide. + +**Introduce a `modelskill/_internal/` subpackage.** Concentrating all cross-cutting internal helpers in one location would make "shared internal API" structurally obvious. Rejected as pure churn: the helpers already live in natural subpackage homes, and moving them does not change the privacy story under this convention. + +**Rename `utils.py` → `_utils.py` and rely on a deprecation shim.** Would break legitimately-public functions (`rename_coords_xr`, `rename_coords_pd`, `make_unique_index`) that are imported from non-underscored paths internally and could already be used by downstream consumers. The actual cross-cutting *private* helpers (`_get_name`, `_get_idx`, `_RESERVED_NAMES`) were split out into a new `modelskill/_names.py` module instead, leaving `utils.py` as a public module. The regression helper that previously lived as `metrics._linear_regression` was split into its own single-purpose private module `modelskill/_regression.py`. Naming the new module `_names.py` rather than `_utils.py` is deliberate: it describes the concern (reserved names plus name/index resolution) instead of a contentless role. + +**Keep the status quo and configure Pyright to silence `reportPrivateUsage`.** Rejected because the warning is genuinely useful for catching future drift; removing the noise by suppression would also remove the signal. + +## Consequences + +- **Cross-module use of internal helpers is no longer flagged.** `from modelskill._names import get_name` is unambiguously legal; the module path says "internal." Pyright's `reportPrivateUsage` falls silent. +- **Function-name underscores remain a convention.** `PLC2701` enforces that no `_foo` may be imported across modules. Beyond that, underscores carry no formal meaning: authors may use them as a navigation hint that a function is file-local, in either public or private modules. +- **Future contributors have a clear default.** New internal helpers go in `_modules`, not in `_underscored_names` inside public modules. PRs that put internal helpers in `metrics.py` or `utils.py` with a leading underscore now stand out as the exception rather than the norm. +- **One breaking change is accepted.** Five underscored names were previously re-exported from `modelskill.timeseries` via `__all__` (`_parse_track_input`, etc.). These names are no longer reachable from `modelskill.timeseries`; the renamed functions live in `modelskill.timeseries._point`, `_track`, `_vertical`. By Python convention, leading-underscore names were never part of the stable API, so this is consistent with the new policy. +- **Tests retain access to internals** via a `tests/**/*.py` per-file ignore for `PLC2701`. White-box testing of private modules remains supported. +- **`metrics.py` no longer hosts `_parse_metric` or `_linear_regression`.** They moved to `comparison/_utils.py:parse_metric` and `_regression.py:linear_regression` respectively. The first is consumer-specific to the comparison subpackage; the second is a general regression helper used by both `metrics.lin_slope` and `plotting._scatter`. diff --git a/adr/README.md b/adr/README.md index eeff80d44..499df951a 100644 --- a/adr/README.md +++ b/adr/README.md @@ -30,6 +30,7 @@ Each ADR follows this structure: - [ADR-009](009-factory-pattern.md) - Factory pattern for type detection - [ADR-010](010-optional-domain-dependencies.md) - Optional dependencies for domain-specific model types (Draft) - [ADR-011](011-vertical-pre-extracted-columns.md) - VerticalModelResult ingests pre-extracted columns +- [ADR-012](012-public-private-api-convention.md) - Public vs. internal API convention ## Contributing diff --git a/pyproject.toml b/pyproject.toml index c7a71b275..37aa26e47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,13 @@ extend-exclude = ["notebooks"] [tool.ruff.lint] ignore = ["E501"] -select = ["E4", "E7", "E9", "F", "D200", "D205"] +select = ["E4", "E7", "E9", "F", "D200", "D205", "PLC2701"] +preview = true +explicit-preview-rules = true + +[tool.ruff.lint.per-file-ignores] +# Tests legitimately exercise package internals. +"tests/**/*.py" = ["PLC2701"] [tool.mypy] python_version = "3.10" diff --git a/src/modelskill/_names.py b/src/modelskill/_names.py new file mode 100644 index 000000000..bead22254 --- /dev/null +++ b/src/modelskill/_names.py @@ -0,0 +1,50 @@ +"""Reserved coordinate/variable names and name-or-index resolution. + +Centralises the names ModelSkill reserves on its xarray data structures and +the small resolver used wherever a user supplies a name or positional index +against a list of valid names. Private per ADR-012. +""" + +from __future__ import annotations +from typing import Sequence, cast +from collections.abc import Hashable + + +RESERVED_COORD_NAMES = ["x", "y", "z", "time"] +RESERVED_COMPARER_VAR_NAMES = [*RESERVED_COORD_NAMES, "Observation"] + + +def get_name(x: int | str | None, valid_names: Sequence[Hashable]) -> str: + """Parse name/idx from list of valid names (e.g. obs from obs_names), return name.""" + return cast(str, valid_names[get_idx(x, valid_names)]) + + +def get_idx(x: int | str | None, valid_names: Sequence[Hashable]) -> int: + """Parse name/idx from list of valid names (e.g. obs from obs_names), return idx.""" + + if x is None: + if len(valid_names) == 1: + return 0 + else: + raise ValueError( + f"Multiple items available. Must specify name or index. Available items: {valid_names}" + ) + + n = len(valid_names) + if n == 0: + raise ValueError(f"Cannot select {x} from empty list!") + elif isinstance(x, str): + if x in valid_names: + idx = valid_names.index(x) + else: + raise KeyError(f"Name {x} could not be found in {valid_names}") + elif isinstance(x, int): + if x < 0: + x += n + if x >= 0 and x < n: + idx = x + else: + raise IndexError(f"Id {x} is out of range for {valid_names}") + else: + raise TypeError(f"Input {x} invalid! Must be None, str or int, not {type(x)}") + return idx diff --git a/src/modelskill/_regression.py b/src/modelskill/_regression.py new file mode 100644 index 000000000..2c550b18c --- /dev/null +++ b/src/modelskill/_regression.py @@ -0,0 +1,40 @@ +"""Linear regression helper shared by ``metrics.lin_slope`` and the scatter trendline. + +Private per ADR-012: the leading underscore on the module name signals internal use. +""" + +from __future__ import annotations +from typing import Tuple + +import numpy as np +from numpy.typing import ArrayLike + + +def linear_regression( + obs: ArrayLike, model: ArrayLike, reg_method: str = "ols" +) -> Tuple[float, float]: + """Fit a linear regression of ``model`` against ``obs`` and return (slope, intercept).""" + if len(obs) == 0: # type: ignore[arg-type] + return np.nan, np.nan + + if reg_method == "ols": + from scipy.stats import linregress + + reg = linregress(obs, model) + intercept = reg.intercept + slope = reg.slope + elif reg_method == "odr": + from scipy import odr + + data = odr.Data(obs, model) + odr_obj = odr.ODR(data, odr.unilinear) + output = odr_obj.run() + + intercept = output.beta[1] + slope = output.beta[0] + else: + raise NotImplementedError( + f"Regression method: {reg_method} not implemented, select 'ols' or 'odr'" + ) + + return slope, intercept diff --git a/src/modelskill/comparison/_collection.py b/src/modelskill/comparison/_collection.py index fd4023509..4a4d5ad27 100644 --- a/src/modelskill/comparison/_collection.py +++ b/src/modelskill/comparison/_collection.py @@ -28,13 +28,13 @@ from ..skill import SkillTable from ..skill_grid import SkillGrid -from ..utils import _get_name +from .._names import get_name from ._comparison import Comparer -from ..metrics import _parse_metric from ._utils import ( - _add_spatial_grid_to_df, - _groupby_df, - _parse_groupby, + add_spatial_grid_to_df, + groupby_df, + parse_groupby, + parse_metric, IdxOrNameTypes, TimeTypes, ) @@ -230,7 +230,7 @@ def __getitem__( return ComparerCollection([self[i] for i in idxs]) if isinstance(x, int): - name = _get_name(x, self.obs_names) + name = get_name(x, self.obs_names) return self._comparers[name] if isinstance(x, Iterable): @@ -332,16 +332,16 @@ def sel( models = [model] else: models = list(model) - mod_names: List[str] = [_get_name(m, self.mod_names) for m in models] + mod_names: List[str] = [get_name(m, self.mod_names) for m in models] if observation is None: observation = self.obs_names else: observation = [observation] if np.isscalar(observation) else observation # type: ignore - observation = [_get_name(o, self.obs_names) for o in observation] # type: ignore + observation = [get_name(o, self.obs_names) for o in observation] # type: ignore if (quantity is not None) and (self.n_quantities > 1): quantity = [quantity] if np.isscalar(quantity) else quantity # type: ignore - quantity = [_get_name(v, self.quantity_names) for v in quantity] # type: ignore + quantity = [get_name(v, self.quantity_names) for v in quantity] # type: ignore else: quantity = self.quantity_names @@ -482,14 +482,14 @@ def skill( """ cc = self - pmetrics = _parse_metric(metrics) + pmetrics = parse_metric(metrics) - agg_cols = _parse_groupby(by, n_mod=cc.n_models, n_qnt=cc.n_quantities) + agg_cols = parse_groupby(by, n_mod=cc.n_models, n_qnt=cc.n_quantities) agg_cols, attrs_keys = self._attrs_keys_in_by(agg_cols) df = cc._to_long_dataframe(attrs_keys=attrs_keys, observed=observed) - res = _groupby_df(df, by=agg_cols, metrics=pmetrics) + res = groupby_df(df, by=agg_cols, metrics=pmetrics) mtr_cols = [m.__name__ for m in pmetrics] # type: ignore res = res.dropna(subset=mtr_cols, how="all") # TODO: ok to remove empty? res = self._append_xy_to_res(res, cc) @@ -634,19 +634,19 @@ def gridded_skill( """ cmp = self - metrics = _parse_metric(metrics) + metrics = parse_metric(metrics) df = cmp._to_long_dataframe() - df = _add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) + df = add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) - agg_cols = _parse_groupby(by, n_mod=cmp.n_models, n_qnt=cmp.n_quantities) + agg_cols = parse_groupby(by, n_mod=cmp.n_models, n_qnt=cmp.n_quantities) if "x" not in agg_cols: agg_cols.insert(0, "x") if "y" not in agg_cols: agg_cols.insert(0, "y") df = df.drop(columns=["x", "y"]).rename(columns=dict(xBin="x", yBin="y")) - res = _groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) + res = groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) ds = res.to_xarray().squeeze() # change categorial index to coordinates @@ -724,7 +724,7 @@ def mean_skill( case _: raise ValueError("Invalid weights specification") - pmetrics = _parse_metric(metrics) + pmetrics = parse_metric(metrics) skilldf = ( self.skill(metrics=pmetrics) .to_dataframe() @@ -809,7 +809,7 @@ def score( {'mod': 8.414442957854142} """ - metric = _parse_metric(metric)[0] + metric = parse_metric(metric)[0] if not (callable(metric) or isinstance(metric, str)): raise ValueError("metric must be a string or a function") diff --git a/src/modelskill/comparison/_collection_plotter.py b/src/modelskill/comparison/_collection_plotter.py index b1df0cd2f..9960d7b1a 100644 --- a/src/modelskill/comparison/_collection_plotter.py +++ b/src/modelskill/comparison/_collection_plotter.py @@ -23,9 +23,9 @@ from .. import metrics as mtr from ..plotting import TaylorPoint, scatter, taylor_diagram -from ..plotting._misc import _get_fig_ax, _xtick_directional, _ytick_directional +from ..plotting._misc import get_fig_ax, xtick_directional, ytick_directional from ..settings import options -from ..utils import _get_idx +from .._names import get_idx from ._comparer_plotter import quantiles_xy @@ -271,8 +271,8 @@ def _scatter_one_model( ) if backend == "matplotlib" and self.is_directional: - _xtick_directional(ax, xlim) - _ytick_directional(ax, ylim) + xtick_directional(ax, xlim) + ytick_directional(ax, ylim) return ax @@ -302,7 +302,7 @@ def kde(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: >>> cc.plot.kde(bw_method='silverman') """ - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = self.cc._to_long_dataframe() ax = df.obs_val.plot.kde( @@ -334,7 +334,7 @@ def kde(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: ax.spines["left"].set_visible(False) if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -417,12 +417,12 @@ def _hist_one_model( ): from ._comparison import MOD_COLORS - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) assert ( mod_name in self.cc.mod_names ), f"Model {mod_name} not found in collection" - mod_idx = _get_idx(mod_name, self.cc.mod_names) + mod_idx = get_idx(mod_name, self.cc.mod_names) title = ( _default_univarate_title("Histogram", self.cc) if title is None else title @@ -450,7 +450,7 @@ def _hist_one_model( ax.set_ylabel("count") if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -566,7 +566,7 @@ def box(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: >>> cc.plot.box(showmeans=True) >>> cc.plot.box(ax=ax, title="Box plot") """ - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = self.cc._to_long_dataframe() @@ -595,7 +595,7 @@ def box(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: ax.set_title(title) if self.is_directional: - _ytick_directional(ax) + ytick_directional(ax) return ax @@ -638,7 +638,7 @@ def qq( """ cc = self.cc - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = cc._to_long_dataframe() @@ -684,8 +684,8 @@ def qq( ax.set_title(title) if self.is_directional: - _xtick_directional(ax) - _ytick_directional(ax) + xtick_directional(ax) + ytick_directional(ax) return ax @@ -756,7 +756,7 @@ def _residual_hist_one_model( **kwargs, ) -> Axes: """Residual histogram for one model only""" - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = self.cc.sel(model=mod_name)._to_long_dataframe() residuals = df.mod_val.values - df.obs_val.values diff --git a/src/modelskill/comparison/_comparer_plotter.py b/src/modelskill/comparison/_comparer_plotter.py index d226252e2..d7458ba89 100644 --- a/src/modelskill/comparison/_comparer_plotter.py +++ b/src/modelskill/comparison/_comparer_plotter.py @@ -18,12 +18,12 @@ import numpy as np # type: ignore from .. import metrics as mtr -from ..utils import _get_idx +from .._names import get_idx import matplotlib.colors as colors from ..plotting._misc import ( - _get_fig_ax, - _xtick_directional, - _ytick_directional, + get_fig_ax, + xtick_directional, + ytick_directional, quantiles_xy, ) from ..plotting import taylor_diagram, scatter, TaylorPoint @@ -93,7 +93,7 @@ def timeseries( title = cmp.name if backend == "matplotlib": - fig, ax = _get_fig_ax(ax, figsize) + fig, ax = get_fig_ax(ax, figsize) for j in range(cmp.n_models): key = cmp.mod_names[j] mod = cmp.raw_mod_data[key]._values_as_series @@ -109,7 +109,7 @@ def timeseries( ax.legend([*cmp.mod_names, cmp._obs_name]) ax.set_ylim(ylim) if self.is_directional: - _ytick_directional(ax, ylim) + ytick_directional(ax, ylim) ax.set_title(title) return ax @@ -226,11 +226,11 @@ def _hist_one_model( cmp = self.comparer assert mod_name in cmp.mod_names, f"Model {mod_name} not found in comparer" - mod_idx = _get_idx(mod_name, cmp.mod_names) + mod_idx = get_idx(mod_name, cmp.mod_names) title = f"{mod_name} vs {cmp.name}" if title is None else title - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) kwargs["alpha"] = alpha kwargs["density"] = density @@ -254,7 +254,7 @@ def _hist_one_model( ax.set_ylabel("count") if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -291,7 +291,7 @@ def kde(self, ax=None, title=None, figsize=None, **kwargs) -> matplotlib.axes.Ax """ cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) cmp.data.Observation.to_series().plot.kde( ax=ax, linestyle="dashed", label="Observation", **kwargs @@ -317,7 +317,7 @@ def kde(self, ax=None, title=None, figsize=None, **kwargs) -> matplotlib.axes.Ax ax.spines["left"].set_visible(False) if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -360,7 +360,7 @@ def qq( """ cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) x = cmp.data.Observation.values xmin, xmax = x.min(), x.max() @@ -404,8 +404,8 @@ def qq( ax.set_title(title or f"Q-Q plot for {cmp.name}") if self.is_directional: - _xtick_directional(ax) - _ytick_directional(ax) + xtick_directional(ax) + ytick_directional(ax) return ax @@ -442,7 +442,7 @@ def box(self, *, ax=None, title=None, figsize=None, **kwargs): """ cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) cols = ["Observation"] + cmp.mod_names df = cmp.data[cols].to_dataframe()[cols] @@ -451,7 +451,7 @@ def box(self, *, ax=None, title=None, figsize=None, **kwargs): ax.set_title(title or cmp.name) if self.is_directional: - _ytick_directional(ax) + ytick_directional(ax) return ax @@ -669,8 +669,8 @@ def _scatter_one_model( ) if backend == "matplotlib" and self.is_directional: - _xtick_directional(ax, xlim) - _ytick_directional(ax, ylim) + xtick_directional(ax, xlim) + ytick_directional(ax, ylim) return ax @@ -833,7 +833,7 @@ def _residual_hist_one_model( **kwargs, ) -> matplotlib.axes.Axes: """Residual histogram for one model only""" - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) default_color = "#8B8D8E" color = default_color if color is None else color diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index aed4cc27e..ef8b0adf4 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -28,21 +28,21 @@ from ..types import GeometryType from ..obs import PointObservation, TrackObservation, NodeObservation from ..model import PointModelResult, TrackModelResult, VerticalModelResult -from ..timeseries._timeseries import _normalize_time_to_ns, _validate_data_var_name +from ..timeseries._timeseries import normalize_time_to_ns, validate_data_var_name from ._comparer_plotter import ComparerPlotter -from ..metrics import _parse_metric from ._utils import ( - _add_spatial_grid_to_df, - _groupby_df, - _parse_groupby, + add_spatial_grid_to_df, + groupby_df, + parse_groupby, + parse_metric, TimeTypes, IdxOrNameTypes, ) from ..skill import SkillTable from ..skill_grid import SkillGrid from ..settings import register_option -from ..utils import _get_name, _RESERVED_NAMES +from .._names import get_name, RESERVED_COMPARER_VAR_NAMES from .. import __version__ from ._vertical_comparison import VerticalAccessor @@ -64,7 +64,7 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset: # matched_data = self._matched_data_to_xarray(matched_data) assert "Observation" in data.data_vars - data = _normalize_time_to_ns(data) + data = normalize_time_to_ns(data) # no missing values allowed in Observation if data["Observation"].isnull().any(): @@ -86,7 +86,7 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset: assert len(vars) > 1, "dataset must have at least two data arrays" for v in data.data_vars: - v = _validate_data_var_name(str(v)) + v = validate_data_var_name(str(v)) assert ( len(data[v].dims) == 1 ), f"Only 0-dimensional data arrays are supported! {v} has {len(data[v].dims)} dimensions" @@ -234,24 +234,24 @@ def parse( """ if len(items) < 2: raise ValueError("data must contain at least two items") - obs_name = _get_name(obs_item, items) if obs_item else items[0] + obs_name = get_name(obs_item, items) if obs_item else items[0] # Check existence of items and convert to names if aux_items is not None: if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_names = [_get_name(a, items) for a in aux_items] + aux_names = [get_name(a, items) for a in aux_items] else: aux_names = [] - x_name = _get_name(x_item, items) if x_item is not None else None - y_name = _get_name(y_item, items) if y_item is not None else None - z_name = _get_name(z_item, items) if z_item is not None else None + x_name = get_name(x_item, items) if x_item is not None else None + y_name = get_name(y_item, items) if y_item is not None else None + z_name = get_name(z_item, items) if z_item is not None else None if mod_items is not None: if isinstance(mod_items, (str, int)): mod_items = [mod_items] - mod_names = [_get_name(m, items) for m in mod_items] + mod_names = [get_name(m, items) for m in mod_items] else: # Add remaining items as model items mod_names = [ @@ -575,9 +575,9 @@ def name(self) -> str: @name.setter def name(self, name: str) -> None: - if name in _RESERVED_NAMES: + if name in RESERVED_COMPARER_VAR_NAMES: raise ValueError( - f"Cannot rename to any of {_RESERVED_NAMES}, these are reserved names!" + f"Cannot rename to any of {RESERVED_COMPARER_VAR_NAMES}, these are reserved names!" ) self.data.attrs["name"] = name @@ -748,10 +748,10 @@ def rename( # "ignore": silently remove keys that are not in allowed_keys mapping = {k: v for k, v in mapping.items() if k in allowed_keys} - if any([k in _RESERVED_NAMES for k in mapping.values()]): + if any([k in RESERVED_COMPARER_VAR_NAMES for k in mapping.values()]): # TODO: also check for duplicates raise ValueError( - f"Cannot rename to any of {_RESERVED_NAMES}, these are reserved names!" + f"Cannot rename to any of {RESERVED_COMPARER_VAR_NAMES}, these are reserved names!" ) # rename observation @@ -896,7 +896,7 @@ def sel( models = [model] else: models = list(model) - mod_names: List[str] = [_get_name(m, self.mod_names) for m in models] + mod_names: List[str] = [get_name(m, self.mod_names) for m in models] dropped_models = [m for m in self.mod_names if m not in mod_names] d = d.drop_vars(dropped_models) raw_mod_data = {m: raw_mod_data[m] for m in mod_names} @@ -1089,16 +1089,16 @@ def skill( 2017-10-28 0 NaN NaN NaN NaN NaN NaN NaN 2017-10-29 41 0.33 0.41 0.25 0.36 0.96 0.06 0.99 """ - metrics = _parse_metric(metrics, directional=self.quantity.is_directional) + metrics = parse_metric(metrics, directional=self.quantity.is_directional) cmp = self if cmp.n_points == 0: raise ValueError("No data selected for skill assessment") - by = _parse_groupby(by, n_mod=cmp.n_models, n_qnt=1) + by = parse_groupby(by, n_mod=cmp.n_models, n_qnt=1) df = cmp._to_long_dataframe() - res = _groupby_df(df, by=by, metrics=metrics) + res = groupby_df(df, by=by, metrics=metrics) res["x"] = np.nan if self.gtype == "track" or cmp.x is None else cmp.x res["y"] = np.nan if self.gtype == "track" or cmp.y is None else cmp.y res = self._add_as_col_if_not_in_index(df, skilldf=res) @@ -1150,7 +1150,7 @@ def score( >>> cmp.score(metric="mape") {'mod': 11.567399646108198} """ - metric = _parse_metric(metric)[0] + metric = parse_metric(metric)[0] if not (callable(metric) or isinstance(metric, str)): raise ValueError("metric must be a string or a function") @@ -1231,21 +1231,21 @@ def gridded_skill( cmp = self - metrics = _parse_metric(metrics) + metrics = parse_metric(metrics) if cmp.n_points == 0: raise ValueError("No data to compare") df = cmp._to_long_dataframe() - df = _add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) + df = add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) - agg_cols = _parse_groupby(by=by, n_mod=cmp.n_models, n_qnt=1) + agg_cols = parse_groupby(by=by, n_mod=cmp.n_models, n_qnt=1) if "x" not in agg_cols: agg_cols.insert(0, "x") if "y" not in agg_cols: agg_cols.insert(0, "y") df = df.drop(columns=["x", "y"]).rename(columns=dict(xBin="x", yBin="y")) - res = _groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) + res = groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) ds = res.to_xarray().squeeze() # change categorial index to coordinates diff --git a/src/modelskill/comparison/_utils.py b/src/modelskill/comparison/_utils.py index ff22c9b7a..2cf62e6f0 100644 --- a/src/modelskill/comparison/_utils.py +++ b/src/modelskill/comparison/_utils.py @@ -1,15 +1,55 @@ from __future__ import annotations +import inspect from typing import Callable, Iterable, List, Tuple, Union from datetime import datetime import numpy as np import pandas as pd +from ..metrics import default_circular_metrics, get_metric +from ..settings import options + TimeTypes = Union[str, np.datetime64, pd.Timestamp, datetime] IdxOrNameTypes = Union[int, str, List[int], List[str]] -def _add_spatial_grid_to_df( +def parse_metric( + metric: str | Iterable[str] | Callable | Iterable[Callable] | None, + *, + directional: bool = False, +) -> List[Callable]: + if metric is None: + if directional: + return default_circular_metrics + else: + # could be a list of str! + return [get_metric(m) for m in options.metrics.list] + + if isinstance(metric, str): + metrics: list = [metric] + elif callable(metric): + metrics = [metric] + elif isinstance(metric, Iterable): + metrics = list(metric) + + parsed_metrics = [] + + for m in metrics: + if isinstance(m, str): + parsed_metrics.append(get_metric(m)) + elif callable(m): + if len(inspect.signature(m).parameters) < 2: + raise ValueError( + "Metrics must have at least two arguments (obs, model)" + ) + parsed_metrics.append(m) + else: + raise TypeError(f"metric {m} must be a string or callable") + + return parsed_metrics + + +def add_spatial_grid_to_df( df: pd.DataFrame, bins, binsize: float | None ) -> pd.DataFrame: if binsize is None: @@ -53,7 +93,7 @@ def _add_spatial_grid_to_df( return df -def _groupby_df( +def groupby_df( df: pd.DataFrame, *, by: List[str | pd.Grouper], @@ -71,8 +111,8 @@ def calc_metrics(group: pd.DataFrame) -> pd.Series: row[metric.__name__] = metric(group.obs_val, group.mod_val) return pd.Series(row) - if _dt_in_by(by): - df, by = _add_dt_to_df(df, by) + if dt_in_by(by): + df, by = add_dt_to_df(df, by) # sort=False to avoid re-ordering compared to original cc (also for performance) res = df.groupby(by=by, observed=False, sort=False, group_keys=True)[ @@ -90,7 +130,7 @@ def calc_metrics(group: pd.DataFrame) -> pd.Series: return res -def _dt_in_by(by): +def dt_in_by(by): by = [by] if isinstance(by, str) else by if any(str(by).startswith("dt:") for by in by): return True @@ -114,7 +154,7 @@ def _dt_in_by(by): ] -def _add_dt_to_df(df: pd.DataFrame, by: List[str]) -> Tuple[pd.DataFrame, List[str]]: +def add_dt_to_df(df: pd.DataFrame, by: List[str]) -> Tuple[pd.DataFrame, List[str]]: ser = df["time"] assert isinstance(by, list) # by = [by] if isinstance(by, str) else by @@ -138,7 +178,7 @@ def _add_dt_to_df(df: pd.DataFrame, by: List[str]) -> Tuple[pd.DataFrame, List[s return df, by -def _parse_groupby( +def parse_groupby( by: str | Iterable[str] | None, *, n_mod: int, n_qnt: int ) -> List[str | pd.Grouper]: if by is None: diff --git a/src/modelskill/comparison/_vertical_comparison.py b/src/modelskill/comparison/_vertical_comparison.py index c6f2a6918..9ebfcfdda 100644 --- a/src/modelskill/comparison/_vertical_comparison.py +++ b/src/modelskill/comparison/_vertical_comparison.py @@ -5,13 +5,13 @@ from typing import TYPE_CHECKING, Callable, Iterable, Sequence, Tuple from ..types import GeometryType -from ..plotting._misc import _get_fig_ax +from ..plotting._misc import get_fig_ax import xarray as xr from matplotlib import dates as mdates from ..model import PointModelResult, TrackModelResult from ..model.network import NodeModelResult from ..model.vertical import VerticalModelResult -from ..metrics import _parse_metric +from ._utils import parse_metric from ..skill_profile import SkillProfile if TYPE_CHECKING: @@ -52,7 +52,7 @@ def profile( from ._comparison import MOD_COLORS cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) title = title if title is not None else cmp.name @@ -139,7 +139,7 @@ def hovmoller( cmp = self.comparer mod_name = self._get_model_name(model) - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) if title is None: title = f"{mod_name} and Observations" @@ -328,7 +328,7 @@ def skill( if cmp.n_points == 0: raise ValueError("No data selected for skill assessment") - metric_funcs = _parse_metric(metrics, directional=cmp.quantity.is_directional) + metric_funcs = parse_metric(metrics, directional=cmp.quantity.is_directional) def calculate_metric(g): obs = g[cmp._obs_name] diff --git a/src/modelskill/metrics.py b/src/modelskill/metrics.py index 5a6f008a0..b5a9f45e5 100644 --- a/src/modelskill/metrics.py +++ b/src/modelskill/metrics.py @@ -1,18 +1,15 @@ """Metrics for evaluating the difference between a model and an observation.""" from __future__ import annotations -import inspect import sys import warnings from typing import ( Any, Callable, - Iterable, List, Literal, Set, - Tuple, TypeVar, Union, ) @@ -22,7 +19,7 @@ from numpy.typing import ArrayLike from scipy import stats -from .settings import options +from ._regression import linear_regression defined_metrics: Set[str] = set() metrics_with_units: Set[str] = set() @@ -722,36 +719,7 @@ def lin_slope(obs: ArrayLike, model: ArrayLike, reg_method="ols") -> Any: Range: $(-\infty, \infty )$; Best: 1 """ assert obs.size == model.size - return _linear_regression(obs, model, reg_method)[0] - - -def _linear_regression( - obs: ArrayLike, model: ArrayLike, reg_method="ols" -) -> Tuple[float, float]: - if len(obs) == 0: - return np.nan, np.nan # TODO raise error? - - if reg_method == "ols": - from scipy.stats import linregress as _linregress - - reg = _linregress(obs, model) - intercept = reg.intercept - slope = reg.slope - elif reg_method == "odr": - from scipy import odr - - data = odr.Data(obs, model) - odr_obj = odr.ODR(data, odr.unilinear) - output = odr_obj.run() - - intercept = output.beta[1] - slope = output.beta[0] - else: - raise NotImplementedError( - f"Regression method: {reg_method} not implemented, select 'ols' or 'odr'" - ) - - return slope, intercept + return linear_regression(obs, model, reg_method)[0] def _std_obs(obs: ArrayLike, model: ArrayLike) -> Any: @@ -1164,42 +1132,6 @@ def get_metric(metric: Union[str, Callable]) -> Callable: ) -def _parse_metric( - metric: str | Iterable[str] | Callable | Iterable[Callable] | None, - *, - directional: bool = False, -) -> List[Callable]: - if metric is None: - if directional: - return default_circular_metrics - else: - # could be a list of str! - return [get_metric(m) for m in options.metrics.list] - - if isinstance(metric, str): - metrics: list = [metric] - elif callable(metric): - metrics = [metric] - elif isinstance(metric, Iterable): - metrics = list(metric) - - parsed_metrics = [] - - for m in metrics: - if isinstance(m, str): - parsed_metrics.append(get_metric(m)) - elif callable(m): - if len(inspect.signature(m).parameters) < 2: - raise ValueError( - "Metrics must have at least two arguments (obs, model)" - ) - parsed_metrics.append(m) - else: - raise TypeError(f"metric {m} must be a string or callable") - - return parsed_metrics - - def is_best(metric: str, expected: str | int) -> bool: try: func = get_metric(metric) diff --git a/src/modelskill/model/_base.py b/src/modelskill/model/_base.py index 23de28e51..15a51fe91 100644 --- a/src/modelskill/model/_base.py +++ b/src/modelskill/model/_base.py @@ -12,7 +12,7 @@ from .track import TrackModelResult from .vertical import VerticalModelResult -from ..utils import _get_name +from .._names import get_name from ..obs import Observation, PointObservation, TrackObservation, VerticalObservation @@ -48,10 +48,10 @@ def _parse_items( f"Input has more than 1 item, but item was not given! Available items: {avail_items}" ) - item = _get_name(item, valid_names=avail_items) + item = get_name(item, valid_names=avail_items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=avail_items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=avail_items) for i in aux_items or []] # check that there are no duplicates res = SelectedItems(values=item, aux=aux_items_str) @@ -63,7 +63,7 @@ def _parse_items( return res -def _validate_overlap_in_time(time: pd.DatetimeIndex, observation: Observation) -> None: +def validate_overlap_in_time(time: pd.DatetimeIndex, observation: Observation) -> None: overlap_in_time = ( time[0] <= observation.time[-1] and time[-1] >= observation.time[0] ) diff --git a/src/modelskill/model/dfsu.py b/src/modelskill/model/dfsu.py index 7fec7bb4e..a8db8e8e0 100644 --- a/src/modelskill/model/dfsu.py +++ b/src/modelskill/model/dfsu.py @@ -7,10 +7,10 @@ import numpy as np import pandas as pd -from ._base import SpatialField, _validate_overlap_in_time, SelectedItems +from ._base import SpatialField, validate_overlap_in_time, SelectedItems from ..types import UnstructuredType from ..quantity import Quantity -from ..utils import _get_idx +from .._names import get_idx from .point import PointModelResult from .track import TrackModelResult from .vertical import VerticalModelResult @@ -78,7 +78,7 @@ def __init__( data, (mikeio.dfsu.Dfsu2DH, mikeio.dfsu.Dfsu3D, mikeio.Dataset) ): item_names = [i.name for i in data.items] - idx = _get_idx(x=item, valid_names=item_names) + idx = get_idx(x=item, valid_names=item_names) item_info = data.items[idx] self.sel_items = SelectedItems.parse( @@ -137,7 +137,7 @@ def extract( """ method = self._parse_spatial_method(spatial_method) - _validate_overlap_in_time(self.time, observation) + validate_overlap_in_time(self.time, observation) if isinstance(observation, PointObservation): return self._extract_point(observation, spatial_method=method) elif isinstance(observation, TrackObservation): diff --git a/src/modelskill/model/grid.py b/src/modelskill/model/grid.py index 8d1a35178..b81ec9d8e 100644 --- a/src/modelskill/model/grid.py +++ b/src/modelskill/model/grid.py @@ -6,7 +6,7 @@ import pandas as pd import xarray as xr -from ._base import SpatialField, _validate_overlap_in_time, SelectedItems +from ._base import SpatialField, validate_overlap_in_time, SelectedItems from ..utils import rename_coords_xr, rename_coords_pd from ..types import GridType from ..quantity import Quantity @@ -148,7 +148,7 @@ def extract( raise NotImplementedError( "Extraction of VerticalObservation from GridModelResult is not implemented yet." ) - _validate_overlap_in_time(self.time, observation) + validate_overlap_in_time(self.time, observation) if isinstance(observation, PointObservation): return self._extract_point(observation, spatial_method) elif isinstance(observation, TrackObservation): diff --git a/src/modelskill/model/network.py b/src/modelskill/model/network.py index 17c62035c..4df6006ce 100644 --- a/src/modelskill/model/network.py +++ b/src/modelskill/model/network.py @@ -7,7 +7,8 @@ import pandas as pd import xarray as xr -from modelskill.timeseries import TimeSeries, _parse_network_node_input +from modelskill.timeseries import TimeSeries +from modelskill.timeseries._point import parse_network_node_input from ._base import SelectedItems from ..obs import NodeObservation, ReachObservation from ..quantity import Quantity @@ -58,7 +59,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ): if not self._is_input_validated(data): - data = _parse_network_node_input( + data = parse_network_node_input( data, name=name, item=item, diff --git a/src/modelskill/model/point.py b/src/modelskill/model/point.py index 4e071c4e9..5441438be 100644 --- a/src/modelskill/model/point.py +++ b/src/modelskill/model/point.py @@ -6,7 +6,8 @@ from ..obs import Observation from ..types import PointType from ..quantity import Quantity -from ..timeseries import TimeSeries, _parse_xyz_point_input +from ..timeseries import TimeSeries +from ..timeseries._point import parse_xyz_point_input from ..timeseries._align import align_data @@ -52,7 +53,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_xyz_point_input( + data = parse_xyz_point_input( data, name=name, item=item, diff --git a/src/modelskill/model/track.py b/src/modelskill/model/track.py index 612cb0063..9543e43bc 100644 --- a/src/modelskill/model/track.py +++ b/src/modelskill/model/track.py @@ -8,7 +8,8 @@ from ..types import TrackType from ..obs import TrackObservation from ..quantity import Quantity -from ..timeseries import TimeSeries, _parse_track_input +from ..timeseries import TimeSeries +from ..timeseries._track import parse_track_input class TrackModelResult(TimeSeries): @@ -55,7 +56,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_track_input( + data = parse_track_input( data=data, name=name, item=item, diff --git a/src/modelskill/model/vertical.py b/src/modelskill/model/vertical.py index b1e529ccb..abd46a525 100644 --- a/src/modelskill/model/vertical.py +++ b/src/modelskill/model/vertical.py @@ -7,7 +7,8 @@ from ..types import VerticalType from ..quantity import Quantity -from ..timeseries import TimeSeries, _parse_vertical_input +from ..timeseries import TimeSeries +from ..timeseries._vertical import parse_vertical_input from ..obs import VerticalObservation @@ -58,7 +59,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_vertical_input( + data = parse_vertical_input( data=data, name=name, item=item, diff --git a/src/modelskill/network.py b/src/modelskill/network.py index 071bc01f6..7b08c1a22 100644 --- a/src/modelskill/network.py +++ b/src/modelskill/network.py @@ -445,6 +445,7 @@ def from_res1d( list_of_reaches = cls._load_res1d_network(res, nodes_list, reaches_list) return cls(list_of_reaches) + @staticmethod def _load_res1d_network( res: Res1D, @@ -491,7 +492,9 @@ def _generate_alias_map(g: nx.Graph) -> dict[str | tuple[str, float], int]: return {g.nodes[id]["alias"]: id for id in g.nodes()} @staticmethod - def _generate_reaches_dict(reaches: Sequence[NetworkReach]) -> dict[str, NetworkReach]: + def _generate_reaches_dict( + reaches: Sequence[NetworkReach], + ) -> dict[str, NetworkReach]: return {r.id: r for r in reaches} @staticmethod diff --git a/src/modelskill/obs.py b/src/modelskill/obs.py index e4f653bbe..1325915b6 100644 --- a/src/modelskill/obs.py +++ b/src/modelskill/obs.py @@ -21,14 +21,14 @@ from .types import PointType, TrackType, VerticalType, GeometryType, DataInputType from . import Quantity -from .timeseries import ( - TimeSeries, - _parse_xyz_point_input, - _parse_track_input, - _parse_vertical_input, - _parse_network_node_input, - _parse_network_breakpoint_input, +from .timeseries import TimeSeries +from .timeseries._point import ( + parse_xyz_point_input, + parse_network_node_input, + parse_network_breakpoint_input, ) +from .timeseries._track import parse_track_input +from .timeseries._vertical import parse_vertical_input # NetCDF attributes can only be str, int, float https://unidata.github.io/netcdf4-python/#attributes-in-a-netcdf-file @@ -233,7 +233,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_xyz_point_input( + data = parse_xyz_point_input( data, name=name, item=item, @@ -352,7 +352,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_track_input( + data = parse_track_input( data=data, name=name, item=item, @@ -451,7 +451,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_vertical_input( + data = parse_vertical_input( data, name=name, item=item, @@ -549,7 +549,7 @@ def __init__( if isinstance(at, tuple): reach, distance = str(at[0]), float(at[1]) if not self._is_input_validated(data): - data = _parse_network_breakpoint_input( + data = parse_network_breakpoint_input( data, name=name, item=item, @@ -560,7 +560,7 @@ def __init__( ) else: if not self._is_input_validated(data): - data = _parse_network_node_input( + data = parse_network_node_input( data, name=name, item=item, @@ -747,7 +747,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_network_breakpoint_input( + data = parse_network_breakpoint_input( data, name=name, item=item, diff --git a/src/modelskill/plotting/_misc.py b/src/modelskill/plotting/_misc.py index e41a8e214..cc6d08ed4 100644 --- a/src/modelskill/plotting/_misc.py +++ b/src/modelskill/plotting/_misc.py @@ -11,13 +11,13 @@ from ..obs import unit_display_name -def _get_ax(ax=None, figsize=None): +def get_ax(ax=None, figsize=None): if ax is None: _, ax = plt.subplots(figsize=figsize) return ax -def _get_fig_ax(ax: Axes | None = None, figsize=None): +def get_fig_ax(ax: Axes | None = None, figsize=None): if ax is None: fig, ax = plt.subplots(figsize=figsize) else: @@ -25,7 +25,7 @@ def _get_fig_ax(ax: Axes | None = None, figsize=None): return fig, ax -def _xtick_directional(ax, xlim=None): +def xtick_directional(ax, xlim=None): """Set x-ticks for directional data""" ticks = ticks = _xyticks(lim=xlim) if len(ticks) > 2: @@ -34,7 +34,7 @@ def _xtick_directional(ax, xlim=None): ax.set_xlim(0, 360) -def _ytick_directional(ax, ylim=None): +def ytick_directional(ax, ylim=None): """Set y-ticks for directional data""" ticks = _xyticks(lim=ylim) if len(ticks) > 2: diff --git a/src/modelskill/plotting/_scatter.py b/src/modelskill/plotting/_scatter.py index e8d5c0600..564481476 100644 --- a/src/modelskill/plotting/_scatter.py +++ b/src/modelskill/plotting/_scatter.py @@ -17,8 +17,8 @@ import modelskill.settings as settings from modelskill.settings import options -from ..metrics import _linear_regression -from ._misc import quantiles_xy, sample_points, format_skill_table, _get_fig_ax +from .._regression import linear_regression +from ._misc import quantiles_xy, sample_points, format_skill_table, get_fig_ax def scatter( @@ -283,7 +283,7 @@ def _scatter_matplotlib( cmap=None, **kwargs, ) -> matplotlib.axes.Axes: - fig, ax = _get_fig_ax(ax, figsize) + fig, ax = get_fig_ax(ax, figsize) if len(x) < 2: raise ValueError("Not enough data to plot. At least 2 points are required.") @@ -334,11 +334,11 @@ def _scatter_matplotlib( if reg_method: if fit_to_quantiles: - slope, intercept = _linear_regression( + slope, intercept = linear_regression( obs=xq, model=yq, reg_method=reg_method ) else: - slope, intercept = _linear_regression(obs=x, model=y, reg_method=reg_method) + slope, intercept = linear_regression(obs=x, model=y, reg_method=reg_method) ax.plot( x_trend, @@ -456,11 +456,11 @@ def _scatter_plotly( if reg_method: if fit_to_quantiles: - slope, intercept = _linear_regression( + slope, intercept = linear_regression( obs=xq, model=yq, reg_method=reg_method ) else: - slope, intercept = _linear_regression(obs=x, model=y, reg_method=reg_method) + slope, intercept = linear_regression(obs=x, model=y, reg_method=reg_method) regression_line = go.Scatter( x=x_trend, diff --git a/src/modelskill/plotting/_spatial_overview.py b/src/modelskill/plotting/_spatial_overview.py index 739369147..c1f0df5d1 100644 --- a/src/modelskill/plotting/_spatial_overview.py +++ b/src/modelskill/plotting/_spatial_overview.py @@ -10,7 +10,7 @@ from ..model.track import TrackModelResult from ..model.vertical import VerticalModelResult from ..obs import Observation, PointObservation, TrackObservation, VerticalObservation -from ._misc import _get_ax +from ._misc import get_ax def spatial_overview( @@ -66,7 +66,7 @@ def spatial_overview( obs = [] if obs is None else list(obs) if isinstance(obs, Iterable) else [obs] # type: ignore mods = [] if mod is None else list(mod) if isinstance(mod, Iterable) else [mod] # type: ignore - ax = _get_ax(ax=ax, figsize=figsize) + ax = get_ax(ax=ax, figsize=figsize) # TODO: support Gridded ModelResults for m in mods: diff --git a/src/modelskill/plotting/_temporal_coverage.py b/src/modelskill/plotting/_temporal_coverage.py index a861728a5..5b5bfc639 100644 --- a/src/modelskill/plotting/_temporal_coverage.py +++ b/src/modelskill/plotting/_temporal_coverage.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import numpy as np -from ._misc import _get_fig_ax +from ._misc import get_fig_ax def temporal_coverage( @@ -79,7 +79,7 @@ def temporal_coverage( ysize = max(2.0, 0.45 * n_lines) figsize = (7, ysize) - fig, ax = _get_fig_ax(ax=ax, figsize=figsize) + fig, ax = get_fig_ax(ax=ax, figsize=figsize) y = np.repeat(0.0, 2) labels = [] diff --git a/src/modelskill/skill.py b/src/modelskill/skill.py index d16bda8dd..eeeb67a65 100644 --- a/src/modelskill/skill.py +++ b/src/modelskill/skill.py @@ -9,7 +9,7 @@ from matplotlib.axes import Axes from matplotlib.colors import Colormap -from .plotting._misc import _get_fig_ax +from .plotting._misc import get_fig_ax from .metrics import small_is_best, large_is_best, zero_is_best, one_is_best @@ -249,7 +249,7 @@ def grid( if figsize is None: figsize = (nx, ny) - fig, ax = _get_fig_ax(ax, figsize) + fig, ax = get_fig_ax(ax, figsize) assert ax is not None pcm = ax.pcolormesh(df, cmap=cmap, vmin=vmin, vmax=vmax) ax.set_xticks(np.arange(nx) + 0.5) diff --git a/src/modelskill/timeseries/__init__.py b/src/modelskill/timeseries/__init__.py index f52f17271..457a755cd 100644 --- a/src/modelskill/timeseries/__init__.py +++ b/src/modelskill/timeseries/__init__.py @@ -1,17 +1,3 @@ from ._timeseries import TimeSeries -from ._point import ( - _parse_xyz_point_input, - _parse_network_node_input, - _parse_network_breakpoint_input, -) -from ._track import _parse_track_input -from ._vertical import _parse_vertical_input -__all__ = [ - "TimeSeries", - "_parse_xyz_point_input", - "_parse_track_input", - "_parse_vertical_input", - "_parse_network_node_input", - "_parse_network_breakpoint_input", -] +__all__ = ["TimeSeries"] diff --git a/src/modelskill/timeseries/_point.py b/src/modelskill/timeseries/_point.py index f0742f4dc..798eddbcd 100644 --- a/src/modelskill/timeseries/_point.py +++ b/src/modelskill/timeseries/_point.py @@ -10,8 +10,8 @@ from ..types import GeometryType, PointType from ..quantity import Quantity -from ..utils import _get_name -from ._timeseries import _normalize_time_to_ns, _validate_data_var_name +from .._names import get_name +from ._timeseries import normalize_time_to_ns, validate_data_var_name from ._coords import XYZCoords, NodeCoords, ReachCoords @@ -39,10 +39,10 @@ def _parse_point_items( f"Input has more than 1 item, but item was not given! Available items: {items}" ) - item = _get_name(item, valid_names=items) + item = get_name(item, valid_names=items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=items) for i in aux_items or []] # check that there are no duplicates res = PointItem(values=item, aux=aux_items_str) @@ -121,9 +121,9 @@ def _convert_to_dataset( data = data.rename({time_dim_name: "time"}) ds = data - ds = _normalize_time_to_ns(ds) + ds = normalize_time_to_ns(ds) - name = _validate_data_var_name(varname) + name = validate_data_var_name(varname) n_unique_times = len(ds.time.to_index().unique()) if n_unique_times < len(ds.time): @@ -266,7 +266,7 @@ def _parse_point_input( return ds -def _parse_xyz_point_input( +def parse_xyz_point_input( data: PointType, name: str | None, item: str | int | None, @@ -281,7 +281,7 @@ def _parse_xyz_point_input( return ds -def _parse_network_node_input( +def parse_network_node_input( data: PointType, name: str | None, item: str | int | None, @@ -296,7 +296,7 @@ def _parse_network_node_input( return ds -def _parse_network_breakpoint_input( +def parse_network_breakpoint_input( data: PointType, name: str | None, item: str | int | None, diff --git a/src/modelskill/timeseries/_timeseries.py b/src/modelskill/timeseries/_timeseries.py index bba5d78f7..d840e54b3 100644 --- a/src/modelskill/timeseries/_timeseries.py +++ b/src/modelskill/timeseries/_timeseries.py @@ -10,6 +10,7 @@ from ..types import GeometryType from ..quantity import Quantity +from .._names import RESERVED_COORD_NAMES from ._plotter import TimeSeriesPlotter, MatplotlibTimeSeriesPlotter from .. import __version__ @@ -28,18 +29,17 @@ ] -def _validate_data_var_name(name: str) -> str: +def validate_data_var_name(name: str) -> str: if not isinstance(name, str): raise TypeError("name must be a string") - RESERVED_NAMES = ["x", "y", "z", "time"] - if name in RESERVED_NAMES: + if name in RESERVED_COORD_NAMES: raise ValueError( f"name '{name}' is reserved and cannot be used! Please choose another name." ) return name -def _normalize_time_to_ns(ds: xr.Dataset) -> xr.Dataset: +def normalize_time_to_ns(ds: xr.Dataset) -> xr.Dataset: """Cast a dataset's time coordinate to ``datetime64[ns]``. Under pandas 3.0 the default datetime resolution is no longer nanoseconds, @@ -128,7 +128,7 @@ def _validate_dataset(ds: xr.Dataset) -> xr.Dataset: name = "" n_primary = 0 for v in vars: - v = _validate_data_var_name(str(v)) + v = validate_data_var_name(str(v)) assert ( len(ds[v].dims) == 1 ), f"Only 0-dimensional data arrays are supported! {v} has {len(ds[v].dims)} dimensions" @@ -204,7 +204,7 @@ def name(self) -> str: @name.setter def name(self, name: str) -> None: - name = _validate_data_var_name(name) + name = validate_data_var_name(name) self.data = self.data.rename({self._val_item: name}) @property diff --git a/src/modelskill/timeseries/_track.py b/src/modelskill/timeseries/_track.py index 9dcee6dea..d12c4f969 100644 --- a/src/modelskill/timeseries/_track.py +++ b/src/modelskill/timeseries/_track.py @@ -11,8 +11,9 @@ from ..types import GeometryType, TrackType from ..quantity import Quantity -from ..utils import _get_name, make_unique_index -from ._timeseries import _validate_data_var_name +from .._names import get_name +from ..utils import make_unique_index +from ._timeseries import validate_data_var_name @dataclass @@ -47,12 +48,12 @@ def _parse_track_items( f"Input has more than 3 items, but item was not given! Available items: {items}" ) - item = _get_name(item, valid_names=items) - x_item = _get_name(x_item, valid_names=items) - y_item = _get_name(y_item, valid_names=items) + item = get_name(item, valid_names=items) + x_item = get_name(x_item, valid_names=items) + y_item = get_name(y_item, valid_names=items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=items) for i in aux_items or []] # check that there are no duplicates res = TrackItem(x=x_item, y=y_item, values=item, aux=aux_items_str) @@ -62,7 +63,7 @@ def _parse_track_items( return res -def _parse_track_input( +def parse_track_input( data: TrackType, name: str | None, item: str | int | None, @@ -100,7 +101,7 @@ def _parse_track_input( valid_items, x_item=x_item, y_item=y_item, item=item, aux_items=aux_items ) name = name or sel_items.values - name = _validate_data_var_name(name) + name = validate_data_var_name(name) # parse quantity if isinstance(data, mikeio.Dataset): diff --git a/src/modelskill/timeseries/_vertical.py b/src/modelskill/timeseries/_vertical.py index 57226ad58..a684cd14a 100644 --- a/src/modelskill/timeseries/_vertical.py +++ b/src/modelskill/timeseries/_vertical.py @@ -13,8 +13,8 @@ from ..types import GeometryType, VerticalType from ..quantity import Quantity -from ..utils import _get_name -from ._timeseries import _validate_data_var_name +from .._names import get_name +from ._timeseries import validate_data_var_name @dataclass @@ -47,11 +47,11 @@ def _parse_vertical_items( f"Input has more than 2 items, but item was not given! Available items: {items}" ) - item = _get_name(item, valid_names=items) - z_item = _get_name(z_item, valid_names=items) + item = get_name(item, valid_names=items) + z_item = get_name(z_item, valid_names=items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=items) for i in aux_items or []] # check that there are no duplicates res = VerticalItem(z=z_item, values=item, aux=aux_items_str) @@ -80,7 +80,7 @@ def _include_location( return ds -def _parse_vertical_input( +def parse_vertical_input( data: VerticalType, name: Optional[str], item: str | int | None, @@ -118,7 +118,7 @@ def _parse_vertical_input( ) name = name or sel_items.values - name = _validate_data_var_name(name) + name = validate_data_var_name(name) # parse quantity if isinstance(data, mikeio.Dataset): diff --git a/src/modelskill/utils.py b/src/modelskill/utils.py index 43435dd93..f21c3cff0 100644 --- a/src/modelskill/utils.py +++ b/src/modelskill/utils.py @@ -1,12 +1,9 @@ from __future__ import annotations -from typing import Sequence, cast import warnings import numpy as np import pandas as pd import xarray as xr -from collections.abc import Hashable, Iterable - -_RESERVED_NAMES = ["Observation", "time", "x", "y", "z"] +from collections.abc import Iterable POS_COORDINATE_NAME_MAPPING = { "lon": "x", @@ -156,39 +153,3 @@ def make_unique_index( new_index = df_index + offset_in_ns * tmp assert isinstance(new_index, pd.DatetimeIndex) return new_index - - -def _get_name(x: int | str | None, valid_names: Sequence[Hashable]) -> str: - """Parse name/idx from list of valid names (e.g. obs from obs_names), return name""" - return cast(str, valid_names[_get_idx(x, valid_names)]) - - -def _get_idx(x: int | str | None, valid_names: Sequence[Hashable]) -> int: - """Parse name/idx from list of valid names (e.g. obs from obs_names), return idx""" - - if x is None: - if len(valid_names) == 1: - return 0 - else: - raise ValueError( - f"Multiple items available. Must specify name or index. Available items: {valid_names}" - ) - - n = len(valid_names) - if n == 0: - raise ValueError(f"Cannot select {x} from empty list!") - elif isinstance(x, str): - if x in valid_names: - idx = valid_names.index(x) - else: - raise KeyError(f"Name {x} could not be found in {valid_names}") - elif isinstance(x, int): - if x < 0: # Handle negative indices - x += n - if x >= 0 and x < n: - idx = x - else: - raise IndexError(f"Id {x} is out of range for {valid_names}") - else: - raise TypeError(f"Input {x} invalid! Must be None, str or int, not {type(x)}") - return idx