feat(metrics): add WelfordVariance and WelfordCovariance helpers (PR 1 of #3748)#3750
feat(metrics): add WelfordVariance and WelfordCovariance helpers (PR 1 of #3748)#3750joemunene-by wants to merge 4 commits into
Conversation
Introduces ignite/metrics/_running_stats.py with two numerically
stable running-statistics primitives that variance- and
covariance-bearing metrics can share, instead of each one rolling
its own naive Σx² − (Σx)²/n implementation.
WelfordVariance mean, variance, std for a single variable.
WelfordCovariance variance_x, variance_y, covariance, and
Pearson correlation for a paired (x, y) stream.
Both classes:
- keep internal state in float64 regardless of input dtype, so the
classic E[X²] − E[X]² cancellation does not bite at large means;
- update incrementally via Welford's online algorithm;
- expose merge() implementing the Chan / Welford parallel formula,
suitable for cross-rank distributed reduction or any other case
where two accumulators need to be combined without re-iterating
the raw data.
This is PR 1 of the plan in pytorch#3748. Follow-ups:
PR 2 will port R2Score (pytorch#3662-style regression test attached).
PR 3 will refactor pytorch#3741 to consume WelfordCovariance instead of
its current inline Welford state.
Tests (20 total, all passing):
- per-class correctness vs numpy mean / var / cov / corrcoef
- multi-batch update matches single-batch update
- merge matches concatenated update
- merge with empty accumulators on either side
- numerical-stability regression (mean=1e6 in float32) for both
classes, with an assertion that the naive float32 formula
actually does fail on the same data so the test documents what
we're protecting against
- shape-mismatch raises ValueError
- empty-batch update is a no-op
- reset clears state
- input dtypes (int32) upcast to float64 correctly
- cross-class sanity: WelfordCovariance.variance_x matches
WelfordVariance.variance fed the same x
| mean: torch.Tensor | ||
| sum_sq_dev_from_mean: torch.Tensor | ||
|
|
||
| def __init__(self, device: Union[str, torch.device] = "cpu") -> None: |
There was a problem hiding this comment.
This class should not handle device, neither dtype
| import torch | ||
|
|
||
|
|
||
| class WelfordVariance: |
There was a problem hiding this comment.
Let's make it as a dataclass?
| """ | ||
| if batch.numel() == 0: | ||
| return | ||
| batch64 = batch.detach().to(dtype=torch.float64).flatten() |
There was a problem hiding this comment.
I do not think we should flatten it and for the mean computation we may need to specify the axis (to confirm)
Addressing @vfdev-5's inline review on pytorch#3750: - Drop the device and dtype constructor args. The helper now leaves placement and precision to the caller; state takes the dtype and device of the first batch passed to update(). PearsonCorrelation and R2Score already do their own float64 upcast before handing inputs to the helper, so this is a no-op for the planned consumers. - Switch both classes to @DataClass with field(default_factory=...) for the tensor fields. Drops the manual __init__ / reset() plumbing; "reset" is now reconstruction (m.welford = WelfordVariance()), which is the natural fit for how the consumer Metric.reset() methods already work. - Drop the explicit .flatten() on update inputs. batch.mean() and batch.numel() both reduce over the full tensor regardless of shape, so behavior for the current scalar-reduction consumers is unchanged, and the code reads more naturally for any shape. Tests adjusted accordingly: - test_reset replaced by test_fresh_instance_has_zero_state, which documents the dataclass default-factory behavior. - test_input_dtype_upcast_to_float64 replaced by test_state_dtype_follows_first_batch, which verifies dtype is preserved (the design change). - Stability tests upcast inputs caller-side before handing to the helper, matching how the metric classes will use it. - test_multi_batch_matches_single_batch switched to float64 inputs so it exercises the algorithm rather than float32 noise. All 20 tests still pass, ruff format / check clean. The question about axis-aware reduction is deferred to the review thread; I'll follow it once @vfdev-5 confirms whether it lands here or as a follow-up.
|
Thanks for the review. Pushed 1b82442 addressing all three points. 1. No device / dtype in the helper. Dropped the constructor args and all internal 2. Dataclasses. Both classes are now 3. Flatten. Removed the explicit On axis-aware reduction: the current consumers (R2Score, PearsonCorrelation) both produce a single scalar variance / covariance, so they only need the full-reduction behavior. If we want axis support for future per-channel use cases (running variance over def update(self, batch: torch.Tensor, dim: Optional[Union[int, tuple[int, ...]]] = None) -> None:
...
n_b = batch.numel() if dim is None else _samples_along(batch.shape, dim)
mean_b = batch.mean(dim=dim, keepdim=False) if dim is not None else batch.mean()
...The state on first update would take the shape of Two ways we can take it:
I lean toward (b): the consumers don't need it, axis-aware running stats have a few subtleties (Bessel correction conventions, mixed-shape merge, axis reordering) that are easier to design alongside a concrete user. Happy to go either way; let me know which you prefer. |
|
@joemunene-by hi, I was trying to review but I find the code to be quite confusing I think we can try to make more readable :) |
| def _zero() -> torch.Tensor: | ||
| return torch.tensor(0.0) |
There was a problem hiding this comment.
why we writing a function for a code which is literally 1 line?
| if self.n_samples == 0: | ||
| self.mean = mean_b | ||
| self.sum_sq_dev_from_mean = m2_b | ||
| self.n_samples = n_b | ||
| return | ||
|
|
||
| n_a = self.n_samples | ||
| n_ab = n_a + n_b | ||
| delta = mean_b - self.mean | ||
| self.mean = self.mean + delta * n_b / n_ab | ||
| self.sum_sq_dev_from_mean = self.sum_sq_dev_from_mean + m2_b + delta * delta * n_a * n_b / n_ab | ||
| self.n_samples = n_ab |
There was a problem hiding this comment.
This part is redundant as merge is already doing it we can just use that.
| self.sum_sq_dev_from_mean = self.sum_sq_dev_from_mean + m2_b + delta * delta * n_a * n_b / n_ab | ||
| self.n_samples = n_ab | ||
|
|
||
| def merge(self, other: "WelfordVariance") -> None: |
There was a problem hiding this comment.
Why do we need merge?
| if self.n_samples == 0: | ||
| self.mean_x = mean_x_b | ||
| self.mean_y = mean_y_b | ||
| self.sum_sq_dev_x = m2_x_b | ||
| self.sum_sq_dev_y = m2_y_b | ||
| self.sum_product_of_devs = cxy_b | ||
| self.n_samples = n_b | ||
| return | ||
|
|
||
| n_a = self.n_samples | ||
| n_ab = n_a + n_b | ||
| cross = n_a * n_b / n_ab | ||
| delta_x = mean_x_b - self.mean_x | ||
| delta_y = mean_y_b - self.mean_y | ||
|
|
||
| self.mean_x = self.mean_x + delta_x * n_b / n_ab | ||
| self.mean_y = self.mean_y + delta_y * n_b / n_ab | ||
| self.sum_sq_dev_x = self.sum_sq_dev_x + m2_x_b + delta_x * delta_x * cross | ||
| self.sum_sq_dev_y = self.sum_sq_dev_y + m2_y_b + delta_y * delta_y * cross | ||
| self.sum_product_of_devs = self.sum_product_of_devs + cxy_b + delta_x * delta_y * cross |
Three readability changes responding to @aaishwarymishra's inline review: 1. _zero() helper removed; each tensor field uses field(default_factory=lambda: torch.tensor(0.0)) directly. 2. update() is now the degenerate case of merge() where "other" is a freshly-built single-batch accumulator. The Chan / Welford parallel formula lives in exactly one place. Same refactor applied to WelfordCovariance.update. 3. merge() docstring grew a paragraph explaining the distributed- reduction motivation -- without an explicit merge, cross-rank reduction has to re-iterate the raw data, which defeats the point of an online algorithm. Behavior is bit-equivalent: the only delta is an extra detach/clone on the first-batch path (via merge's first-time-absorb branch), which is a no-op for correctness.
|
@aaishwarymishra thanks, your inline points are all fair. Pushed e153f57 addressing each. 1. 2. The mental model that fell out of this is worth stating: 3. "Why do we need The bonus is that the two-path design ( cc @vfdev-5 — points from your earlier review are still addressed in 1b82442; this commit is a follow-up to @aaishwarymishra's readability pass. Ready for another look when you have a moment. |
There was a problem hiding this comment.
Pull request overview
Adds internal running-statistics primitives intended to centralize numerically stable variance/covariance accumulation for future metric refactors (per #3748), along with a dedicated unit-test suite to validate update/merge correctness and numerical-stability behavior.
Changes:
- Introduce
WelfordVarianceandWelfordCovariancehelpers implementing Welford/Chan online + merge formulas. - Add unit tests covering empty/single/multi-batch updates, merges, and large-mean numerical-stability regressions.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
ignite/metrics/_running_stats.py |
New internal Welford-based running variance/covariance accumulators with merge support. |
tests/ignite/metrics/test_running_stats.py |
New test suite validating correctness, merge equivalence, and numerical-stability scenarios. |
Comments suppressed due to low confidence (1)
ignite/metrics/_running_stats.py:224
- merge() is not wrapped in torch.no_grad() and does not defensively detach incoming tensors unless self is empty. If callers merge state that still requires grad, this can create an autograd graph. Consider adding @torch.no_grad() on merge() and/or detaching other.mean_x/mean_y and the sum_* tensors at the start of the method.
def merge(self, other: "WelfordCovariance") -> None:
"""Combine ``other`` into ``self`` using the Chan / Welford parallel formula.
Same correction term as the univariate version, applied three
times: once for ``sum_sq_dev_x``, once for ``sum_sq_dev_y``, and
once for ``sum_product_of_devs`` (using ``delta_x * delta_y``
instead of ``delta * delta``). See
:meth:`WelfordVariance.merge` for the derivation.
"""
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Shared by metrics that need to accumulate variance / covariance from | ||
| streaming batches without falling into the catastrophic-cancellation | ||
| trap of the naive ``E[X^2] - E[X]^2`` formula. Used by |
| Both classes are tensor-type-agnostic dataclasses: callers supply | ||
| tensors in whatever dtype and device they want, and the helper | ||
| preserves both. For numerical stability under large means, callers | ||
| should pre-cast inputs to ``float64`` (the consumer metric classes | ||
| already do this in their own ``update`` methods). |
| def merge(self, other: "WelfordVariance") -> None: | ||
| """Combine ``other`` into ``self`` using the Chan / Welford parallel formula. | ||
|
|
||
| Used in two places: by :meth:`update` (where ``other`` is a | ||
| freshly-built single-batch accumulator), and by callers that | ||
| need to combine independently-accumulated state from elsewhere. | ||
| The motivating second case is distributed training: each rank |
| # The correction-term coefficient n_a * n_b / n_ab shows up in | ||
| # every parallel-formula line below; compute it once. | ||
| cross_coef = n_a * n_b / n_ab | ||
| delta_x = other.mean_x - self.mean_x | ||
| delta_y = other.mean_y - self.mean_y | ||
|
|
…ined cross-coef Four points from the Copilot auto-review on pytorch#3750: 1. Module docstring claimed the helpers are *used by* PearsonCorrelation and R2Score, but those metrics are wired up in follow-up PRs and still use their own running-sum state at HEAD. Reworded to "intended consumers (in follow-up PRs of pytorch#3748)" so the doc matches reality at this commit. 2. Same docstring claimed internal state is kept in float64 regardless of input dtype, contradicting the actual implementation (dtype- and device-agnostic; caller supplies the float64 cast when stability matters). Vfdev-5's earlier review specifically asked for this contract — aligning the prose with the code so users don't get a false sense of safety on float32 inputs. 3. WelfordVariance.merge and WelfordCovariance.merge are now both wrapped in `@torch.no_grad()` (mirroring update). Without it, a caller that merges an accumulator whose tensors still require grad would build an autograd graph and leak memory across the lifetime of the metric. Belt to the existing detach/clone-on-first-time- absorb suspenders. 4. WelfordCovariance.merge precomputed `cross_coef = n_a * n_b / n_ab` as a Python float and reused it three times. Dropped the temporary and inlined `n_a * n_b / n_ab` directly into each parallel-formula line — mirrors WelfordVariance.merge's existing style and keeps the arithmetic on the same dtype/device as the tensor operands rather than promoting through Python scalar land. 20 / 20 tests still passing.
|
Copilot's auto-review surfaced four concrete points — pushed 89d5b36 addressing each:
20 / 20 tests still passing locally. |
Summary
PR 1 of the plan in #3748: introduces
ignite/metrics/_running_stats.pywith two numerically stable running-statistics primitives that variance- and covariance-bearing metrics can share, instead of each one rolling its own naiveΣx² − (Σx)²/nimplementation.WelfordVariance: runningmean,variance,stdfor a single variable.WelfordCovariance: runningvariance_x,variance_y,covariance, and Pearsoncorrelation()for a paired(x, y)stream.Both classes:
float64regardless of input dtype, so the classicE[X²] − E[X]²cancellation does not bite at large means (the failure mode from [Bug] Numerical instability inPearsonCorrelationdue to naive variance formula #3662);merge()implementing the Chan / Welford parallel formula, suitable for cross-rank distributed reduction or any other case where two accumulators need to be combined without re-iterating the raw data.No existing metric is touched in this PR. The helper lands first so the two consumer PRs can each diff against a stable shared API.
Follow-ups (already scoped in #3748)
R2ScoretoWelfordVariance, with a regression test mirroring [Bug] Numerical instability inPearsonCorrelationdue to naive variance formula #3662'smean=1e6, std=1failure case.WelfordCovarianceinstead. Pure refactor on top of the existing review.API
Tests
20 unit tests in
tests/ignite/metrics/test_running_stats.py, all passing locally withpytest tests/ignite/metrics/test_running_stats.py:WelfordVariance
numpy.meanandnumpy.varto 1e-12mergeof two accumulators matchesupdateon the concatenated datamergewith an empty accumulator on either side is a no-op or absorbs the other sidemean=1e6, std=1in float32; Welford in float64 recovers the true variance, and the same data through the naiveΣx² − (Σx)²/nformula in float32 is asserted to fail by ≥ 0.1, so the test documents the failure mode it is protecting againstupdateis a no-opresetclears stateWelfordCovariance
numpy.cov(..., bias=True)andnumpy.corrcoefto 1e-10 to 1e-12mergematches concatenatedupdatemean=1e6, true correlation> 0.99; Welford recoversrwithin 1e-4 of the float64 ground truthValueErrorr = 0.0(not NaN) via theepsclampresetclears stateCross-class sanity
WelfordCovariance.variance_xmatchesWelfordVariance.variancefed the samex. Catches future drift between the two implementations.Conventions
ruff formatclean,ruff checkclean.torch.Tensoron the user-supplied device, mirroring the dtype / device convention used elsewhere inignite.metrics._running_stats.py(leading underscore) and not re-exported fromignite/metrics/__init__.py, so it is internal to the metrics module. Public access stays through individual metric classes; the helper can graduate to a public API later if there is demand.Test plan
pytest tests/ignite/metrics/test_running_stats.py: 20 / 20 passing locallyruff format --check,ruff checkcc @vfdev-5. Opening PR 2 (R2Score port) and PR 3 (PearsonCorrelation refactor on top of #3741) once this lands.