diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 4c87359..ce11363 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -19,6 +19,7 @@ def condition_equal_rows( columns: list[str], schema_left: pl.Schema, schema_right: pl.Schema, + max_list_lengths_by_column: Mapping[str, int], abs_tol_by_column: Mapping[str, float], rel_tol_by_column: Mapping[str, float], abs_tol_temporal_by_column: Mapping[str, dt.timedelta], @@ -34,6 +35,7 @@ def condition_equal_rows( column=column, dtype_left=schema_left[column], dtype_right=schema_right[column], + max_list_length=max_list_lengths_by_column.get(column), abs_tol=abs_tol_by_column[column], rel_tol=rel_tol_by_column[column], abs_tol_temporal=abs_tol_temporal_by_column[column], @@ -47,6 +49,7 @@ def condition_equal_columns( column: str, dtype_left: pl.DataType, dtype_right: pl.DataType, + max_list_length: int | None, abs_tol: float = ABS_TOL_DEFAULT, rel_tol: float = REL_TOL_DEFAULT, abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT, @@ -58,6 +61,7 @@ def condition_equal_columns( col_right=pl.col(f"{column}_{Side.RIGHT}"), dtype_left=dtype_left, dtype_right=dtype_right, + max_list_length=max_list_length, abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, @@ -92,6 +96,7 @@ def _compare_columns( col_right: pl.Expr, dtype_left: DataType | DataTypeClass, dtype_right: DataType | DataTypeClass, + max_list_length: int | None, abs_tol: float, rel_tol: float, abs_tol_temporal: dt.timedelta, @@ -123,6 +128,7 @@ def _compare_columns( col_right=col_right.struct[field], dtype_left=fields_left[field], dtype_right=fields_right[field], + max_list_length=max_list_length, abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, @@ -133,10 +139,16 @@ def _compare_columns( elif isinstance(dtype_left, pl.List | pl.Array) and isinstance( dtype_right, pl.List | pl.Array ): - # As of polars 1.28, there is no way to access another column within - # `list.eval`. Hence, we necessarily need to resort to a primitive - # comparison in this case. - pass + return _compare_sequence_columns( + col_left=col_left, + col_right=col_right, + dtype_left=dtype_left, + dtype_right=dtype_right, + max_list_length=max_list_length, + abs_tol=abs_tol, + rel_tol=rel_tol, + abs_tol_temporal=abs_tol_temporal, + ) if ( isinstance(dtype_left, pl.Enum) @@ -154,6 +166,7 @@ def _compare_columns( abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, + max_list_length=max_list_length, ) return _compare_primitive_columns( @@ -167,6 +180,72 @@ def _compare_columns( ) +def _compare_sequence_columns( + col_left: pl.Expr, + col_right: pl.Expr, + dtype_left: DataType | DataTypeClass, + dtype_right: DataType | DataTypeClass, + max_list_length: int | None, + abs_tol: float, + rel_tol: float, + abs_tol_temporal: dt.timedelta, +) -> pl.Expr: + """Compare Array/List columns element-wise with tolerance.""" + assert isinstance(dtype_left, pl.List | pl.Array) + assert isinstance(dtype_right, pl.List | pl.Array) + inner_left = dtype_left.inner + inner_right = dtype_right.inner + + n_elements: int + has_same_length: pl.Expr + + if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array): + if dtype_left.shape != dtype_right.shape: + return pl.repeat(pl.lit(False), pl.len()) + n_elements = dtype_left.shape[0] + has_same_length = pl.repeat(pl.lit(True), pl.len()) + elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List): + n_elements = dtype_left.shape[0] + has_same_length = col_right.list.len().eq(pl.lit(n_elements)) + elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array): + n_elements = dtype_right.shape[0] + has_same_length = col_left.list.len().eq(pl.lit(n_elements)) + else: # pl.List vs pl.List + if not isinstance(max_list_length, int): + # Fallback for nested list comparisons where no max_list_length is + # available: perform a direct equality comparison without element-wise + # unrolling. + return _eq_missing(col_left.eq_missing(col_right), col_left, col_right) + n_elements = max_list_length + has_same_length = col_left.list.len().eq_missing(col_right.list.len()) + + if n_elements == 0: + return _eq_missing(pl.lit(True), col_left, col_right) + + def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Expr: + if isinstance(dtype, pl.Array): + return col.arr.get(i) + return col.list.get(i, null_on_oob=True) + + elements_match = pl.all_horizontal( + [ + _compare_columns( + col_left=_get_element(col_left, dtype_left, i), + col_right=_get_element(col_right, dtype_right, i), + dtype_left=inner_left, + dtype_right=inner_right, + abs_tol=abs_tol, + rel_tol=rel_tol, + abs_tol_temporal=abs_tol_temporal, + max_list_length=None, + ) + for i in range(n_elements) + ] + ) + + return _eq_missing(has_same_length & elements_match, col_left, col_right) + + def _compare_primitive_columns( col_left: pl.Expr, col_right: pl.Expr, diff --git a/diffly/comparison.py b/diffly/comparison.py index ca126df..ee46e55 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -508,6 +508,7 @@ def equal(self, *, check_dtypes: bool = True) -> bool: columns=common_columns, schema_left=self.left_schema, schema_right=self.right_schema, + max_list_lengths_by_column=self._max_list_lengths_by_column, abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, @@ -708,11 +709,32 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str] raise ValueError(f"{difference} are not common columns.") return list(subset) + @cached_property + def _max_list_lengths_by_column(self) -> dict[str, int]: + list_columns = [ + col + for col in self._other_common_columns + if isinstance(self.left_schema[col], pl.List) + and isinstance(self.right_schema[col], pl.List) + ] + if not list_columns: + return {} + + exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns] + [left_max, right_max] = pl.collect_all( + [self.left.select(exprs), self.right.select(exprs)] + ) + return { + col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0)) + for col in list_columns + } + def _condition_equal_rows(self, columns: list[str]) -> pl.Expr: return condition_equal_rows( columns=columns, schema_left=self.left_schema, schema_right=self.right_schema, + max_list_lengths_by_column=self._max_list_lengths_by_column, abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, @@ -726,6 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr: abs_tol=self.abs_tol_by_column[column], rel_tol=self.rel_tol_by_column[column], abs_tol_temporal=self.abs_tol_temporal_by_column[column], + max_list_length=self._max_list_lengths_by_column.get(column), ) def _equal_rows(self) -> bool: diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 772e61c..9093ee6 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -32,6 +32,7 @@ def test_condition_equal_columns_struct() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=None, abs_tol=0.5, rel_tol=0, ) @@ -66,6 +67,7 @@ def test_condition_equal_columns_different_struct_fields() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=None, ) ) .to_series() @@ -81,21 +83,21 @@ def test_condition_equal_columns_different_struct_fields() -> None: @pytest.mark.parametrize( "rhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)] ) -def test_condition_equal_columns_list_array_equal_exact( +def test_condition_equal_columns_list_array_with_tolerance( lhs_type: pl.DataType, rhs_type: pl.DataType ) -> None: # Arrange lhs = pl.DataFrame( { - "pk": [1, 2], - "a_left": [[1.0, 1.1], [2.0, 2.1]], + "pk": [1, 2, 3], + "a_left": [[1.0, 1.1], [2.0, 2.1], [3.0, 3.0]], }, schema={"pk": pl.Int64, "a_left": lhs_type}, ) rhs = pl.DataFrame( { - "pk": [1, 2], - "a_right": [[1.0, 1.1], [2.0, 2.2]], + "pk": [1, 2, 3], + "a_right": [[1.0, 1.1], [2.0, 2.2], [3.0, 3.7]], }, schema={"pk": pl.Int64, "a_right": rhs_type}, ) @@ -110,13 +112,70 @@ def test_condition_equal_columns_list_array_equal_exact( dtype_right=rhs.schema["a_right"], abs_tol=0.5, rel_tol=0, + max_list_length=2, ) ) .to_series() ) - # Assert - assert actual.to_list() == [True, False] + assert actual.to_list() == [True, True, False] + + +@pytest.mark.parametrize( + "lhs_type", + [pl.Array(pl.Float64, shape=(2, 3)), pl.List(pl.List(pl.Float64))], +) +@pytest.mark.parametrize( + "rhs_type", + [pl.Array(pl.Float64, shape=(2, 3)), pl.List(pl.List(pl.Float64))], +) +def test_condition_equal_columns_nested_list_array_with_tolerance( + lhs_type: pl.DataType, rhs_type: pl.DataType +) -> None: + # Arrange + lhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_left": [ + [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], + [[3.0, 3.0, 3.1], [4.0, 4.0, 4.1]], + [[5.0, 5.0, 5.1], [6.0, 6.0, 6.1]], + ], + }, + schema={"pk": pl.Int64, "a_left": lhs_type}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_right": [ + [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], + [[3.0, 3.0, 3.1], [4.0, 4.4, 4.1]], + [[5.0, 5.0, 5.1], [6.0, 6.8, 6.1]], + ], + }, + schema={"pk": pl.Int64, "a_right": rhs_type}, + ) + + # Act + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + abs_tol=0.5, + rel_tol=0, + max_list_length=2, + ) + ) + .to_series() + ) + + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): + assert actual.to_list() == [True, False, False] + else: + assert actual.to_list() == [True, True, False] def test_condition_equal_columns_nested_dtype_mismatch() -> None: @@ -142,6 +201,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=None, ) ) .to_series() @@ -174,6 +234,7 @@ def test_condition_equal_columns_exactly_one_nested() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=None, ) ) .to_series() @@ -216,6 +277,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=None, abs_tol_temporal=dt.timedelta(seconds=2), ) ) @@ -226,6 +288,162 @@ def test_condition_equal_columns_temporal_tolerance() -> None: assert actual.to_list() == [True, False, False, True] +def test_condition_equal_columns_two_lists() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2, 3, 4, 5], + "a_left": [[1.0, 2.0], [3.0], [5.0, None], None, None], + }, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2, 3, 4, 5], + "a_right": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0], None], + }, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + abs_tol=0.5, + rel_tol=0, + max_list_length=2, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, False, False, False, True] + + +def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [[1.0, 2.0], [3.0, 4.0]], + }, + schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [[1.0, 2.0], [3.0]], + }, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=None, + abs_tol=0.5, + rel_tol=0, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, False] + + +def test_condition_equal_columns_two_arrays_different_shapes() -> None: + lhs = pl.DataFrame( + { + "pk": [1], + "a_left": [[1.0, 2.0]], + }, + schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, + ) + rhs = pl.DataFrame( + { + "pk": [1], + "a_right": [[1.0, 2.0, 3.0]], + }, + schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=3)}, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=None, + ) + ) + .to_series() + ) + assert actual.to_list() == [False] + + +def test_condition_equal_columns_empty_arrays() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [[], None], + }, + schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=0)}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [[], None], + }, + schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=0)}, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=None, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, True] + + +def test_condition_equal_columns_empty_lists() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_left": [[], None, []], + }, + schema={"pk": pl.Int64, "a_left": pl.List(pl.Float64)}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_right": [[], None, None], + }, + schema={"pk": pl.Int64, "a_right": pl.List(pl.Float64)}, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=0, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, True, False] + + @pytest.mark.parametrize( ("dtype_left", "dtype_right", "can_compare_dtypes"), [