From 1a65a5ab8542150e5c91197fd4285fd05f5620ec Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 14 May 2026 12:48:32 +0200 Subject: [PATCH 1/3] Adopt module-path privacy convention (ADR-012) Function names inside private modules drop the leading underscore; the module path already signals "internal." Cross-cutting helpers that lived in public modules (utils.py, metrics.py) move into a new modelskill/_utils.py or into comparison/_utils.py so they don't accidentally become public surface. Pyright's reportPrivateUsage warnings disappear; Ruff PLC2701 now enforces the rule in CI. See adr/012-public-private-api-convention.md for the policy and rationale. --- CLAUDE.md | 16 +++- adr/012-public-private-api-convention.md | 56 +++++++++++++ adr/README.md | 1 + pyproject.toml | 8 +- src/modelskill/_utils.py | 82 +++++++++++++++++++ src/modelskill/comparison/_collection.py | 36 ++++---- .../comparison/_collection_plotter.py | 30 +++---- .../comparison/_comparer_plotter.py | 38 ++++----- src/modelskill/comparison/_comparison.py | 54 ++++++------ src/modelskill/comparison/_utils.py | 54 ++++++++++-- .../comparison/_vertical_comparison.py | 10 +-- src/modelskill/metrics.py | 72 +--------------- src/modelskill/model/_base.py | 8 +- src/modelskill/model/dfsu.py | 8 +- src/modelskill/model/grid.py | 4 +- src/modelskill/model/network.py | 5 +- src/modelskill/model/point.py | 5 +- src/modelskill/model/track.py | 5 +- src/modelskill/model/vertical.py | 5 +- src/modelskill/network.py | 5 +- src/modelskill/obs.py | 26 +++--- src/modelskill/plotting/_misc.py | 8 +- src/modelskill/plotting/_scatter.py | 14 ++-- src/modelskill/plotting/_spatial_overview.py | 4 +- src/modelskill/plotting/_temporal_coverage.py | 4 +- src/modelskill/skill.py | 4 +- src/modelskill/timeseries/__init__.py | 16 +--- src/modelskill/timeseries/_point.py | 18 ++-- src/modelskill/timeseries/_timeseries.py | 8 +- src/modelskill/timeseries/_track.py | 17 ++-- src/modelskill/timeseries/_vertical.py | 14 ++-- src/modelskill/utils.py | 41 +--------- 32 files changed, 381 insertions(+), 295 deletions(-) create mode 100644 adr/012-public-private-api-convention.md create mode 100644 src/modelskill/_utils.py diff --git a/CLAUDE.md b/CLAUDE.md index 92294dbe7..69ec4327e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -54,6 +54,20 @@ This updates `roadmap/README.md` from the feature frontmatter. ## Coding Conventions +### Public vs. Internal API + +A name is **private** if any segment of its import path starts with `_`. +Functions inside private modules use plain names (no leading `_`) — the path +already carries the privacy signal. Reserve leading-`_` function names for +symbols that live in a public module but are used in only one file. + +- Public: `modelskill.matching.match`, `modelskill.utils.rename_coords_xr` +- Private (module path): `modelskill._utils.get_name`, `modelskill.timeseries._timeseries.validate_data_var_name` +- File-private: a `_helper` defined and used only inside `metrics.py` + +See [ADR-012](adr/012-public-private-api-convention.md). Enforced in CI by +Ruff rule `PLC2701` (zero violations in `src/`; tests have a per-file ignore). + ### Docstrings - All docstrings use **NumPy format** (not Google or reStructuredText style) - Include sections: Parameters, Returns, Raises, Examples, See Also, Notes as appropriate @@ -169,7 +183,7 @@ Plots support both matplotlib (static) and plotly (interactive) backends. - Internal data storage uses xarray Datasets with standardized coordinate/variable names - Time coordinates use pandas datetime64 - Spatial coordinates: `x`, `y` (and `z` when applicable) -- Reserved names in `_RESERVED_NAMES` should not be used for model/observation names +- Reserved names in `modelskill._utils.RESERVED_NAMES` should not be used for model/observation names - The `Quantity` class handles physical quantities with units and validation ### Testing Structure (`tests/`) diff --git a/adr/012-public-private-api-convention.md b/adr/012-public-private-api-convention.md new file mode 100644 index 000000000..a7a18b567 --- /dev/null +++ b/adr/012-public-private-api-convention.md @@ -0,0 +1,56 @@ +# ADR-012: Public vs. internal API convention + +**Status**: Accepted + +**Date**: 2026-05-14 + +## Context + +ModelSkill had accumulated an inconsistent privacy convention. Leading underscores appeared on both **file names** (`_comparison.py`, `_misc.py`, `_timeseries.py`) and **function or constant names** (`_get_name`, `_parse_metric`, `_validate_data_var_name`), with no written rule about what each level meant or how the two interacted. Symptoms: + +- `from modelskill.timeseries._timeseries import _validate_data_var_name` — privacy claimed at both the path segment *and* the function name, while the function was imported from 4+ other files across subpackages. +- `timeseries/__init__.py` placed five underscore-prefixed names in `__all__`, simultaneously declaring them "officially public" and "private." +- Pyright (Pylance) flagged `reportPrivateUsage` on every cross-module `_foo` import, generating ambient warnings during development. +- No written policy distinguished "what users can rely on" from "what the package uses internally." + +Three candidate definitions of "public" were considered: + +- **Strict**: only names in `modelskill.__all__` at the top level. +- **Documented**: anything appearing in `docs/api/`. +- **Conventional**: anything reachable as `modelskill.<...>` without a leading `_` in any path segment. + +The conventional definition was chosen because the team did not want documentation gaps to silently shrink the public API: a function reachable via a non-underscored import path could already have been adopted by users, even if undocumented. + +## Decision + +**Privacy is set by the module path, not the function name.** A name is private if any segment of its import path starts with `_`. Functions inside private modules use plain names (no leading `_`); the path already carries the privacy signal. A leading `_` on a function or class name is reserved for symbols that live in a public module but are file-local — i.e., used only inside that one file. + +Examples: + +- `modelskill.timeseries._timeseries.validate_data_var_name` — private (module path). +- `modelskill._utils.get_name` — private (module path). +- `modelskill.metrics._format_directional_label` (hypothetical) — private (function name); used only inside `metrics.py`. +- `modelskill.matching.match` — public. + +This is the same convention used by scikit-learn (`sklearn.utils._param_validation.validate_params`). + +The convention is mechanically enforced by Ruff rule `PLC2701` (`import-private-name`), which flags `from x import _foo` style imports. After the renames applied with this ADR, the codebase has zero PLC2701 violations in `src/`. Tests are allowed to import private names via a per-file ignore, since white-box testing of internals is a legitimate need. + +## Alternatives Considered + +**Snapshot-test the public surface.** A test that asserts `dir(modelskill)` matches a frozen list catches accidental public additions but requires updating the expected list on every intentional public change. Rejected as ongoing maintenance cost for limited additional coverage beyond what Ruff and PR review already provide. + +**Introduce a `modelskill/_internal/` subpackage.** Concentrating all cross-cutting internal helpers in one location would make "shared internal API" structurally obvious. Rejected as pure churn: the helpers already live in natural subpackage homes, and moving them does not change the privacy story under this convention. + +**Rename `utils.py` → `_utils.py` and rely on a deprecation shim.** Would break legitimately-public functions (`rename_coords_xr`, `rename_coords_pd`, `make_unique_index`) that are imported from non-underscored paths internally and could already be used by downstream consumers. The actual cross-cutting *private* helpers (`_get_name`, `_get_idx`, `_RESERVED_NAMES`) were split out into a new `modelskill/_utils.py` module instead, leaving `utils.py` as a public module. + +**Keep the status quo and configure Pyright to silence `reportPrivateUsage`.** Rejected because the warning is genuinely useful for catching future drift; removing the noise by suppression would also remove the signal. + +## Consequences + +- **Cross-module use of internal helpers is no longer flagged.** `from modelskill._utils import get_name` is unambiguously legal; the module path says "internal." Pyright's `reportPrivateUsage` falls silent. +- **Function-name underscores carry a sharper meaning.** A `_function_name` now signals "private to this one file" — never "private to this subpackage." Reviewers can use that signal at face value. +- **Future contributors have a clear default.** New internal helpers go in `_modules`, not in `_underscored_names` inside public modules. PRs that put internal helpers in `metrics.py` or `utils.py` with a leading underscore now stand out as the exception rather than the norm. +- **One breaking change is accepted.** Five underscored names were previously re-exported from `modelskill.timeseries` via `__all__` (`_parse_track_input`, etc.). These names are no longer reachable from `modelskill.timeseries`; the renamed functions live in `modelskill.timeseries._point`, `_track`, `_vertical`. By Python convention, leading-underscore names were never part of the stable API, so this is consistent with the new policy. +- **Tests retain access to internals** via a `tests/**/*.py` per-file ignore for `PLC2701`. White-box testing of private modules remains supported. +- **`metrics.py` no longer hosts `_parse_metric` or `_linear_regression`.** They moved to `comparison/_utils.py:parse_metric` and `_utils.py:linear_regression` respectively. The first is consumer-specific to the comparison subpackage; the second is a general regression helper used by both `metrics.lin_slope` and `plotting._scatter`. diff --git a/adr/README.md b/adr/README.md index eeff80d44..499df951a 100644 --- a/adr/README.md +++ b/adr/README.md @@ -30,6 +30,7 @@ Each ADR follows this structure: - [ADR-009](009-factory-pattern.md) - Factory pattern for type detection - [ADR-010](010-optional-domain-dependencies.md) - Optional dependencies for domain-specific model types (Draft) - [ADR-011](011-vertical-pre-extracted-columns.md) - VerticalModelResult ingests pre-extracted columns +- [ADR-012](012-public-private-api-convention.md) - Public vs. internal API convention ## Contributing diff --git a/pyproject.toml b/pyproject.toml index c7a71b275..37aa26e47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,13 @@ extend-exclude = ["notebooks"] [tool.ruff.lint] ignore = ["E501"] -select = ["E4", "E7", "E9", "F", "D200", "D205"] +select = ["E4", "E7", "E9", "F", "D200", "D205", "PLC2701"] +preview = true +explicit-preview-rules = true + +[tool.ruff.lint.per-file-ignores] +# Tests legitimately exercise package internals. +"tests/**/*.py" = ["PLC2701"] [tool.mypy] python_version = "3.10" diff --git a/src/modelskill/_utils.py b/src/modelskill/_utils.py new file mode 100644 index 000000000..d9363b63e --- /dev/null +++ b/src/modelskill/_utils.py @@ -0,0 +1,82 @@ +"""Package-internal helpers shared across modelskill subpackages. + +The leading underscore on the module name signals that this is internal API: +modelskill itself imports freely from here, but downstream consumers must not. +See ADR-012 for the public/private convention. +""" + +from __future__ import annotations +from typing import Sequence, Tuple, cast +from collections.abc import Hashable + +import numpy as np +from numpy.typing import ArrayLike + + +RESERVED_NAMES = ["Observation", "time", "x", "y", "z"] + + +def get_name(x: int | str | None, valid_names: Sequence[Hashable]) -> str: + """Parse name/idx from list of valid names (e.g. obs from obs_names), return name.""" + return cast(str, valid_names[get_idx(x, valid_names)]) + + +def get_idx(x: int | str | None, valid_names: Sequence[Hashable]) -> int: + """Parse name/idx from list of valid names (e.g. obs from obs_names), return idx.""" + + if x is None: + if len(valid_names) == 1: + return 0 + else: + raise ValueError( + f"Multiple items available. Must specify name or index. Available items: {valid_names}" + ) + + n = len(valid_names) + if n == 0: + raise ValueError(f"Cannot select {x} from empty list!") + elif isinstance(x, str): + if x in valid_names: + idx = valid_names.index(x) + else: + raise KeyError(f"Name {x} could not be found in {valid_names}") + elif isinstance(x, int): + if x < 0: + x += n + if x >= 0 and x < n: + idx = x + else: + raise IndexError(f"Id {x} is out of range for {valid_names}") + else: + raise TypeError(f"Input {x} invalid! Must be None, str or int, not {type(x)}") + return idx + + +def linear_regression( + obs: ArrayLike, model: ArrayLike, reg_method: str = "ols" +) -> Tuple[float, float]: + """Fit a linear regression of ``model`` against ``obs`` and return (slope, intercept).""" + if len(obs) == 0: # type: ignore[arg-type] + return np.nan, np.nan + + if reg_method == "ols": + from scipy.stats import linregress + + reg = linregress(obs, model) + intercept = reg.intercept + slope = reg.slope + elif reg_method == "odr": + from scipy import odr + + data = odr.Data(obs, model) + odr_obj = odr.ODR(data, odr.unilinear) + output = odr_obj.run() + + intercept = output.beta[1] + slope = output.beta[0] + else: + raise NotImplementedError( + f"Regression method: {reg_method} not implemented, select 'ols' or 'odr'" + ) + + return slope, intercept diff --git a/src/modelskill/comparison/_collection.py b/src/modelskill/comparison/_collection.py index fd4023509..74592299a 100644 --- a/src/modelskill/comparison/_collection.py +++ b/src/modelskill/comparison/_collection.py @@ -28,13 +28,13 @@ from ..skill import SkillTable from ..skill_grid import SkillGrid -from ..utils import _get_name +from .._utils import get_name from ._comparison import Comparer -from ..metrics import _parse_metric from ._utils import ( - _add_spatial_grid_to_df, - _groupby_df, - _parse_groupby, + add_spatial_grid_to_df, + groupby_df, + parse_groupby, + parse_metric, IdxOrNameTypes, TimeTypes, ) @@ -230,7 +230,7 @@ def __getitem__( return ComparerCollection([self[i] for i in idxs]) if isinstance(x, int): - name = _get_name(x, self.obs_names) + name = get_name(x, self.obs_names) return self._comparers[name] if isinstance(x, Iterable): @@ -332,16 +332,16 @@ def sel( models = [model] else: models = list(model) - mod_names: List[str] = [_get_name(m, self.mod_names) for m in models] + mod_names: List[str] = [get_name(m, self.mod_names) for m in models] if observation is None: observation = self.obs_names else: observation = [observation] if np.isscalar(observation) else observation # type: ignore - observation = [_get_name(o, self.obs_names) for o in observation] # type: ignore + observation = [get_name(o, self.obs_names) for o in observation] # type: ignore if (quantity is not None) and (self.n_quantities > 1): quantity = [quantity] if np.isscalar(quantity) else quantity # type: ignore - quantity = [_get_name(v, self.quantity_names) for v in quantity] # type: ignore + quantity = [get_name(v, self.quantity_names) for v in quantity] # type: ignore else: quantity = self.quantity_names @@ -482,14 +482,14 @@ def skill( """ cc = self - pmetrics = _parse_metric(metrics) + pmetrics = parse_metric(metrics) - agg_cols = _parse_groupby(by, n_mod=cc.n_models, n_qnt=cc.n_quantities) + agg_cols = parse_groupby(by, n_mod=cc.n_models, n_qnt=cc.n_quantities) agg_cols, attrs_keys = self._attrs_keys_in_by(agg_cols) df = cc._to_long_dataframe(attrs_keys=attrs_keys, observed=observed) - res = _groupby_df(df, by=agg_cols, metrics=pmetrics) + res = groupby_df(df, by=agg_cols, metrics=pmetrics) mtr_cols = [m.__name__ for m in pmetrics] # type: ignore res = res.dropna(subset=mtr_cols, how="all") # TODO: ok to remove empty? res = self._append_xy_to_res(res, cc) @@ -634,19 +634,19 @@ def gridded_skill( """ cmp = self - metrics = _parse_metric(metrics) + metrics = parse_metric(metrics) df = cmp._to_long_dataframe() - df = _add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) + df = add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) - agg_cols = _parse_groupby(by, n_mod=cmp.n_models, n_qnt=cmp.n_quantities) + agg_cols = parse_groupby(by, n_mod=cmp.n_models, n_qnt=cmp.n_quantities) if "x" not in agg_cols: agg_cols.insert(0, "x") if "y" not in agg_cols: agg_cols.insert(0, "y") df = df.drop(columns=["x", "y"]).rename(columns=dict(xBin="x", yBin="y")) - res = _groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) + res = groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) ds = res.to_xarray().squeeze() # change categorial index to coordinates @@ -724,7 +724,7 @@ def mean_skill( case _: raise ValueError("Invalid weights specification") - pmetrics = _parse_metric(metrics) + pmetrics = parse_metric(metrics) skilldf = ( self.skill(metrics=pmetrics) .to_dataframe() @@ -809,7 +809,7 @@ def score( {'mod': 8.414442957854142} """ - metric = _parse_metric(metric)[0] + metric = parse_metric(metric)[0] if not (callable(metric) or isinstance(metric, str)): raise ValueError("metric must be a string or a function") diff --git a/src/modelskill/comparison/_collection_plotter.py b/src/modelskill/comparison/_collection_plotter.py index b1df0cd2f..d7b0af405 100644 --- a/src/modelskill/comparison/_collection_plotter.py +++ b/src/modelskill/comparison/_collection_plotter.py @@ -23,9 +23,9 @@ from .. import metrics as mtr from ..plotting import TaylorPoint, scatter, taylor_diagram -from ..plotting._misc import _get_fig_ax, _xtick_directional, _ytick_directional +from ..plotting._misc import get_fig_ax, xtick_directional, ytick_directional from ..settings import options -from ..utils import _get_idx +from .._utils import get_idx from ._comparer_plotter import quantiles_xy @@ -271,8 +271,8 @@ def _scatter_one_model( ) if backend == "matplotlib" and self.is_directional: - _xtick_directional(ax, xlim) - _ytick_directional(ax, ylim) + xtick_directional(ax, xlim) + ytick_directional(ax, ylim) return ax @@ -302,7 +302,7 @@ def kde(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: >>> cc.plot.kde(bw_method='silverman') """ - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = self.cc._to_long_dataframe() ax = df.obs_val.plot.kde( @@ -334,7 +334,7 @@ def kde(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: ax.spines["left"].set_visible(False) if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -417,12 +417,12 @@ def _hist_one_model( ): from ._comparison import MOD_COLORS - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) assert ( mod_name in self.cc.mod_names ), f"Model {mod_name} not found in collection" - mod_idx = _get_idx(mod_name, self.cc.mod_names) + mod_idx = get_idx(mod_name, self.cc.mod_names) title = ( _default_univarate_title("Histogram", self.cc) if title is None else title @@ -450,7 +450,7 @@ def _hist_one_model( ax.set_ylabel("count") if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -566,7 +566,7 @@ def box(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: >>> cc.plot.box(showmeans=True) >>> cc.plot.box(ax=ax, title="Box plot") """ - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = self.cc._to_long_dataframe() @@ -595,7 +595,7 @@ def box(self, *, ax=None, figsize=None, title=None, **kwargs) -> Axes: ax.set_title(title) if self.is_directional: - _ytick_directional(ax) + ytick_directional(ax) return ax @@ -638,7 +638,7 @@ def qq( """ cc = self.cc - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = cc._to_long_dataframe() @@ -684,8 +684,8 @@ def qq( ax.set_title(title) if self.is_directional: - _xtick_directional(ax) - _ytick_directional(ax) + xtick_directional(ax) + ytick_directional(ax) return ax @@ -756,7 +756,7 @@ def _residual_hist_one_model( **kwargs, ) -> Axes: """Residual histogram for one model only""" - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) df = self.cc.sel(model=mod_name)._to_long_dataframe() residuals = df.mod_val.values - df.obs_val.values diff --git a/src/modelskill/comparison/_comparer_plotter.py b/src/modelskill/comparison/_comparer_plotter.py index d226252e2..7ec75486e 100644 --- a/src/modelskill/comparison/_comparer_plotter.py +++ b/src/modelskill/comparison/_comparer_plotter.py @@ -18,12 +18,12 @@ import numpy as np # type: ignore from .. import metrics as mtr -from ..utils import _get_idx +from .._utils import get_idx import matplotlib.colors as colors from ..plotting._misc import ( - _get_fig_ax, - _xtick_directional, - _ytick_directional, + get_fig_ax, + xtick_directional, + ytick_directional, quantiles_xy, ) from ..plotting import taylor_diagram, scatter, TaylorPoint @@ -93,7 +93,7 @@ def timeseries( title = cmp.name if backend == "matplotlib": - fig, ax = _get_fig_ax(ax, figsize) + fig, ax = get_fig_ax(ax, figsize) for j in range(cmp.n_models): key = cmp.mod_names[j] mod = cmp.raw_mod_data[key]._values_as_series @@ -109,7 +109,7 @@ def timeseries( ax.legend([*cmp.mod_names, cmp._obs_name]) ax.set_ylim(ylim) if self.is_directional: - _ytick_directional(ax, ylim) + ytick_directional(ax, ylim) ax.set_title(title) return ax @@ -226,11 +226,11 @@ def _hist_one_model( cmp = self.comparer assert mod_name in cmp.mod_names, f"Model {mod_name} not found in comparer" - mod_idx = _get_idx(mod_name, cmp.mod_names) + mod_idx = get_idx(mod_name, cmp.mod_names) title = f"{mod_name} vs {cmp.name}" if title is None else title - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) kwargs["alpha"] = alpha kwargs["density"] = density @@ -254,7 +254,7 @@ def _hist_one_model( ax.set_ylabel("count") if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -291,7 +291,7 @@ def kde(self, ax=None, title=None, figsize=None, **kwargs) -> matplotlib.axes.Ax """ cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) cmp.data.Observation.to_series().plot.kde( ax=ax, linestyle="dashed", label="Observation", **kwargs @@ -317,7 +317,7 @@ def kde(self, ax=None, title=None, figsize=None, **kwargs) -> matplotlib.axes.Ax ax.spines["left"].set_visible(False) if self.is_directional: - _xtick_directional(ax) + xtick_directional(ax) return ax @@ -360,7 +360,7 @@ def qq( """ cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) x = cmp.data.Observation.values xmin, xmax = x.min(), x.max() @@ -404,8 +404,8 @@ def qq( ax.set_title(title or f"Q-Q plot for {cmp.name}") if self.is_directional: - _xtick_directional(ax) - _ytick_directional(ax) + xtick_directional(ax) + ytick_directional(ax) return ax @@ -442,7 +442,7 @@ def box(self, *, ax=None, title=None, figsize=None, **kwargs): """ cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) cols = ["Observation"] + cmp.mod_names df = cmp.data[cols].to_dataframe()[cols] @@ -451,7 +451,7 @@ def box(self, *, ax=None, title=None, figsize=None, **kwargs): ax.set_title(title or cmp.name) if self.is_directional: - _ytick_directional(ax) + ytick_directional(ax) return ax @@ -669,8 +669,8 @@ def _scatter_one_model( ) if backend == "matplotlib" and self.is_directional: - _xtick_directional(ax, xlim) - _ytick_directional(ax, ylim) + xtick_directional(ax, xlim) + ytick_directional(ax, ylim) return ax @@ -833,7 +833,7 @@ def _residual_hist_one_model( **kwargs, ) -> matplotlib.axes.Axes: """Residual histogram for one model only""" - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) default_color = "#8B8D8E" color = default_color if color is None else color diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index aed4cc27e..701972594 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -28,21 +28,21 @@ from ..types import GeometryType from ..obs import PointObservation, TrackObservation, NodeObservation from ..model import PointModelResult, TrackModelResult, VerticalModelResult -from ..timeseries._timeseries import _normalize_time_to_ns, _validate_data_var_name +from ..timeseries._timeseries import normalize_time_to_ns, validate_data_var_name from ._comparer_plotter import ComparerPlotter -from ..metrics import _parse_metric from ._utils import ( - _add_spatial_grid_to_df, - _groupby_df, - _parse_groupby, + add_spatial_grid_to_df, + groupby_df, + parse_groupby, + parse_metric, TimeTypes, IdxOrNameTypes, ) from ..skill import SkillTable from ..skill_grid import SkillGrid from ..settings import register_option -from ..utils import _get_name, _RESERVED_NAMES +from .._utils import get_name, RESERVED_NAMES from .. import __version__ from ._vertical_comparison import VerticalAccessor @@ -64,7 +64,7 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset: # matched_data = self._matched_data_to_xarray(matched_data) assert "Observation" in data.data_vars - data = _normalize_time_to_ns(data) + data = normalize_time_to_ns(data) # no missing values allowed in Observation if data["Observation"].isnull().any(): @@ -86,7 +86,7 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset: assert len(vars) > 1, "dataset must have at least two data arrays" for v in data.data_vars: - v = _validate_data_var_name(str(v)) + v = validate_data_var_name(str(v)) assert ( len(data[v].dims) == 1 ), f"Only 0-dimensional data arrays are supported! {v} has {len(data[v].dims)} dimensions" @@ -234,24 +234,24 @@ def parse( """ if len(items) < 2: raise ValueError("data must contain at least two items") - obs_name = _get_name(obs_item, items) if obs_item else items[0] + obs_name = get_name(obs_item, items) if obs_item else items[0] # Check existence of items and convert to names if aux_items is not None: if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_names = [_get_name(a, items) for a in aux_items] + aux_names = [get_name(a, items) for a in aux_items] else: aux_names = [] - x_name = _get_name(x_item, items) if x_item is not None else None - y_name = _get_name(y_item, items) if y_item is not None else None - z_name = _get_name(z_item, items) if z_item is not None else None + x_name = get_name(x_item, items) if x_item is not None else None + y_name = get_name(y_item, items) if y_item is not None else None + z_name = get_name(z_item, items) if z_item is not None else None if mod_items is not None: if isinstance(mod_items, (str, int)): mod_items = [mod_items] - mod_names = [_get_name(m, items) for m in mod_items] + mod_names = [get_name(m, items) for m in mod_items] else: # Add remaining items as model items mod_names = [ @@ -575,9 +575,9 @@ def name(self) -> str: @name.setter def name(self, name: str) -> None: - if name in _RESERVED_NAMES: + if name in RESERVED_NAMES: raise ValueError( - f"Cannot rename to any of {_RESERVED_NAMES}, these are reserved names!" + f"Cannot rename to any of {RESERVED_NAMES}, these are reserved names!" ) self.data.attrs["name"] = name @@ -748,10 +748,10 @@ def rename( # "ignore": silently remove keys that are not in allowed_keys mapping = {k: v for k, v in mapping.items() if k in allowed_keys} - if any([k in _RESERVED_NAMES for k in mapping.values()]): + if any([k in RESERVED_NAMES for k in mapping.values()]): # TODO: also check for duplicates raise ValueError( - f"Cannot rename to any of {_RESERVED_NAMES}, these are reserved names!" + f"Cannot rename to any of {RESERVED_NAMES}, these are reserved names!" ) # rename observation @@ -896,7 +896,7 @@ def sel( models = [model] else: models = list(model) - mod_names: List[str] = [_get_name(m, self.mod_names) for m in models] + mod_names: List[str] = [get_name(m, self.mod_names) for m in models] dropped_models = [m for m in self.mod_names if m not in mod_names] d = d.drop_vars(dropped_models) raw_mod_data = {m: raw_mod_data[m] for m in mod_names} @@ -1089,16 +1089,16 @@ def skill( 2017-10-28 0 NaN NaN NaN NaN NaN NaN NaN 2017-10-29 41 0.33 0.41 0.25 0.36 0.96 0.06 0.99 """ - metrics = _parse_metric(metrics, directional=self.quantity.is_directional) + metrics = parse_metric(metrics, directional=self.quantity.is_directional) cmp = self if cmp.n_points == 0: raise ValueError("No data selected for skill assessment") - by = _parse_groupby(by, n_mod=cmp.n_models, n_qnt=1) + by = parse_groupby(by, n_mod=cmp.n_models, n_qnt=1) df = cmp._to_long_dataframe() - res = _groupby_df(df, by=by, metrics=metrics) + res = groupby_df(df, by=by, metrics=metrics) res["x"] = np.nan if self.gtype == "track" or cmp.x is None else cmp.x res["y"] = np.nan if self.gtype == "track" or cmp.y is None else cmp.y res = self._add_as_col_if_not_in_index(df, skilldf=res) @@ -1150,7 +1150,7 @@ def score( >>> cmp.score(metric="mape") {'mod': 11.567399646108198} """ - metric = _parse_metric(metric)[0] + metric = parse_metric(metric)[0] if not (callable(metric) or isinstance(metric, str)): raise ValueError("metric must be a string or a function") @@ -1231,21 +1231,21 @@ def gridded_skill( cmp = self - metrics = _parse_metric(metrics) + metrics = parse_metric(metrics) if cmp.n_points == 0: raise ValueError("No data to compare") df = cmp._to_long_dataframe() - df = _add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) + df = add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize) - agg_cols = _parse_groupby(by=by, n_mod=cmp.n_models, n_qnt=1) + agg_cols = parse_groupby(by=by, n_mod=cmp.n_models, n_qnt=1) if "x" not in agg_cols: agg_cols.insert(0, "x") if "y" not in agg_cols: agg_cols.insert(0, "y") df = df.drop(columns=["x", "y"]).rename(columns=dict(xBin="x", yBin="y")) - res = _groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) + res = groupby_df(df, by=agg_cols, metrics=metrics, n_min=n_min) ds = res.to_xarray().squeeze() # change categorial index to coordinates diff --git a/src/modelskill/comparison/_utils.py b/src/modelskill/comparison/_utils.py index ff22c9b7a..2cf62e6f0 100644 --- a/src/modelskill/comparison/_utils.py +++ b/src/modelskill/comparison/_utils.py @@ -1,15 +1,55 @@ from __future__ import annotations +import inspect from typing import Callable, Iterable, List, Tuple, Union from datetime import datetime import numpy as np import pandas as pd +from ..metrics import default_circular_metrics, get_metric +from ..settings import options + TimeTypes = Union[str, np.datetime64, pd.Timestamp, datetime] IdxOrNameTypes = Union[int, str, List[int], List[str]] -def _add_spatial_grid_to_df( +def parse_metric( + metric: str | Iterable[str] | Callable | Iterable[Callable] | None, + *, + directional: bool = False, +) -> List[Callable]: + if metric is None: + if directional: + return default_circular_metrics + else: + # could be a list of str! + return [get_metric(m) for m in options.metrics.list] + + if isinstance(metric, str): + metrics: list = [metric] + elif callable(metric): + metrics = [metric] + elif isinstance(metric, Iterable): + metrics = list(metric) + + parsed_metrics = [] + + for m in metrics: + if isinstance(m, str): + parsed_metrics.append(get_metric(m)) + elif callable(m): + if len(inspect.signature(m).parameters) < 2: + raise ValueError( + "Metrics must have at least two arguments (obs, model)" + ) + parsed_metrics.append(m) + else: + raise TypeError(f"metric {m} must be a string or callable") + + return parsed_metrics + + +def add_spatial_grid_to_df( df: pd.DataFrame, bins, binsize: float | None ) -> pd.DataFrame: if binsize is None: @@ -53,7 +93,7 @@ def _add_spatial_grid_to_df( return df -def _groupby_df( +def groupby_df( df: pd.DataFrame, *, by: List[str | pd.Grouper], @@ -71,8 +111,8 @@ def calc_metrics(group: pd.DataFrame) -> pd.Series: row[metric.__name__] = metric(group.obs_val, group.mod_val) return pd.Series(row) - if _dt_in_by(by): - df, by = _add_dt_to_df(df, by) + if dt_in_by(by): + df, by = add_dt_to_df(df, by) # sort=False to avoid re-ordering compared to original cc (also for performance) res = df.groupby(by=by, observed=False, sort=False, group_keys=True)[ @@ -90,7 +130,7 @@ def calc_metrics(group: pd.DataFrame) -> pd.Series: return res -def _dt_in_by(by): +def dt_in_by(by): by = [by] if isinstance(by, str) else by if any(str(by).startswith("dt:") for by in by): return True @@ -114,7 +154,7 @@ def _dt_in_by(by): ] -def _add_dt_to_df(df: pd.DataFrame, by: List[str]) -> Tuple[pd.DataFrame, List[str]]: +def add_dt_to_df(df: pd.DataFrame, by: List[str]) -> Tuple[pd.DataFrame, List[str]]: ser = df["time"] assert isinstance(by, list) # by = [by] if isinstance(by, str) else by @@ -138,7 +178,7 @@ def _add_dt_to_df(df: pd.DataFrame, by: List[str]) -> Tuple[pd.DataFrame, List[s return df, by -def _parse_groupby( +def parse_groupby( by: str | Iterable[str] | None, *, n_mod: int, n_qnt: int ) -> List[str | pd.Grouper]: if by is None: diff --git a/src/modelskill/comparison/_vertical_comparison.py b/src/modelskill/comparison/_vertical_comparison.py index c6f2a6918..9ebfcfdda 100644 --- a/src/modelskill/comparison/_vertical_comparison.py +++ b/src/modelskill/comparison/_vertical_comparison.py @@ -5,13 +5,13 @@ from typing import TYPE_CHECKING, Callable, Iterable, Sequence, Tuple from ..types import GeometryType -from ..plotting._misc import _get_fig_ax +from ..plotting._misc import get_fig_ax import xarray as xr from matplotlib import dates as mdates from ..model import PointModelResult, TrackModelResult from ..model.network import NodeModelResult from ..model.vertical import VerticalModelResult -from ..metrics import _parse_metric +from ._utils import parse_metric from ..skill_profile import SkillProfile if TYPE_CHECKING: @@ -52,7 +52,7 @@ def profile( from ._comparison import MOD_COLORS cmp = self.comparer - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) title = title if title is not None else cmp.name @@ -139,7 +139,7 @@ def hovmoller( cmp = self.comparer mod_name = self._get_model_name(model) - _, ax = _get_fig_ax(ax, figsize) + _, ax = get_fig_ax(ax, figsize) if title is None: title = f"{mod_name} and Observations" @@ -328,7 +328,7 @@ def skill( if cmp.n_points == 0: raise ValueError("No data selected for skill assessment") - metric_funcs = _parse_metric(metrics, directional=cmp.quantity.is_directional) + metric_funcs = parse_metric(metrics, directional=cmp.quantity.is_directional) def calculate_metric(g): obs = g[cmp._obs_name] diff --git a/src/modelskill/metrics.py b/src/modelskill/metrics.py index 5a6f008a0..fd19fd2a9 100644 --- a/src/modelskill/metrics.py +++ b/src/modelskill/metrics.py @@ -1,18 +1,15 @@ """Metrics for evaluating the difference between a model and an observation.""" from __future__ import annotations -import inspect import sys import warnings from typing import ( Any, Callable, - Iterable, List, Literal, Set, - Tuple, TypeVar, Union, ) @@ -22,7 +19,7 @@ from numpy.typing import ArrayLike from scipy import stats -from .settings import options +from ._utils import linear_regression defined_metrics: Set[str] = set() metrics_with_units: Set[str] = set() @@ -722,36 +719,7 @@ def lin_slope(obs: ArrayLike, model: ArrayLike, reg_method="ols") -> Any: Range: $(-\infty, \infty )$; Best: 1 """ assert obs.size == model.size - return _linear_regression(obs, model, reg_method)[0] - - -def _linear_regression( - obs: ArrayLike, model: ArrayLike, reg_method="ols" -) -> Tuple[float, float]: - if len(obs) == 0: - return np.nan, np.nan # TODO raise error? - - if reg_method == "ols": - from scipy.stats import linregress as _linregress - - reg = _linregress(obs, model) - intercept = reg.intercept - slope = reg.slope - elif reg_method == "odr": - from scipy import odr - - data = odr.Data(obs, model) - odr_obj = odr.ODR(data, odr.unilinear) - output = odr_obj.run() - - intercept = output.beta[1] - slope = output.beta[0] - else: - raise NotImplementedError( - f"Regression method: {reg_method} not implemented, select 'ols' or 'odr'" - ) - - return slope, intercept + return linear_regression(obs, model, reg_method)[0] def _std_obs(obs: ArrayLike, model: ArrayLike) -> Any: @@ -1164,42 +1132,6 @@ def get_metric(metric: Union[str, Callable]) -> Callable: ) -def _parse_metric( - metric: str | Iterable[str] | Callable | Iterable[Callable] | None, - *, - directional: bool = False, -) -> List[Callable]: - if metric is None: - if directional: - return default_circular_metrics - else: - # could be a list of str! - return [get_metric(m) for m in options.metrics.list] - - if isinstance(metric, str): - metrics: list = [metric] - elif callable(metric): - metrics = [metric] - elif isinstance(metric, Iterable): - metrics = list(metric) - - parsed_metrics = [] - - for m in metrics: - if isinstance(m, str): - parsed_metrics.append(get_metric(m)) - elif callable(m): - if len(inspect.signature(m).parameters) < 2: - raise ValueError( - "Metrics must have at least two arguments (obs, model)" - ) - parsed_metrics.append(m) - else: - raise TypeError(f"metric {m} must be a string or callable") - - return parsed_metrics - - def is_best(metric: str, expected: str | int) -> bool: try: func = get_metric(metric) diff --git a/src/modelskill/model/_base.py b/src/modelskill/model/_base.py index 23de28e51..3040520c5 100644 --- a/src/modelskill/model/_base.py +++ b/src/modelskill/model/_base.py @@ -12,7 +12,7 @@ from .track import TrackModelResult from .vertical import VerticalModelResult -from ..utils import _get_name +from .._utils import get_name from ..obs import Observation, PointObservation, TrackObservation, VerticalObservation @@ -48,10 +48,10 @@ def _parse_items( f"Input has more than 1 item, but item was not given! Available items: {avail_items}" ) - item = _get_name(item, valid_names=avail_items) + item = get_name(item, valid_names=avail_items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=avail_items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=avail_items) for i in aux_items or []] # check that there are no duplicates res = SelectedItems(values=item, aux=aux_items_str) @@ -63,7 +63,7 @@ def _parse_items( return res -def _validate_overlap_in_time(time: pd.DatetimeIndex, observation: Observation) -> None: +def validate_overlap_in_time(time: pd.DatetimeIndex, observation: Observation) -> None: overlap_in_time = ( time[0] <= observation.time[-1] and time[-1] >= observation.time[0] ) diff --git a/src/modelskill/model/dfsu.py b/src/modelskill/model/dfsu.py index 7fec7bb4e..32c705b9a 100644 --- a/src/modelskill/model/dfsu.py +++ b/src/modelskill/model/dfsu.py @@ -7,10 +7,10 @@ import numpy as np import pandas as pd -from ._base import SpatialField, _validate_overlap_in_time, SelectedItems +from ._base import SpatialField, validate_overlap_in_time, SelectedItems from ..types import UnstructuredType from ..quantity import Quantity -from ..utils import _get_idx +from .._utils import get_idx from .point import PointModelResult from .track import TrackModelResult from .vertical import VerticalModelResult @@ -78,7 +78,7 @@ def __init__( data, (mikeio.dfsu.Dfsu2DH, mikeio.dfsu.Dfsu3D, mikeio.Dataset) ): item_names = [i.name for i in data.items] - idx = _get_idx(x=item, valid_names=item_names) + idx = get_idx(x=item, valid_names=item_names) item_info = data.items[idx] self.sel_items = SelectedItems.parse( @@ -137,7 +137,7 @@ def extract( """ method = self._parse_spatial_method(spatial_method) - _validate_overlap_in_time(self.time, observation) + validate_overlap_in_time(self.time, observation) if isinstance(observation, PointObservation): return self._extract_point(observation, spatial_method=method) elif isinstance(observation, TrackObservation): diff --git a/src/modelskill/model/grid.py b/src/modelskill/model/grid.py index 8d1a35178..b81ec9d8e 100644 --- a/src/modelskill/model/grid.py +++ b/src/modelskill/model/grid.py @@ -6,7 +6,7 @@ import pandas as pd import xarray as xr -from ._base import SpatialField, _validate_overlap_in_time, SelectedItems +from ._base import SpatialField, validate_overlap_in_time, SelectedItems from ..utils import rename_coords_xr, rename_coords_pd from ..types import GridType from ..quantity import Quantity @@ -148,7 +148,7 @@ def extract( raise NotImplementedError( "Extraction of VerticalObservation from GridModelResult is not implemented yet." ) - _validate_overlap_in_time(self.time, observation) + validate_overlap_in_time(self.time, observation) if isinstance(observation, PointObservation): return self._extract_point(observation, spatial_method) elif isinstance(observation, TrackObservation): diff --git a/src/modelskill/model/network.py b/src/modelskill/model/network.py index 17c62035c..4df6006ce 100644 --- a/src/modelskill/model/network.py +++ b/src/modelskill/model/network.py @@ -7,7 +7,8 @@ import pandas as pd import xarray as xr -from modelskill.timeseries import TimeSeries, _parse_network_node_input +from modelskill.timeseries import TimeSeries +from modelskill.timeseries._point import parse_network_node_input from ._base import SelectedItems from ..obs import NodeObservation, ReachObservation from ..quantity import Quantity @@ -58,7 +59,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ): if not self._is_input_validated(data): - data = _parse_network_node_input( + data = parse_network_node_input( data, name=name, item=item, diff --git a/src/modelskill/model/point.py b/src/modelskill/model/point.py index 4e071c4e9..5441438be 100644 --- a/src/modelskill/model/point.py +++ b/src/modelskill/model/point.py @@ -6,7 +6,8 @@ from ..obs import Observation from ..types import PointType from ..quantity import Quantity -from ..timeseries import TimeSeries, _parse_xyz_point_input +from ..timeseries import TimeSeries +from ..timeseries._point import parse_xyz_point_input from ..timeseries._align import align_data @@ -52,7 +53,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_xyz_point_input( + data = parse_xyz_point_input( data, name=name, item=item, diff --git a/src/modelskill/model/track.py b/src/modelskill/model/track.py index 612cb0063..9543e43bc 100644 --- a/src/modelskill/model/track.py +++ b/src/modelskill/model/track.py @@ -8,7 +8,8 @@ from ..types import TrackType from ..obs import TrackObservation from ..quantity import Quantity -from ..timeseries import TimeSeries, _parse_track_input +from ..timeseries import TimeSeries +from ..timeseries._track import parse_track_input class TrackModelResult(TimeSeries): @@ -55,7 +56,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_track_input( + data = parse_track_input( data=data, name=name, item=item, diff --git a/src/modelskill/model/vertical.py b/src/modelskill/model/vertical.py index b1e529ccb..abd46a525 100644 --- a/src/modelskill/model/vertical.py +++ b/src/modelskill/model/vertical.py @@ -7,7 +7,8 @@ from ..types import VerticalType from ..quantity import Quantity -from ..timeseries import TimeSeries, _parse_vertical_input +from ..timeseries import TimeSeries +from ..timeseries._vertical import parse_vertical_input from ..obs import VerticalObservation @@ -58,7 +59,7 @@ def __init__( aux_items: Sequence[int | str] | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_vertical_input( + data = parse_vertical_input( data=data, name=name, item=item, diff --git a/src/modelskill/network.py b/src/modelskill/network.py index 071bc01f6..7b08c1a22 100644 --- a/src/modelskill/network.py +++ b/src/modelskill/network.py @@ -445,6 +445,7 @@ def from_res1d( list_of_reaches = cls._load_res1d_network(res, nodes_list, reaches_list) return cls(list_of_reaches) + @staticmethod def _load_res1d_network( res: Res1D, @@ -491,7 +492,9 @@ def _generate_alias_map(g: nx.Graph) -> dict[str | tuple[str, float], int]: return {g.nodes[id]["alias"]: id for id in g.nodes()} @staticmethod - def _generate_reaches_dict(reaches: Sequence[NetworkReach]) -> dict[str, NetworkReach]: + def _generate_reaches_dict( + reaches: Sequence[NetworkReach], + ) -> dict[str, NetworkReach]: return {r.id: r for r in reaches} @staticmethod diff --git a/src/modelskill/obs.py b/src/modelskill/obs.py index e4f653bbe..1325915b6 100644 --- a/src/modelskill/obs.py +++ b/src/modelskill/obs.py @@ -21,14 +21,14 @@ from .types import PointType, TrackType, VerticalType, GeometryType, DataInputType from . import Quantity -from .timeseries import ( - TimeSeries, - _parse_xyz_point_input, - _parse_track_input, - _parse_vertical_input, - _parse_network_node_input, - _parse_network_breakpoint_input, +from .timeseries import TimeSeries +from .timeseries._point import ( + parse_xyz_point_input, + parse_network_node_input, + parse_network_breakpoint_input, ) +from .timeseries._track import parse_track_input +from .timeseries._vertical import parse_vertical_input # NetCDF attributes can only be str, int, float https://unidata.github.io/netcdf4-python/#attributes-in-a-netcdf-file @@ -233,7 +233,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_xyz_point_input( + data = parse_xyz_point_input( data, name=name, item=item, @@ -352,7 +352,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_track_input( + data = parse_track_input( data=data, name=name, item=item, @@ -451,7 +451,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_vertical_input( + data = parse_vertical_input( data, name=name, item=item, @@ -549,7 +549,7 @@ def __init__( if isinstance(at, tuple): reach, distance = str(at[0]), float(at[1]) if not self._is_input_validated(data): - data = _parse_network_breakpoint_input( + data = parse_network_breakpoint_input( data, name=name, item=item, @@ -560,7 +560,7 @@ def __init__( ) else: if not self._is_input_validated(data): - data = _parse_network_node_input( + data = parse_network_node_input( data, name=name, item=item, @@ -747,7 +747,7 @@ def __init__( attrs: dict | None = None, ) -> None: if not self._is_input_validated(data): - data = _parse_network_breakpoint_input( + data = parse_network_breakpoint_input( data, name=name, item=item, diff --git a/src/modelskill/plotting/_misc.py b/src/modelskill/plotting/_misc.py index e41a8e214..cc6d08ed4 100644 --- a/src/modelskill/plotting/_misc.py +++ b/src/modelskill/plotting/_misc.py @@ -11,13 +11,13 @@ from ..obs import unit_display_name -def _get_ax(ax=None, figsize=None): +def get_ax(ax=None, figsize=None): if ax is None: _, ax = plt.subplots(figsize=figsize) return ax -def _get_fig_ax(ax: Axes | None = None, figsize=None): +def get_fig_ax(ax: Axes | None = None, figsize=None): if ax is None: fig, ax = plt.subplots(figsize=figsize) else: @@ -25,7 +25,7 @@ def _get_fig_ax(ax: Axes | None = None, figsize=None): return fig, ax -def _xtick_directional(ax, xlim=None): +def xtick_directional(ax, xlim=None): """Set x-ticks for directional data""" ticks = ticks = _xyticks(lim=xlim) if len(ticks) > 2: @@ -34,7 +34,7 @@ def _xtick_directional(ax, xlim=None): ax.set_xlim(0, 360) -def _ytick_directional(ax, ylim=None): +def ytick_directional(ax, ylim=None): """Set y-ticks for directional data""" ticks = _xyticks(lim=ylim) if len(ticks) > 2: diff --git a/src/modelskill/plotting/_scatter.py b/src/modelskill/plotting/_scatter.py index e8d5c0600..5934462aa 100644 --- a/src/modelskill/plotting/_scatter.py +++ b/src/modelskill/plotting/_scatter.py @@ -17,8 +17,8 @@ import modelskill.settings as settings from modelskill.settings import options -from ..metrics import _linear_regression -from ._misc import quantiles_xy, sample_points, format_skill_table, _get_fig_ax +from .._utils import linear_regression +from ._misc import quantiles_xy, sample_points, format_skill_table, get_fig_ax def scatter( @@ -283,7 +283,7 @@ def _scatter_matplotlib( cmap=None, **kwargs, ) -> matplotlib.axes.Axes: - fig, ax = _get_fig_ax(ax, figsize) + fig, ax = get_fig_ax(ax, figsize) if len(x) < 2: raise ValueError("Not enough data to plot. At least 2 points are required.") @@ -334,11 +334,11 @@ def _scatter_matplotlib( if reg_method: if fit_to_quantiles: - slope, intercept = _linear_regression( + slope, intercept = linear_regression( obs=xq, model=yq, reg_method=reg_method ) else: - slope, intercept = _linear_regression(obs=x, model=y, reg_method=reg_method) + slope, intercept = linear_regression(obs=x, model=y, reg_method=reg_method) ax.plot( x_trend, @@ -456,11 +456,11 @@ def _scatter_plotly( if reg_method: if fit_to_quantiles: - slope, intercept = _linear_regression( + slope, intercept = linear_regression( obs=xq, model=yq, reg_method=reg_method ) else: - slope, intercept = _linear_regression(obs=x, model=y, reg_method=reg_method) + slope, intercept = linear_regression(obs=x, model=y, reg_method=reg_method) regression_line = go.Scatter( x=x_trend, diff --git a/src/modelskill/plotting/_spatial_overview.py b/src/modelskill/plotting/_spatial_overview.py index 739369147..c1f0df5d1 100644 --- a/src/modelskill/plotting/_spatial_overview.py +++ b/src/modelskill/plotting/_spatial_overview.py @@ -10,7 +10,7 @@ from ..model.track import TrackModelResult from ..model.vertical import VerticalModelResult from ..obs import Observation, PointObservation, TrackObservation, VerticalObservation -from ._misc import _get_ax +from ._misc import get_ax def spatial_overview( @@ -66,7 +66,7 @@ def spatial_overview( obs = [] if obs is None else list(obs) if isinstance(obs, Iterable) else [obs] # type: ignore mods = [] if mod is None else list(mod) if isinstance(mod, Iterable) else [mod] # type: ignore - ax = _get_ax(ax=ax, figsize=figsize) + ax = get_ax(ax=ax, figsize=figsize) # TODO: support Gridded ModelResults for m in mods: diff --git a/src/modelskill/plotting/_temporal_coverage.py b/src/modelskill/plotting/_temporal_coverage.py index a861728a5..5b5bfc639 100644 --- a/src/modelskill/plotting/_temporal_coverage.py +++ b/src/modelskill/plotting/_temporal_coverage.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import numpy as np -from ._misc import _get_fig_ax +from ._misc import get_fig_ax def temporal_coverage( @@ -79,7 +79,7 @@ def temporal_coverage( ysize = max(2.0, 0.45 * n_lines) figsize = (7, ysize) - fig, ax = _get_fig_ax(ax=ax, figsize=figsize) + fig, ax = get_fig_ax(ax=ax, figsize=figsize) y = np.repeat(0.0, 2) labels = [] diff --git a/src/modelskill/skill.py b/src/modelskill/skill.py index d16bda8dd..eeeb67a65 100644 --- a/src/modelskill/skill.py +++ b/src/modelskill/skill.py @@ -9,7 +9,7 @@ from matplotlib.axes import Axes from matplotlib.colors import Colormap -from .plotting._misc import _get_fig_ax +from .plotting._misc import get_fig_ax from .metrics import small_is_best, large_is_best, zero_is_best, one_is_best @@ -249,7 +249,7 @@ def grid( if figsize is None: figsize = (nx, ny) - fig, ax = _get_fig_ax(ax, figsize) + fig, ax = get_fig_ax(ax, figsize) assert ax is not None pcm = ax.pcolormesh(df, cmap=cmap, vmin=vmin, vmax=vmax) ax.set_xticks(np.arange(nx) + 0.5) diff --git a/src/modelskill/timeseries/__init__.py b/src/modelskill/timeseries/__init__.py index f52f17271..457a755cd 100644 --- a/src/modelskill/timeseries/__init__.py +++ b/src/modelskill/timeseries/__init__.py @@ -1,17 +1,3 @@ from ._timeseries import TimeSeries -from ._point import ( - _parse_xyz_point_input, - _parse_network_node_input, - _parse_network_breakpoint_input, -) -from ._track import _parse_track_input -from ._vertical import _parse_vertical_input -__all__ = [ - "TimeSeries", - "_parse_xyz_point_input", - "_parse_track_input", - "_parse_vertical_input", - "_parse_network_node_input", - "_parse_network_breakpoint_input", -] +__all__ = ["TimeSeries"] diff --git a/src/modelskill/timeseries/_point.py b/src/modelskill/timeseries/_point.py index f0742f4dc..a642163ed 100644 --- a/src/modelskill/timeseries/_point.py +++ b/src/modelskill/timeseries/_point.py @@ -10,8 +10,8 @@ from ..types import GeometryType, PointType from ..quantity import Quantity -from ..utils import _get_name -from ._timeseries import _normalize_time_to_ns, _validate_data_var_name +from .._utils import get_name +from ._timeseries import normalize_time_to_ns, validate_data_var_name from ._coords import XYZCoords, NodeCoords, ReachCoords @@ -39,10 +39,10 @@ def _parse_point_items( f"Input has more than 1 item, but item was not given! Available items: {items}" ) - item = _get_name(item, valid_names=items) + item = get_name(item, valid_names=items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=items) for i in aux_items or []] # check that there are no duplicates res = PointItem(values=item, aux=aux_items_str) @@ -121,9 +121,9 @@ def _convert_to_dataset( data = data.rename({time_dim_name: "time"}) ds = data - ds = _normalize_time_to_ns(ds) + ds = normalize_time_to_ns(ds) - name = _validate_data_var_name(varname) + name = validate_data_var_name(varname) n_unique_times = len(ds.time.to_index().unique()) if n_unique_times < len(ds.time): @@ -266,7 +266,7 @@ def _parse_point_input( return ds -def _parse_xyz_point_input( +def parse_xyz_point_input( data: PointType, name: str | None, item: str | int | None, @@ -281,7 +281,7 @@ def _parse_xyz_point_input( return ds -def _parse_network_node_input( +def parse_network_node_input( data: PointType, name: str | None, item: str | int | None, @@ -296,7 +296,7 @@ def _parse_network_node_input( return ds -def _parse_network_breakpoint_input( +def parse_network_breakpoint_input( data: PointType, name: str | None, item: str | int | None, diff --git a/src/modelskill/timeseries/_timeseries.py b/src/modelskill/timeseries/_timeseries.py index bba5d78f7..5850d4a0e 100644 --- a/src/modelskill/timeseries/_timeseries.py +++ b/src/modelskill/timeseries/_timeseries.py @@ -28,7 +28,7 @@ ] -def _validate_data_var_name(name: str) -> str: +def validate_data_var_name(name: str) -> str: if not isinstance(name, str): raise TypeError("name must be a string") RESERVED_NAMES = ["x", "y", "z", "time"] @@ -39,7 +39,7 @@ def _validate_data_var_name(name: str) -> str: return name -def _normalize_time_to_ns(ds: xr.Dataset) -> xr.Dataset: +def normalize_time_to_ns(ds: xr.Dataset) -> xr.Dataset: """Cast a dataset's time coordinate to ``datetime64[ns]``. Under pandas 3.0 the default datetime resolution is no longer nanoseconds, @@ -128,7 +128,7 @@ def _validate_dataset(ds: xr.Dataset) -> xr.Dataset: name = "" n_primary = 0 for v in vars: - v = _validate_data_var_name(str(v)) + v = validate_data_var_name(str(v)) assert ( len(ds[v].dims) == 1 ), f"Only 0-dimensional data arrays are supported! {v} has {len(ds[v].dims)} dimensions" @@ -204,7 +204,7 @@ def name(self) -> str: @name.setter def name(self, name: str) -> None: - name = _validate_data_var_name(name) + name = validate_data_var_name(name) self.data = self.data.rename({self._val_item: name}) @property diff --git a/src/modelskill/timeseries/_track.py b/src/modelskill/timeseries/_track.py index 9dcee6dea..8ff5291f6 100644 --- a/src/modelskill/timeseries/_track.py +++ b/src/modelskill/timeseries/_track.py @@ -11,8 +11,9 @@ from ..types import GeometryType, TrackType from ..quantity import Quantity -from ..utils import _get_name, make_unique_index -from ._timeseries import _validate_data_var_name +from .._utils import get_name +from ..utils import make_unique_index +from ._timeseries import validate_data_var_name @dataclass @@ -47,12 +48,12 @@ def _parse_track_items( f"Input has more than 3 items, but item was not given! Available items: {items}" ) - item = _get_name(item, valid_names=items) - x_item = _get_name(x_item, valid_names=items) - y_item = _get_name(y_item, valid_names=items) + item = get_name(item, valid_names=items) + x_item = get_name(x_item, valid_names=items) + y_item = get_name(y_item, valid_names=items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=items) for i in aux_items or []] # check that there are no duplicates res = TrackItem(x=x_item, y=y_item, values=item, aux=aux_items_str) @@ -62,7 +63,7 @@ def _parse_track_items( return res -def _parse_track_input( +def parse_track_input( data: TrackType, name: str | None, item: str | int | None, @@ -100,7 +101,7 @@ def _parse_track_input( valid_items, x_item=x_item, y_item=y_item, item=item, aux_items=aux_items ) name = name or sel_items.values - name = _validate_data_var_name(name) + name = validate_data_var_name(name) # parse quantity if isinstance(data, mikeio.Dataset): diff --git a/src/modelskill/timeseries/_vertical.py b/src/modelskill/timeseries/_vertical.py index 57226ad58..79193d002 100644 --- a/src/modelskill/timeseries/_vertical.py +++ b/src/modelskill/timeseries/_vertical.py @@ -13,8 +13,8 @@ from ..types import GeometryType, VerticalType from ..quantity import Quantity -from ..utils import _get_name -from ._timeseries import _validate_data_var_name +from .._utils import get_name +from ._timeseries import validate_data_var_name @dataclass @@ -47,11 +47,11 @@ def _parse_vertical_items( f"Input has more than 2 items, but item was not given! Available items: {items}" ) - item = _get_name(item, valid_names=items) - z_item = _get_name(z_item, valid_names=items) + item = get_name(item, valid_names=items) + z_item = get_name(z_item, valid_names=items) if isinstance(aux_items, (str, int)): aux_items = [aux_items] - aux_items_str = [_get_name(i, valid_names=items) for i in aux_items or []] + aux_items_str = [get_name(i, valid_names=items) for i in aux_items or []] # check that there are no duplicates res = VerticalItem(z=z_item, values=item, aux=aux_items_str) @@ -80,7 +80,7 @@ def _include_location( return ds -def _parse_vertical_input( +def parse_vertical_input( data: VerticalType, name: Optional[str], item: str | int | None, @@ -118,7 +118,7 @@ def _parse_vertical_input( ) name = name or sel_items.values - name = _validate_data_var_name(name) + name = validate_data_var_name(name) # parse quantity if isinstance(data, mikeio.Dataset): diff --git a/src/modelskill/utils.py b/src/modelskill/utils.py index 43435dd93..f21c3cff0 100644 --- a/src/modelskill/utils.py +++ b/src/modelskill/utils.py @@ -1,12 +1,9 @@ from __future__ import annotations -from typing import Sequence, cast import warnings import numpy as np import pandas as pd import xarray as xr -from collections.abc import Hashable, Iterable - -_RESERVED_NAMES = ["Observation", "time", "x", "y", "z"] +from collections.abc import Iterable POS_COORDINATE_NAME_MAPPING = { "lon": "x", @@ -156,39 +153,3 @@ def make_unique_index( new_index = df_index + offset_in_ns * tmp assert isinstance(new_index, pd.DatetimeIndex) return new_index - - -def _get_name(x: int | str | None, valid_names: Sequence[Hashable]) -> str: - """Parse name/idx from list of valid names (e.g. obs from obs_names), return name""" - return cast(str, valid_names[_get_idx(x, valid_names)]) - - -def _get_idx(x: int | str | None, valid_names: Sequence[Hashable]) -> int: - """Parse name/idx from list of valid names (e.g. obs from obs_names), return idx""" - - if x is None: - if len(valid_names) == 1: - return 0 - else: - raise ValueError( - f"Multiple items available. Must specify name or index. Available items: {valid_names}" - ) - - n = len(valid_names) - if n == 0: - raise ValueError(f"Cannot select {x} from empty list!") - elif isinstance(x, str): - if x in valid_names: - idx = valid_names.index(x) - else: - raise KeyError(f"Name {x} could not be found in {valid_names}") - elif isinstance(x, int): - if x < 0: # Handle negative indices - x += n - if x >= 0 and x < n: - idx = x - else: - raise IndexError(f"Id {x} is out of range for {valid_names}") - else: - raise TypeError(f"Input {x} invalid! Must be None, str or int, not {type(x)}") - return idx From 162c523734df63a5d4e4ac743afe870718e678af Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 14 May 2026 14:02:16 +0200 Subject: [PATCH 2/3] Refine ADR-012: factor RESERVED_NAMES, extract _regression, align docs - Soften ADR-012 + CLAUDE.md to describe what's actually enforced: PLC2701 is the hard rule; underscores on function names are an optional convention. - Split _utils.RESERVED_NAMES into RESERVED_COORD_NAMES (coord-name collisions) and RESERVED_COMPARER_VAR_NAMES (+Observation slot), expressing the superset relationship in code instead of by shadowed in-function constants. - Extract linear_regression from _utils.py into its own single-purpose _regression.py, leaving _utils.py as the name-handling helpers module. --- CLAUDE.md | 10 +++--- adr/012-public-private-api-convention.md | 8 ++--- src/modelskill/_regression.py | 40 ++++++++++++++++++++++++ src/modelskill/_utils.py | 38 ++-------------------- src/modelskill/comparison/_comparison.py | 10 +++--- src/modelskill/metrics.py | 2 +- src/modelskill/plotting/_scatter.py | 2 +- src/modelskill/timeseries/_timeseries.py | 4 +-- 8 files changed, 61 insertions(+), 53 deletions(-) create mode 100644 src/modelskill/_regression.py diff --git a/CLAUDE.md b/CLAUDE.md index 69ec4327e..8e536b6dc 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -57,13 +57,14 @@ This updates `roadmap/README.md` from the feature frontmatter. ### Public vs. Internal API A name is **private** if any segment of its import path starts with `_`. -Functions inside private modules use plain names (no leading `_`) — the path -already carries the privacy signal. Reserve leading-`_` function names for -symbols that live in a public module but are used in only one file. +Mechanically enforced by `PLC2701`: no module may import a `_foo` name from +another module. Underscores on function names are otherwise a convention — +authors may use them as a hint that a function is file-local, but this is +not required and not enforced. - Public: `modelskill.matching.match`, `modelskill.utils.rename_coords_xr` - Private (module path): `modelskill._utils.get_name`, `modelskill.timeseries._timeseries.validate_data_var_name` -- File-private: a `_helper` defined and used only inside `metrics.py` +- File-local hint: `_helper` inside any module, marking no cross-module use See [ADR-012](adr/012-public-private-api-convention.md). Enforced in CI by Ruff rule `PLC2701` (zero violations in `src/`; tests have a per-file ignore). @@ -183,7 +184,6 @@ Plots support both matplotlib (static) and plotly (interactive) backends. - Internal data storage uses xarray Datasets with standardized coordinate/variable names - Time coordinates use pandas datetime64 - Spatial coordinates: `x`, `y` (and `z` when applicable) -- Reserved names in `modelskill._utils.RESERVED_NAMES` should not be used for model/observation names - The `Quantity` class handles physical quantities with units and validation ### Testing Structure (`tests/`) diff --git a/adr/012-public-private-api-convention.md b/adr/012-public-private-api-convention.md index a7a18b567..c09400367 100644 --- a/adr/012-public-private-api-convention.md +++ b/adr/012-public-private-api-convention.md @@ -23,7 +23,7 @@ The conventional definition was chosen because the team did not want documentati ## Decision -**Privacy is set by the module path, not the function name.** A name is private if any segment of its import path starts with `_`. Functions inside private modules use plain names (no leading `_`); the path already carries the privacy signal. A leading `_` on a function or class name is reserved for symbols that live in a public module but are file-local — i.e., used only inside that one file. +**Privacy is set by the module path, not the function name.** A name is private if any segment of its import path starts with `_`. The only mechanically enforced rule is `PLC2701`: no module may import a name starting with `_` from another module. A leading `_` on a function name is otherwise a convention with no enforced meaning — authors may use it as a hint that a function is file-local. Inside a private module, where the module path already carries the privacy signal, this hint is optional. Examples: @@ -42,15 +42,15 @@ The convention is mechanically enforced by Ruff rule `PLC2701` (`import-private- **Introduce a `modelskill/_internal/` subpackage.** Concentrating all cross-cutting internal helpers in one location would make "shared internal API" structurally obvious. Rejected as pure churn: the helpers already live in natural subpackage homes, and moving them does not change the privacy story under this convention. -**Rename `utils.py` → `_utils.py` and rely on a deprecation shim.** Would break legitimately-public functions (`rename_coords_xr`, `rename_coords_pd`, `make_unique_index`) that are imported from non-underscored paths internally and could already be used by downstream consumers. The actual cross-cutting *private* helpers (`_get_name`, `_get_idx`, `_RESERVED_NAMES`) were split out into a new `modelskill/_utils.py` module instead, leaving `utils.py` as a public module. +**Rename `utils.py` → `_utils.py` and rely on a deprecation shim.** Would break legitimately-public functions (`rename_coords_xr`, `rename_coords_pd`, `make_unique_index`) that are imported from non-underscored paths internally and could already be used by downstream consumers. The actual cross-cutting *private* helpers (`_get_name`, `_get_idx`, `_RESERVED_NAMES`) were split out into a new `modelskill/_utils.py` module instead, leaving `utils.py` as a public module. The regression helper that previously lived as `metrics._linear_regression` was split into its own single-purpose private module `modelskill/_regression.py`. **Keep the status quo and configure Pyright to silence `reportPrivateUsage`.** Rejected because the warning is genuinely useful for catching future drift; removing the noise by suppression would also remove the signal. ## Consequences - **Cross-module use of internal helpers is no longer flagged.** `from modelskill._utils import get_name` is unambiguously legal; the module path says "internal." Pyright's `reportPrivateUsage` falls silent. -- **Function-name underscores carry a sharper meaning.** A `_function_name` now signals "private to this one file" — never "private to this subpackage." Reviewers can use that signal at face value. +- **Function-name underscores remain a convention.** `PLC2701` enforces that no `_foo` may be imported across modules. Beyond that, underscores carry no formal meaning: authors may use them as a navigation hint that a function is file-local, in either public or private modules. - **Future contributors have a clear default.** New internal helpers go in `_modules`, not in `_underscored_names` inside public modules. PRs that put internal helpers in `metrics.py` or `utils.py` with a leading underscore now stand out as the exception rather than the norm. - **One breaking change is accepted.** Five underscored names were previously re-exported from `modelskill.timeseries` via `__all__` (`_parse_track_input`, etc.). These names are no longer reachable from `modelskill.timeseries`; the renamed functions live in `modelskill.timeseries._point`, `_track`, `_vertical`. By Python convention, leading-underscore names were never part of the stable API, so this is consistent with the new policy. - **Tests retain access to internals** via a `tests/**/*.py` per-file ignore for `PLC2701`. White-box testing of private modules remains supported. -- **`metrics.py` no longer hosts `_parse_metric` or `_linear_regression`.** They moved to `comparison/_utils.py:parse_metric` and `_utils.py:linear_regression` respectively. The first is consumer-specific to the comparison subpackage; the second is a general regression helper used by both `metrics.lin_slope` and `plotting._scatter`. +- **`metrics.py` no longer hosts `_parse_metric` or `_linear_regression`.** They moved to `comparison/_utils.py:parse_metric` and `_regression.py:linear_regression` respectively. The first is consumer-specific to the comparison subpackage; the second is a general regression helper used by both `metrics.lin_slope` and `plotting._scatter`. diff --git a/src/modelskill/_regression.py b/src/modelskill/_regression.py new file mode 100644 index 000000000..2c550b18c --- /dev/null +++ b/src/modelskill/_regression.py @@ -0,0 +1,40 @@ +"""Linear regression helper shared by ``metrics.lin_slope`` and the scatter trendline. + +Private per ADR-012: the leading underscore on the module name signals internal use. +""" + +from __future__ import annotations +from typing import Tuple + +import numpy as np +from numpy.typing import ArrayLike + + +def linear_regression( + obs: ArrayLike, model: ArrayLike, reg_method: str = "ols" +) -> Tuple[float, float]: + """Fit a linear regression of ``model`` against ``obs`` and return (slope, intercept).""" + if len(obs) == 0: # type: ignore[arg-type] + return np.nan, np.nan + + if reg_method == "ols": + from scipy.stats import linregress + + reg = linregress(obs, model) + intercept = reg.intercept + slope = reg.slope + elif reg_method == "odr": + from scipy import odr + + data = odr.Data(obs, model) + odr_obj = odr.ODR(data, odr.unilinear) + output = odr_obj.run() + + intercept = output.beta[1] + slope = output.beta[0] + else: + raise NotImplementedError( + f"Regression method: {reg_method} not implemented, select 'ols' or 'odr'" + ) + + return slope, intercept diff --git a/src/modelskill/_utils.py b/src/modelskill/_utils.py index d9363b63e..93a80b484 100644 --- a/src/modelskill/_utils.py +++ b/src/modelskill/_utils.py @@ -6,14 +6,12 @@ """ from __future__ import annotations -from typing import Sequence, Tuple, cast +from typing import Sequence, cast from collections.abc import Hashable -import numpy as np -from numpy.typing import ArrayLike - -RESERVED_NAMES = ["Observation", "time", "x", "y", "z"] +RESERVED_COORD_NAMES = ["x", "y", "z", "time"] +RESERVED_COMPARER_VAR_NAMES = [*RESERVED_COORD_NAMES, "Observation"] def get_name(x: int | str | None, valid_names: Sequence[Hashable]) -> str: @@ -50,33 +48,3 @@ def get_idx(x: int | str | None, valid_names: Sequence[Hashable]) -> int: else: raise TypeError(f"Input {x} invalid! Must be None, str or int, not {type(x)}") return idx - - -def linear_regression( - obs: ArrayLike, model: ArrayLike, reg_method: str = "ols" -) -> Tuple[float, float]: - """Fit a linear regression of ``model`` against ``obs`` and return (slope, intercept).""" - if len(obs) == 0: # type: ignore[arg-type] - return np.nan, np.nan - - if reg_method == "ols": - from scipy.stats import linregress - - reg = linregress(obs, model) - intercept = reg.intercept - slope = reg.slope - elif reg_method == "odr": - from scipy import odr - - data = odr.Data(obs, model) - odr_obj = odr.ODR(data, odr.unilinear) - output = odr_obj.run() - - intercept = output.beta[1] - slope = output.beta[0] - else: - raise NotImplementedError( - f"Regression method: {reg_method} not implemented, select 'ols' or 'odr'" - ) - - return slope, intercept diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index 701972594..dbd3c666c 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -42,7 +42,7 @@ from ..skill import SkillTable from ..skill_grid import SkillGrid from ..settings import register_option -from .._utils import get_name, RESERVED_NAMES +from .._utils import get_name, RESERVED_COMPARER_VAR_NAMES from .. import __version__ from ._vertical_comparison import VerticalAccessor @@ -575,9 +575,9 @@ def name(self) -> str: @name.setter def name(self, name: str) -> None: - if name in RESERVED_NAMES: + if name in RESERVED_COMPARER_VAR_NAMES: raise ValueError( - f"Cannot rename to any of {RESERVED_NAMES}, these are reserved names!" + f"Cannot rename to any of {RESERVED_COMPARER_VAR_NAMES}, these are reserved names!" ) self.data.attrs["name"] = name @@ -748,10 +748,10 @@ def rename( # "ignore": silently remove keys that are not in allowed_keys mapping = {k: v for k, v in mapping.items() if k in allowed_keys} - if any([k in RESERVED_NAMES for k in mapping.values()]): + if any([k in RESERVED_COMPARER_VAR_NAMES for k in mapping.values()]): # TODO: also check for duplicates raise ValueError( - f"Cannot rename to any of {RESERVED_NAMES}, these are reserved names!" + f"Cannot rename to any of {RESERVED_COMPARER_VAR_NAMES}, these are reserved names!" ) # rename observation diff --git a/src/modelskill/metrics.py b/src/modelskill/metrics.py index fd19fd2a9..b5a9f45e5 100644 --- a/src/modelskill/metrics.py +++ b/src/modelskill/metrics.py @@ -19,7 +19,7 @@ from numpy.typing import ArrayLike from scipy import stats -from ._utils import linear_regression +from ._regression import linear_regression defined_metrics: Set[str] = set() metrics_with_units: Set[str] = set() diff --git a/src/modelskill/plotting/_scatter.py b/src/modelskill/plotting/_scatter.py index 5934462aa..564481476 100644 --- a/src/modelskill/plotting/_scatter.py +++ b/src/modelskill/plotting/_scatter.py @@ -17,7 +17,7 @@ import modelskill.settings as settings from modelskill.settings import options -from .._utils import linear_regression +from .._regression import linear_regression from ._misc import quantiles_xy, sample_points, format_skill_table, get_fig_ax diff --git a/src/modelskill/timeseries/_timeseries.py b/src/modelskill/timeseries/_timeseries.py index 5850d4a0e..4e4d46c5d 100644 --- a/src/modelskill/timeseries/_timeseries.py +++ b/src/modelskill/timeseries/_timeseries.py @@ -10,6 +10,7 @@ from ..types import GeometryType from ..quantity import Quantity +from .._utils import RESERVED_COORD_NAMES from ._plotter import TimeSeriesPlotter, MatplotlibTimeSeriesPlotter from .. import __version__ @@ -31,8 +32,7 @@ def validate_data_var_name(name: str) -> str: if not isinstance(name, str): raise TypeError("name must be a string") - RESERVED_NAMES = ["x", "y", "z", "time"] - if name in RESERVED_NAMES: + if name in RESERVED_COORD_NAMES: raise ValueError( f"name '{name}' is reserved and cannot be used! Please choose another name." ) From 0761c01a46302268f65296f5c5411d1356c8121f Mon Sep 17 00:00:00 2001 From: Henrik Andersson Date: Thu, 14 May 2026 17:38:07 +0200 Subject: [PATCH 3/3] Rename _utils.py to _names.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit "utils" is non-informative — the module's actual concern is reserved coordinate/variable names and name-or-index resolution. Naming it _names.py says what's inside and what isn't. ADR-012 updated to reflect the rename and justify the choice; CLAUDE.md example updated. --- CLAUDE.md | 4 ++-- adr/012-public-private-api-convention.md | 6 +++--- src/modelskill/{_utils.py => _names.py} | 8 ++++---- src/modelskill/comparison/_collection.py | 2 +- src/modelskill/comparison/_collection_plotter.py | 2 +- src/modelskill/comparison/_comparer_plotter.py | 2 +- src/modelskill/comparison/_comparison.py | 2 +- src/modelskill/model/_base.py | 2 +- src/modelskill/model/dfsu.py | 2 +- src/modelskill/timeseries/_point.py | 2 +- src/modelskill/timeseries/_timeseries.py | 2 +- src/modelskill/timeseries/_track.py | 2 +- src/modelskill/timeseries/_vertical.py | 2 +- 13 files changed, 19 insertions(+), 19 deletions(-) rename src/modelskill/{_utils.py => _names.py} (84%) diff --git a/CLAUDE.md b/CLAUDE.md index 8e536b6dc..6d4bf185f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -63,8 +63,8 @@ authors may use them as a hint that a function is file-local, but this is not required and not enforced. - Public: `modelskill.matching.match`, `modelskill.utils.rename_coords_xr` -- Private (module path): `modelskill._utils.get_name`, `modelskill.timeseries._timeseries.validate_data_var_name` -- File-local hint: `_helper` inside any module, marking no cross-module use +- Private (module path): `modelskill._names.get_name`, `modelskill.timeseries._timeseries.validate_data_var_name` +- File-local hint: a leading `_` on a function name inside any module, marking no cross-module use See [ADR-012](adr/012-public-private-api-convention.md). Enforced in CI by Ruff rule `PLC2701` (zero violations in `src/`; tests have a per-file ignore). diff --git a/adr/012-public-private-api-convention.md b/adr/012-public-private-api-convention.md index c09400367..cd8054c78 100644 --- a/adr/012-public-private-api-convention.md +++ b/adr/012-public-private-api-convention.md @@ -28,7 +28,7 @@ The conventional definition was chosen because the team did not want documentati Examples: - `modelskill.timeseries._timeseries.validate_data_var_name` — private (module path). -- `modelskill._utils.get_name` — private (module path). +- `modelskill._names.get_name` — private (module path). - `modelskill.metrics._format_directional_label` (hypothetical) — private (function name); used only inside `metrics.py`. - `modelskill.matching.match` — public. @@ -42,13 +42,13 @@ The convention is mechanically enforced by Ruff rule `PLC2701` (`import-private- **Introduce a `modelskill/_internal/` subpackage.** Concentrating all cross-cutting internal helpers in one location would make "shared internal API" structurally obvious. Rejected as pure churn: the helpers already live in natural subpackage homes, and moving them does not change the privacy story under this convention. -**Rename `utils.py` → `_utils.py` and rely on a deprecation shim.** Would break legitimately-public functions (`rename_coords_xr`, `rename_coords_pd`, `make_unique_index`) that are imported from non-underscored paths internally and could already be used by downstream consumers. The actual cross-cutting *private* helpers (`_get_name`, `_get_idx`, `_RESERVED_NAMES`) were split out into a new `modelskill/_utils.py` module instead, leaving `utils.py` as a public module. The regression helper that previously lived as `metrics._linear_regression` was split into its own single-purpose private module `modelskill/_regression.py`. +**Rename `utils.py` → `_utils.py` and rely on a deprecation shim.** Would break legitimately-public functions (`rename_coords_xr`, `rename_coords_pd`, `make_unique_index`) that are imported from non-underscored paths internally and could already be used by downstream consumers. The actual cross-cutting *private* helpers (`_get_name`, `_get_idx`, `_RESERVED_NAMES`) were split out into a new `modelskill/_names.py` module instead, leaving `utils.py` as a public module. The regression helper that previously lived as `metrics._linear_regression` was split into its own single-purpose private module `modelskill/_regression.py`. Naming the new module `_names.py` rather than `_utils.py` is deliberate: it describes the concern (reserved names plus name/index resolution) instead of a contentless role. **Keep the status quo and configure Pyright to silence `reportPrivateUsage`.** Rejected because the warning is genuinely useful for catching future drift; removing the noise by suppression would also remove the signal. ## Consequences -- **Cross-module use of internal helpers is no longer flagged.** `from modelskill._utils import get_name` is unambiguously legal; the module path says "internal." Pyright's `reportPrivateUsage` falls silent. +- **Cross-module use of internal helpers is no longer flagged.** `from modelskill._names import get_name` is unambiguously legal; the module path says "internal." Pyright's `reportPrivateUsage` falls silent. - **Function-name underscores remain a convention.** `PLC2701` enforces that no `_foo` may be imported across modules. Beyond that, underscores carry no formal meaning: authors may use them as a navigation hint that a function is file-local, in either public or private modules. - **Future contributors have a clear default.** New internal helpers go in `_modules`, not in `_underscored_names` inside public modules. PRs that put internal helpers in `metrics.py` or `utils.py` with a leading underscore now stand out as the exception rather than the norm. - **One breaking change is accepted.** Five underscored names were previously re-exported from `modelskill.timeseries` via `__all__` (`_parse_track_input`, etc.). These names are no longer reachable from `modelskill.timeseries`; the renamed functions live in `modelskill.timeseries._point`, `_track`, `_vertical`. By Python convention, leading-underscore names were never part of the stable API, so this is consistent with the new policy. diff --git a/src/modelskill/_utils.py b/src/modelskill/_names.py similarity index 84% rename from src/modelskill/_utils.py rename to src/modelskill/_names.py index 93a80b484..bead22254 100644 --- a/src/modelskill/_utils.py +++ b/src/modelskill/_names.py @@ -1,8 +1,8 @@ -"""Package-internal helpers shared across modelskill subpackages. +"""Reserved coordinate/variable names and name-or-index resolution. -The leading underscore on the module name signals that this is internal API: -modelskill itself imports freely from here, but downstream consumers must not. -See ADR-012 for the public/private convention. +Centralises the names ModelSkill reserves on its xarray data structures and +the small resolver used wherever a user supplies a name or positional index +against a list of valid names. Private per ADR-012. """ from __future__ import annotations diff --git a/src/modelskill/comparison/_collection.py b/src/modelskill/comparison/_collection.py index 74592299a..4a4d5ad27 100644 --- a/src/modelskill/comparison/_collection.py +++ b/src/modelskill/comparison/_collection.py @@ -28,7 +28,7 @@ from ..skill import SkillTable from ..skill_grid import SkillGrid -from .._utils import get_name +from .._names import get_name from ._comparison import Comparer from ._utils import ( add_spatial_grid_to_df, diff --git a/src/modelskill/comparison/_collection_plotter.py b/src/modelskill/comparison/_collection_plotter.py index d7b0af405..9960d7b1a 100644 --- a/src/modelskill/comparison/_collection_plotter.py +++ b/src/modelskill/comparison/_collection_plotter.py @@ -25,7 +25,7 @@ from ..plotting import TaylorPoint, scatter, taylor_diagram from ..plotting._misc import get_fig_ax, xtick_directional, ytick_directional from ..settings import options -from .._utils import get_idx +from .._names import get_idx from ._comparer_plotter import quantiles_xy diff --git a/src/modelskill/comparison/_comparer_plotter.py b/src/modelskill/comparison/_comparer_plotter.py index 7ec75486e..d7458ba89 100644 --- a/src/modelskill/comparison/_comparer_plotter.py +++ b/src/modelskill/comparison/_comparer_plotter.py @@ -18,7 +18,7 @@ import numpy as np # type: ignore from .. import metrics as mtr -from .._utils import get_idx +from .._names import get_idx import matplotlib.colors as colors from ..plotting._misc import ( get_fig_ax, diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index dbd3c666c..ef8b0adf4 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -42,7 +42,7 @@ from ..skill import SkillTable from ..skill_grid import SkillGrid from ..settings import register_option -from .._utils import get_name, RESERVED_COMPARER_VAR_NAMES +from .._names import get_name, RESERVED_COMPARER_VAR_NAMES from .. import __version__ from ._vertical_comparison import VerticalAccessor diff --git a/src/modelskill/model/_base.py b/src/modelskill/model/_base.py index 3040520c5..15a51fe91 100644 --- a/src/modelskill/model/_base.py +++ b/src/modelskill/model/_base.py @@ -12,7 +12,7 @@ from .track import TrackModelResult from .vertical import VerticalModelResult -from .._utils import get_name +from .._names import get_name from ..obs import Observation, PointObservation, TrackObservation, VerticalObservation diff --git a/src/modelskill/model/dfsu.py b/src/modelskill/model/dfsu.py index 32c705b9a..a8db8e8e0 100644 --- a/src/modelskill/model/dfsu.py +++ b/src/modelskill/model/dfsu.py @@ -10,7 +10,7 @@ from ._base import SpatialField, validate_overlap_in_time, SelectedItems from ..types import UnstructuredType from ..quantity import Quantity -from .._utils import get_idx +from .._names import get_idx from .point import PointModelResult from .track import TrackModelResult from .vertical import VerticalModelResult diff --git a/src/modelskill/timeseries/_point.py b/src/modelskill/timeseries/_point.py index a642163ed..798eddbcd 100644 --- a/src/modelskill/timeseries/_point.py +++ b/src/modelskill/timeseries/_point.py @@ -10,7 +10,7 @@ from ..types import GeometryType, PointType from ..quantity import Quantity -from .._utils import get_name +from .._names import get_name from ._timeseries import normalize_time_to_ns, validate_data_var_name from ._coords import XYZCoords, NodeCoords, ReachCoords diff --git a/src/modelskill/timeseries/_timeseries.py b/src/modelskill/timeseries/_timeseries.py index 4e4d46c5d..d840e54b3 100644 --- a/src/modelskill/timeseries/_timeseries.py +++ b/src/modelskill/timeseries/_timeseries.py @@ -10,7 +10,7 @@ from ..types import GeometryType from ..quantity import Quantity -from .._utils import RESERVED_COORD_NAMES +from .._names import RESERVED_COORD_NAMES from ._plotter import TimeSeriesPlotter, MatplotlibTimeSeriesPlotter from .. import __version__ diff --git a/src/modelskill/timeseries/_track.py b/src/modelskill/timeseries/_track.py index 8ff5291f6..d12c4f969 100644 --- a/src/modelskill/timeseries/_track.py +++ b/src/modelskill/timeseries/_track.py @@ -11,7 +11,7 @@ from ..types import GeometryType, TrackType from ..quantity import Quantity -from .._utils import get_name +from .._names import get_name from ..utils import make_unique_index from ._timeseries import validate_data_var_name diff --git a/src/modelskill/timeseries/_vertical.py b/src/modelskill/timeseries/_vertical.py index 79193d002..a684cd14a 100644 --- a/src/modelskill/timeseries/_vertical.py +++ b/src/modelskill/timeseries/_vertical.py @@ -13,7 +13,7 @@ from ..types import GeometryType, VerticalType from ..quantity import Quantity -from .._utils import get_name +from .._names import get_name from ._timeseries import validate_data_var_name