diff --git a/pointblank/_agg.py b/pointblank/_agg.py index 870dff8f2..e3377cd12 100644 --- a/pointblank/_agg.py +++ b/pointblank/_agg.py @@ -76,6 +76,20 @@ def _generic_between(real: Any, lower: Any, upper: Any) -> bool: return bool(lower <= real <= upper) +def split_agg_name(name: str) -> tuple[str, str]: + """Split an aggregation method name into aggregator and comparator names. + + Args: + name (str): The aggregation method name (e.g., "col_sum_eq" or "sum_eq"). + + Returns: + tuple[str, str]: A tuple of (agg_name, comp_name) e.g., ("sum", "eq"). + """ + name = name.removeprefix("col_") + agg_name, comp_name = name.rsplit("_", 1) + return agg_name, comp_name + + def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]: """Resolve the assertion name to a valid aggregator @@ -85,8 +99,7 @@ def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]: Returns: tuple[Aggregator, Comparator]: The aggregator and comparator functions. """ - name = name.removeprefix("col_") - agg_name, comp_name = name.split("_")[-2:] + agg_name, comp_name = split_agg_name(name) aggregator = AGGREGATOR_REGISTRY.get(agg_name) comparator = COMPARATOR_REGISTRY.get(comp_name) diff --git a/pointblank/validate.py b/pointblank/validate.py index 45d7375dc..bbc533a90 100644 --- a/pointblank/validate.py +++ b/pointblank/validate.py @@ -27,7 +27,12 @@ from great_tables.vals import fmt_integer, fmt_number from importlib_resources import files -from pointblank._agg import is_valid_agg, load_validation_method_grid, resolve_agg_registries +from pointblank._agg import ( + is_valid_agg, + load_validation_method_grid, + resolve_agg_registries, + split_agg_name, +) from pointblank._constants import ( ASSERTION_TYPE_METHOD_MAP, CHECK_MARK_SPAN, @@ -18869,6 +18874,15 @@ def _create_autobrief_or_failure_text( for_failure=for_failure, ) + if is_valid_agg(assertion_type): + return _create_text_agg( + lang=lang, + assertion_type=assertion_type, + column=column, + values=values, + for_failure=for_failure, + ) + return None @@ -18903,6 +18917,52 @@ def _create_text_comparison( ) +def _create_text_agg( + lang: str, + assertion_type: str, + column: str | list[str], + values: dict[str, Any], + for_failure: bool = False, +) -> str: + """Create autobrief text for aggregation methods like col_sum_eq, col_avg_gt, etc.""" + type_ = _expect_failure_type(for_failure=for_failure) + + agg_type, comp_type = split_agg_name(assertion_type) + + # this is covered by the test `test_brief_auto_all_agg_methods` to make sure we don't + # create any weird secret agg constants. + agg_display_names: dict[str, str] = { + "sum": "sum", + "avg": "average", + "sd": "standard deviation", + } + try: + agg_display: str = agg_display_names[agg_type] + except KeyError as ke: # pragma: no cover + raise AssertionError from ke # This should never happen in prod, it's caught in CI. + + # Get the operator + comparison_assertion = f"col_vals_{comp_type}" + if lang == "ar": # pragma: no cover + operator = COMPARISON_OPERATORS_AR.get(comparison_assertion, comp_type) + else: + operator = COMPARISON_OPERATORS.get(comparison_assertion, comp_type) + + column_text = _prep_column_text(column=column) + + value = values.get("value", values) if isinstance(values, dict) else values + values_text = _prep_values_text(values=str(value), lang=lang, limit=3) + + # "Expect that the {agg} of {column} should be {operator} {value}." + agg_expectation_text = EXPECT_FAIL_TEXT[f"compare_{type_}_text"][lang] + + return agg_expectation_text.format( + column_text=f"the {agg_display} of {column_text}", + operator=operator, + values_text=values_text, + ) + + def _create_text_between( lang: str, column: str, diff --git a/tests/test_agg.py b/tests/test_agg.py index 95513e5c8..4e788b672 100644 --- a/tests/test_agg.py +++ b/tests/test_agg.py @@ -1139,3 +1139,99 @@ def test_agg_report_multiple_steps_formatting(): # Step 3: Value with asymmetric tolerance assert "2.0
tol=(0.1, 0.2)" in html + + +def test_brief_auto(): + """Test that auto briefs are generated correctly for aggregation methods.""" + data = pl.DataFrame({"amount": [100, 200, 300]}) + + validation = Validate(data).col_sum_gt(columns="amount", value=500, brief=True).interrogate() + + # Check that brief is set to auto template + assert validation.validation_info[0].brief == "{auto}" + + # Check that the HTML report generates auto brief text + html = validation.get_tabular_report().as_raw_html() + assert html is not None + assert len(html) > 0 + + # Auto brief should contain references to the aggregation type and column + # Should mention "sum" and "amount" and the comparison + assert "amount" in html + assert "sum" in html.lower() + + +def test_brief_custom(): + """Test that custom briefs are stored and displayed correctly.""" + data = pl.DataFrame({"sales": [1000, 2000, 3000]}) + + custom_brief = "Validating that total sales exceeds minimum threshold" + + validation = ( + Validate(data).col_avg_eq(columns="sales", value=2000, brief=custom_brief).interrogate() + ) + + # Check that custom brief is stored + assert validation.validation_info[0].brief == custom_brief + + # Check that custom brief appears in HTML report + html = validation.get_tabular_report().as_raw_html() + assert custom_brief in html + + +def test_brief_mixed(): + """Test mixing custom and auto brief templates across multiple validation steps.""" + data = pl.DataFrame({"value_a": [10, 20, 30], "value_b": [100, 200, 300]}) + + custom_brief_1 = "First check: sum validation" + + validation = ( + Validate(data) + .col_sum_gt(columns="value_a", value=50, brief=custom_brief_1) + .col_avg_lt(columns="value_b", value=400, brief=True) # auto brief template + .interrogate() + ) + + # First step should have custom brief + assert validation.validation_info[0].brief == custom_brief_1 + + # Second step should have auto brief template + assert validation.validation_info[1].brief == "{auto}" + + # Both should appear in HTML report + html = validation.get_tabular_report().as_raw_html() + assert custom_brief_1 in html + # Auto brief should mention the column and aggregation type + assert "value_b" in html + assert "average" in html.lower() + + +@pytest.mark.parametrize("method", load_validation_method_grid()) +def test_brief_auto_all_agg_methods(method: str): + """Test that auto briefs are generated for all aggregation methods. + + This ensures that the agg_display_names mapping in _create_text_agg + has coverage for all aggregation types (sum, avg, sd). + """ + from pointblank._agg import split_agg_name + + data = pl.DataFrame({"col": [10.0, 20.0, 30.0, 40.0, 50.0]}) + + v = Validate(data) + v = getattr(v, method)(columns="col", value=100, brief=True) + v = v.interrogate() + + assert v.validation_info[0].brief == "{auto}" + + html = v.get_tabular_report().as_raw_html() + assert html is not None + assert len(html) > 0 + assert "col" in html + + # Extract agg type from method name to verify it appears in report + # e.g., "col_sum_eq" -> agg_type="sum" + agg_type, _ = split_agg_name(method) + agg_display_map = {"sum": "sum", "avg": "average", "sd": "standard deviation"} + agg_display = agg_display_map.get(agg_type, agg_type) + + assert agg_display.lower() in html.lower()