Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions pointblank/_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,20 @@ def _generic_between(real: Any, lower: Any, upper: Any) -> bool:
return bool(lower <= real <= upper)


def split_agg_name(name: str) -> tuple[str, str]:
"""Split an aggregation method name into aggregator and comparator names.

Args:
name (str): The aggregation method name (e.g., "col_sum_eq" or "sum_eq").

Returns:
tuple[str, str]: A tuple of (agg_name, comp_name) e.g., ("sum", "eq").
"""
name = name.removeprefix("col_")
agg_name, comp_name = name.rsplit("_", 1)
return agg_name, comp_name


def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]:
"""Resolve the assertion name to a valid aggregator

Expand All @@ -85,8 +99,7 @@ def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]:
Returns:
tuple[Aggregator, Comparator]: The aggregator and comparator functions.
"""
name = name.removeprefix("col_")
agg_name, comp_name = name.split("_")[-2:]
agg_name, comp_name = split_agg_name(name)

aggregator = AGGREGATOR_REGISTRY.get(agg_name)
comparator = COMPARATOR_REGISTRY.get(comp_name)
Expand Down
62 changes: 61 additions & 1 deletion pointblank/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from great_tables.vals import fmt_integer, fmt_number
from importlib_resources import files

from pointblank._agg import is_valid_agg, load_validation_method_grid, resolve_agg_registries
from pointblank._agg import (
is_valid_agg,
load_validation_method_grid,
resolve_agg_registries,
split_agg_name,
)
from pointblank._constants import (
ASSERTION_TYPE_METHOD_MAP,
CHECK_MARK_SPAN,
Expand Down Expand Up @@ -18869,6 +18874,15 @@ def _create_autobrief_or_failure_text(
for_failure=for_failure,
)

if is_valid_agg(assertion_type):
return _create_text_agg(
lang=lang,
assertion_type=assertion_type,
column=column,
values=values,
for_failure=for_failure,
)

return None


Expand Down Expand Up @@ -18903,6 +18917,52 @@ def _create_text_comparison(
)


def _create_text_agg(
lang: str,
assertion_type: str,
column: str | list[str],
values: dict[str, Any],
for_failure: bool = False,
) -> str:
"""Create autobrief text for aggregation methods like col_sum_eq, col_avg_gt, etc."""
type_ = _expect_failure_type(for_failure=for_failure)

agg_type, comp_type = split_agg_name(assertion_type)

# this is covered by the test `test_brief_auto_all_agg_methods` to make sure we don't
# create any weird secret agg constants.
agg_display_names: dict[str, str] = {
"sum": "sum",
"avg": "average",
"sd": "standard deviation",
}
try:
agg_display: str = agg_display_names[agg_type]
except KeyError as ke: # pragma: no cover
raise AssertionError from ke # This should never happen in prod, it's caught in CI.

# Get the operator
comparison_assertion = f"col_vals_{comp_type}"
if lang == "ar": # pragma: no cover
operator = COMPARISON_OPERATORS_AR.get(comparison_assertion, comp_type)
else:
operator = COMPARISON_OPERATORS.get(comparison_assertion, comp_type)

column_text = _prep_column_text(column=column)

value = values.get("value", values) if isinstance(values, dict) else values
values_text = _prep_values_text(values=str(value), lang=lang, limit=3)

# "Expect that the {agg} of {column} should be {operator} {value}."
agg_expectation_text = EXPECT_FAIL_TEXT[f"compare_{type_}_text"][lang]

return agg_expectation_text.format(
column_text=f"the {agg_display} of {column_text}",
operator=operator,
values_text=values_text,
)


def _create_text_between(
lang: str,
column: str,
Expand Down
96 changes: 96 additions & 0 deletions tests/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,3 +1139,99 @@ def test_agg_report_multiple_steps_formatting():

# Step 3: Value with asymmetric tolerance
assert "2.0<br/>tol=(0.1, 0.2)" in html


def test_brief_auto():
"""Test that auto briefs are generated correctly for aggregation methods."""
data = pl.DataFrame({"amount": [100, 200, 300]})

validation = Validate(data).col_sum_gt(columns="amount", value=500, brief=True).interrogate()

# Check that brief is set to auto template
assert validation.validation_info[0].brief == "{auto}"

# Check that the HTML report generates auto brief text
html = validation.get_tabular_report().as_raw_html()
assert html is not None
assert len(html) > 0

# Auto brief should contain references to the aggregation type and column
# Should mention "sum" and "amount" and the comparison
assert "amount" in html
assert "sum" in html.lower()


def test_brief_custom():
"""Test that custom briefs are stored and displayed correctly."""
data = pl.DataFrame({"sales": [1000, 2000, 3000]})

custom_brief = "Validating that total sales exceeds minimum threshold"

validation = (
Validate(data).col_avg_eq(columns="sales", value=2000, brief=custom_brief).interrogate()
)

# Check that custom brief is stored
assert validation.validation_info[0].brief == custom_brief

# Check that custom brief appears in HTML report
html = validation.get_tabular_report().as_raw_html()
assert custom_brief in html


def test_brief_mixed():
"""Test mixing custom and auto brief templates across multiple validation steps."""
data = pl.DataFrame({"value_a": [10, 20, 30], "value_b": [100, 200, 300]})

custom_brief_1 = "First check: sum validation"

validation = (
Validate(data)
.col_sum_gt(columns="value_a", value=50, brief=custom_brief_1)
.col_avg_lt(columns="value_b", value=400, brief=True) # auto brief template
.interrogate()
)

# First step should have custom brief
assert validation.validation_info[0].brief == custom_brief_1

# Second step should have auto brief template
assert validation.validation_info[1].brief == "{auto}"

# Both should appear in HTML report
html = validation.get_tabular_report().as_raw_html()
assert custom_brief_1 in html
# Auto brief should mention the column and aggregation type
assert "value_b" in html
assert "average" in html.lower()


@pytest.mark.parametrize("method", load_validation_method_grid())
def test_brief_auto_all_agg_methods(method: str):
"""Test that auto briefs are generated for all aggregation methods.

This ensures that the agg_display_names mapping in _create_text_agg
has coverage for all aggregation types (sum, avg, sd).
"""
from pointblank._agg import split_agg_name

data = pl.DataFrame({"col": [10.0, 20.0, 30.0, 40.0, 50.0]})

v = Validate(data)
v = getattr(v, method)(columns="col", value=100, brief=True)
v = v.interrogate()

assert v.validation_info[0].brief == "{auto}"

html = v.get_tabular_report().as_raw_html()
assert html is not None
assert len(html) > 0
assert "col" in html

# Extract agg type from method name to verify it appears in report
# e.g., "col_sum_eq" -> agg_type="sum"
agg_type, _ = split_agg_name(method)
agg_display_map = {"sum": "sum", "avg": "average", "sd": "standard deviation"}
agg_display = agg_display_map.get(agg_type, agg_type)

assert agg_display.lower() in html.lower()
Loading