From 5ec11f729b19cb54ffdcf88595bcacf6066b9f38 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Thu, 19 Mar 2026 20:20:03 +0100 Subject: [PATCH 1/7] feat: Perform tolerance-based comparison for lists and arrays --- diffly/_conditions.py | 93 ++++++++++++++++++++++++-- diffly/comparison.py | 23 +++++++ multi_array.py | 6 ++ tests/test_conditions.py | 141 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 257 insertions(+), 6 deletions(-) create mode 100644 multi_array.py diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 4c87359..51e7ccc 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -22,12 +22,14 @@ def condition_equal_rows( abs_tol_by_column: Mapping[str, float], rel_tol_by_column: Mapping[str, float], abs_tol_temporal_by_column: Mapping[str, dt.timedelta], + max_list_lengths_by_column: Mapping[str, int] | None = None, ) -> pl.Expr: """Build an expression whether two rows are equal, based on all columns' data types.""" if not columns: return pl.lit(True) + _max_list_lengths = max_list_lengths_by_column or {} return pl.all_horizontal( [ condition_equal_columns( @@ -37,6 +39,7 @@ def condition_equal_rows( abs_tol=abs_tol_by_column[column], rel_tol=rel_tol_by_column[column], abs_tol_temporal=abs_tol_temporal_by_column[column], + max_list_length=_max_list_lengths.get(column, 0), ) for column in columns ] @@ -50,6 +53,7 @@ def condition_equal_columns( abs_tol: float = ABS_TOL_DEFAULT, rel_tol: float = REL_TOL_DEFAULT, abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT, + max_list_length: int = 0, ) -> pl.Expr: """Build an expression whether two columns are equal, depending on the columns' data types.""" @@ -61,6 +65,7 @@ def condition_equal_columns( abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, + max_list_length=max_list_length, ) @@ -95,6 +100,7 @@ def _compare_columns( abs_tol: float, rel_tol: float, abs_tol_temporal: dt.timedelta, + max_list_length: int = 0, ) -> pl.Expr: """Build an expression whether two expressions yield the same value. @@ -133,10 +139,18 @@ def _compare_columns( elif isinstance(dtype_left, pl.List | pl.Array) and isinstance( dtype_right, pl.List | pl.Array ): - # As of polars 1.28, there is no way to access another column within - # `list.eval`. Hence, we necessarily need to resort to a primitive - # comparison in this case. - pass + result = _compare_sequence_columns( + col_left=col_left, + col_right=col_right, + dtype_left=dtype_left, + dtype_right=dtype_right, + max_list_length=max_list_length, + abs_tol=abs_tol, + rel_tol=rel_tol, + abs_tol_temporal=abs_tol_temporal, + ) + if result is not None: + return result if ( isinstance(dtype_left, pl.Enum) @@ -167,6 +181,77 @@ def _compare_columns( ) +def _compare_sequence_columns( + col_left: pl.Expr, + col_right: pl.Expr, + dtype_left: DataType | DataTypeClass, + dtype_right: DataType | DataTypeClass, + max_list_length: int, + abs_tol: float, + rel_tol: float, + abs_tol_temporal: dt.timedelta, +) -> pl.Expr | None: + """Compare Array/List columns element-wise with tolerance. + + Returns ``None`` if the comparison cannot be performed element-wise (e.g. List vs + List without a known ``max_list_length``), signalling to the caller that it should + fall back to primitive comparison. + """ + assert isinstance(dtype_left, pl.List | pl.Array) + assert isinstance(dtype_right, pl.List | pl.Array) + inner_left = dtype_left.inner + inner_right = dtype_right.inner + + def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Expr: + if isinstance(dtype, pl.Array): + return col.arr.get(i) + return col.list.get(i, null_on_oob=True) + + n: int | None = None + length_check: pl.Expr | None = None + + if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array): + if dtype_left.shape != dtype_right.shape: + return pl.repeat(pl.lit(False), pl.len()) + n = dtype_left.shape[0] + elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List): + n = dtype_left.shape[0] + length_check = col_right.list.len().eq(pl.lit(n)) + elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array): + n = dtype_right.shape[0] + length_check = col_left.list.len().eq(pl.lit(n)) + else: + # List vs List + if max_list_length == 0: + return None + n = max_list_length + length_check = col_left.list.len().eq_missing(col_right.list.len()) + + if n == 0: + if length_check is not None: + return _eq_missing(length_check, col_left, col_right) + return _eq_missing(pl.lit(True), col_left, col_right) + + elements_match = pl.all_horizontal( + [ + _compare_columns( + col_left=_get_element(col_left, dtype_left, i), + col_right=_get_element(col_right, dtype_right, i), + dtype_left=inner_left, + dtype_right=inner_right, + abs_tol=abs_tol, + rel_tol=rel_tol, + abs_tol_temporal=abs_tol_temporal, + ) + for i in range(n) + ] + ) + + if length_check is not None: + return _eq_missing(length_check & elements_match, col_left, col_right) + return elements_match + + def _compare_primitive_columns( col_left: pl.Expr, col_right: pl.Expr, diff --git a/diffly/comparison.py b/diffly/comparison.py index ca126df..307d363 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -511,6 +511,7 @@ def equal(self, *, check_dtypes: bool = True) -> bool: abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, + max_list_lengths_by_column=self._max_list_lengths, ).all() ) .item() @@ -708,6 +709,26 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str] raise ValueError(f"{difference} are not common columns.") return list(subset) + @cached_property + def _max_list_lengths(self) -> dict[str, int]: + list_columns = [ + col + for col in self._other_common_columns + if isinstance(self.left_schema[col], pl.List) + and isinstance(self.right_schema[col], pl.List) + ] + if not list_columns: + return {} + + exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns] + [left_max, right_max] = pl.collect_all( + [self.left.select(exprs), self.right.select(exprs)] + ) + return { + col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0)) + for col in list_columns + } + def _condition_equal_rows(self, columns: list[str]) -> pl.Expr: return condition_equal_rows( columns=columns, @@ -716,6 +737,7 @@ def _condition_equal_rows(self, columns: list[str]) -> pl.Expr: abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, + max_list_lengths_by_column=self._max_list_lengths, ) def _condition_equal_columns(self, column: str) -> pl.Expr: @@ -726,6 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr: abs_tol=self.abs_tol_by_column[column], rel_tol=self.rel_tol_by_column[column], abs_tol_temporal=self.abs_tol_temporal_by_column[column], + max_list_length=self._max_list_lengths.get(column, 0), ) def _equal_rows(self) -> bool: diff --git a/multi_array.py b/multi_array.py new file mode 100644 index 0000000..f778317 --- /dev/null +++ b/multi_array.py @@ -0,0 +1,6 @@ +# %% +import polars as pl +# %% +df = pl.DataFrame({"a": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]}, schema={"a": pl.Array(inner=pl.UInt8, shape=(2, 2))}) +# %% +df.select(pl.col("a").arr.get(1)) \ No newline at end of file diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 772e61c..a161600 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -81,7 +81,7 @@ def test_condition_equal_columns_different_struct_fields() -> None: @pytest.mark.parametrize( "rhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)] ) -def test_condition_equal_columns_list_array_equal_exact( +def test_condition_equal_columns_list_array_with_tolerance( lhs_type: pl.DataType, rhs_type: pl.DataType ) -> None: # Arrange @@ -110,12 +110,58 @@ def test_condition_equal_columns_list_array_equal_exact( dtype_right=rhs.schema["a_right"], abs_tol=0.5, rel_tol=0, + max_list_length=2, ) ) .to_series() ) - # Assert + # Assert: diff is 0.1, within abs_tol=0.5 + assert actual.to_list() == [True, True] + + +@pytest.mark.parametrize( + "lhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)] +) +@pytest.mark.parametrize( + "rhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)] +) +def test_condition_equal_columns_list_array_exceeds_tolerance( + lhs_type: pl.DataType, rhs_type: pl.DataType +) -> None: + # Arrange + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [[1.0, 1.1], [2.0, 2.1]], + }, + schema={"pk": pl.Int64, "a_left": lhs_type}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [[1.0, 1.1], [2.0, 2.8]], + }, + schema={"pk": pl.Int64, "a_right": rhs_type}, + ) + + # Act + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + abs_tol=0.5, + rel_tol=0, + max_list_length=2, + ) + ) + .to_series() + ) + + # Assert: diff is 0.7, exceeds abs_tol=0.5 assert actual.to_list() == [True, False] @@ -226,6 +272,97 @@ def test_condition_equal_columns_temporal_tolerance() -> None: assert actual.to_list() == [True, False, False, True] +def test_condition_equal_columns_list_different_lengths() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [[1.0, 2.0], [3.0]], + }, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [[1.0, 2.0], [3.0, 4.0]], + }, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + abs_tol=0.5, + rel_tol=0, + max_list_length=2, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, False] + + +def test_condition_equal_columns_list_nulls() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_left": [[1.0, 2.0], None, None], + }, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_right": [[1.0, 2.0], [3.0], None], + }, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=2, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, False, True] + + +def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [[1.0, 2.0], [3.0, 4.0]], + }, + schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [[1.0, 2.0], [3.0]], + }, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + abs_tol=0.5, + rel_tol=0, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, False] + + @pytest.mark.parametrize( ("dtype_left", "dtype_right", "can_compare_dtypes"), [ From 1c0050f23b46118bbae8552719ae6d0d7546e514 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Thu, 19 Mar 2026 20:21:26 +0100 Subject: [PATCH 2/7] remove multi-array --- multi_array.py | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 multi_array.py diff --git a/multi_array.py b/multi_array.py deleted file mode 100644 index f778317..0000000 --- a/multi_array.py +++ /dev/null @@ -1,6 +0,0 @@ -# %% -import polars as pl -# %% -df = pl.DataFrame({"a": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]}, schema={"a": pl.Array(inner=pl.UInt8, shape=(2, 2))}) -# %% -df.select(pl.col("a").arr.get(1)) \ No newline at end of file From 68dc630245f63a99eb832ed68f53e60431299f33 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Thu, 19 Mar 2026 20:35:55 +0100 Subject: [PATCH 3/7] clean up --- diffly/_conditions.py | 61 ++++++++++++++++++---------------------- diffly/comparison.py | 4 +-- tests/test_conditions.py | 6 ++++ 3 files changed, 35 insertions(+), 36 deletions(-) diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 51e7ccc..b440de5 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -19,27 +19,26 @@ def condition_equal_rows( columns: list[str], schema_left: pl.Schema, schema_right: pl.Schema, + max_list_lengths_by_column: Mapping[str, int], abs_tol_by_column: Mapping[str, float], rel_tol_by_column: Mapping[str, float], abs_tol_temporal_by_column: Mapping[str, dt.timedelta], - max_list_lengths_by_column: Mapping[str, int] | None = None, ) -> pl.Expr: """Build an expression whether two rows are equal, based on all columns' data types.""" if not columns: return pl.lit(True) - _max_list_lengths = max_list_lengths_by_column or {} return pl.all_horizontal( [ condition_equal_columns( column=column, dtype_left=schema_left[column], dtype_right=schema_right[column], + max_list_length=max_list_lengths_by_column.get(column, 0), abs_tol=abs_tol_by_column[column], rel_tol=rel_tol_by_column[column], abs_tol_temporal=abs_tol_temporal_by_column[column], - max_list_length=_max_list_lengths.get(column, 0), ) for column in columns ] @@ -50,10 +49,10 @@ def condition_equal_columns( column: str, dtype_left: pl.DataType, dtype_right: pl.DataType, + max_list_length: int, abs_tol: float = ABS_TOL_DEFAULT, rel_tol: float = REL_TOL_DEFAULT, abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT, - max_list_length: int = 0, ) -> pl.Expr: """Build an expression whether two columns are equal, depending on the columns' data types.""" @@ -62,10 +61,10 @@ def condition_equal_columns( col_right=pl.col(f"{column}_{Side.RIGHT}"), dtype_left=dtype_left, dtype_right=dtype_right, + max_list_length=max_list_length, abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, - max_list_length=max_list_length, ) @@ -97,10 +96,10 @@ def _compare_columns( col_right: pl.Expr, dtype_left: DataType | DataTypeClass, dtype_right: DataType | DataTypeClass, + max_list_length: int, abs_tol: float, rel_tol: float, abs_tol_temporal: dt.timedelta, - max_list_length: int = 0, ) -> pl.Expr: """Build an expression whether two expressions yield the same value. @@ -129,6 +128,7 @@ def _compare_columns( col_right=col_right.struct[field], dtype_left=fields_left[field], dtype_right=fields_right[field], + max_list_length=max_list_length, abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, @@ -139,7 +139,7 @@ def _compare_columns( elif isinstance(dtype_left, pl.List | pl.Array) and isinstance( dtype_right, pl.List | pl.Array ): - result = _compare_sequence_columns( + return _compare_sequence_columns( col_left=col_left, col_right=col_right, dtype_left=dtype_left, @@ -149,8 +149,6 @@ def _compare_columns( rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, ) - if result is not None: - return result if ( isinstance(dtype_left, pl.Enum) @@ -168,6 +166,7 @@ def _compare_columns( abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, + max_list_length=max_list_length, ) return _compare_primitive_columns( @@ -190,13 +189,8 @@ def _compare_sequence_columns( abs_tol: float, rel_tol: float, abs_tol_temporal: dt.timedelta, -) -> pl.Expr | None: - """Compare Array/List columns element-wise with tolerance. - - Returns ``None`` if the comparison cannot be performed element-wise (e.g. List vs - List without a known ``max_list_length``), signalling to the caller that it should - fall back to primitive comparison. - """ +) -> pl.Expr: + """Compare Array/List columns element-wise with tolerance.""" assert isinstance(dtype_left, pl.List | pl.Array) assert isinstance(dtype_right, pl.List | pl.Array) inner_left = dtype_left.inner @@ -207,29 +201,27 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex return col.arr.get(i) return col.list.get(i, null_on_oob=True) - n: int | None = None - length_check: pl.Expr | None = None + n_elements: int | None = None + has_same_length: pl.Expr | None = None if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array): if dtype_left.shape != dtype_right.shape: return pl.repeat(pl.lit(False), pl.len()) - n = dtype_left.shape[0] + n_elements = dtype_left.shape[0] elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List): - n = dtype_left.shape[0] - length_check = col_right.list.len().eq(pl.lit(n)) + n_elements = dtype_left.shape[0] + has_same_length = col_right.list.len().eq(pl.lit(n_elements)) elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array): - n = dtype_right.shape[0] - length_check = col_left.list.len().eq(pl.lit(n)) + n_elements = dtype_right.shape[0] + has_same_length = col_left.list.len().eq(pl.lit(n_elements)) else: # List vs List - if max_list_length == 0: - return None - n = max_list_length - length_check = col_left.list.len().eq_missing(col_right.list.len()) - - if n == 0: - if length_check is not None: - return _eq_missing(length_check, col_left, col_right) + n_elements = max_list_length + has_same_length = col_left.list.len().eq_missing(col_right.list.len()) + + if n_elements == 0: + if has_same_length is not None: + return _eq_missing(has_same_length, col_left, col_right) return _eq_missing(pl.lit(True), col_left, col_right) elements_match = pl.all_horizontal( @@ -242,13 +234,14 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, + max_list_length=max_list_length, ) - for i in range(n) + for i in range(n_elements) ] ) - if length_check is not None: - return _eq_missing(length_check & elements_match, col_left, col_right) + if has_same_length is not None: + return _eq_missing(has_same_length & elements_match, col_left, col_right) return elements_match diff --git a/diffly/comparison.py b/diffly/comparison.py index 307d363..045cde6 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -508,10 +508,10 @@ def equal(self, *, check_dtypes: bool = True) -> bool: columns=common_columns, schema_left=self.left_schema, schema_right=self.right_schema, + max_list_lengths_by_column=self._max_list_lengths, abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, - max_list_lengths_by_column=self._max_list_lengths, ).all() ) .item() @@ -734,10 +734,10 @@ def _condition_equal_rows(self, columns: list[str]) -> pl.Expr: columns=columns, schema_left=self.left_schema, schema_right=self.right_schema, + max_list_lengths_by_column=self._max_list_lengths, abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, - max_list_lengths_by_column=self._max_list_lengths, ) def _condition_equal_columns(self, column: str) -> pl.Expr: diff --git a/tests/test_conditions.py b/tests/test_conditions.py index a161600..1d313d8 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -32,6 +32,7 @@ def test_condition_equal_columns_struct() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=0, abs_tol=0.5, rel_tol=0, ) @@ -66,6 +67,7 @@ def test_condition_equal_columns_different_struct_fields() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=0, ) ) .to_series() @@ -188,6 +190,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=0, ) ) .to_series() @@ -220,6 +223,7 @@ def test_condition_equal_columns_exactly_one_nested() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=0, ) ) .to_series() @@ -262,6 +266,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=0, abs_tol_temporal=dt.timedelta(seconds=2), ) ) @@ -354,6 +359,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], + max_list_length=0, abs_tol=0.5, rel_tol=0, ) From 1d4df7b119817b6a6c13d92aed27b161d11b2785 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 20 Mar 2026 08:40:07 +0100 Subject: [PATCH 4/7] improve --- diffly/_conditions.py | 9 ++- diffly/comparison.py | 2 +- tests/test_conditions.py | 161 +++++++++++++++++++++++++-------------- 3 files changed, 110 insertions(+), 62 deletions(-) diff --git a/diffly/_conditions.py b/diffly/_conditions.py index b440de5..acdfad0 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -35,7 +35,7 @@ def condition_equal_rows( column=column, dtype_left=schema_left[column], dtype_right=schema_right[column], - max_list_length=max_list_lengths_by_column.get(column, 0), + max_list_length=max_list_lengths_by_column.get(column), abs_tol=abs_tol_by_column[column], rel_tol=rel_tol_by_column[column], abs_tol_temporal=abs_tol_temporal_by_column[column], @@ -49,7 +49,7 @@ def condition_equal_columns( column: str, dtype_left: pl.DataType, dtype_right: pl.DataType, - max_list_length: int, + max_list_length: int | None = None, abs_tol: float = ABS_TOL_DEFAULT, rel_tol: float = REL_TOL_DEFAULT, abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT, @@ -96,7 +96,7 @@ def _compare_columns( col_right: pl.Expr, dtype_left: DataType | DataTypeClass, dtype_right: DataType | DataTypeClass, - max_list_length: int, + max_list_length: int | None, abs_tol: float, rel_tol: float, abs_tol_temporal: dt.timedelta, @@ -185,7 +185,7 @@ def _compare_sequence_columns( col_right: pl.Expr, dtype_left: DataType | DataTypeClass, dtype_right: DataType | DataTypeClass, - max_list_length: int, + max_list_length: int | None, abs_tol: float, rel_tol: float, abs_tol_temporal: dt.timedelta, @@ -216,6 +216,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex has_same_length = col_left.list.len().eq(pl.lit(n_elements)) else: # List vs List + assert max_list_length is not None n_elements = max_list_length has_same_length = col_left.list.len().eq_missing(col_right.list.len()) diff --git a/diffly/comparison.py b/diffly/comparison.py index 045cde6..e8c55e1 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -748,7 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr: abs_tol=self.abs_tol_by_column[column], rel_tol=self.rel_tol_by_column[column], abs_tol_temporal=self.abs_tol_temporal_by_column[column], - max_list_length=self._max_list_lengths.get(column, 0), + max_list_length=self._max_list_lengths.get(column), ) def _equal_rows(self) -> bool: diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 1d313d8..e473c7e 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -32,7 +32,7 @@ def test_condition_equal_columns_struct() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=0, + max_list_length=None, abs_tol=0.5, rel_tol=0, ) @@ -67,7 +67,7 @@ def test_condition_equal_columns_different_struct_fields() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=0, + max_list_length=None, ) ) .to_series() @@ -89,60 +89,15 @@ def test_condition_equal_columns_list_array_with_tolerance( # Arrange lhs = pl.DataFrame( { - "pk": [1, 2], - "a_left": [[1.0, 1.1], [2.0, 2.1]], - }, - schema={"pk": pl.Int64, "a_left": lhs_type}, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [[1.0, 1.1], [2.0, 2.2]], - }, - schema={"pk": pl.Int64, "a_right": rhs_type}, - ) - - # Act - actual = ( - lhs.join(rhs, on="pk", maintain_order="left") - .select( - condition_equal_columns( - "a", - dtype_left=lhs.schema["a_left"], - dtype_right=rhs.schema["a_right"], - abs_tol=0.5, - rel_tol=0, - max_list_length=2, - ) - ) - .to_series() - ) - - # Assert: diff is 0.1, within abs_tol=0.5 - assert actual.to_list() == [True, True] - - -@pytest.mark.parametrize( - "lhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)] -) -@pytest.mark.parametrize( - "rhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)] -) -def test_condition_equal_columns_list_array_exceeds_tolerance( - lhs_type: pl.DataType, rhs_type: pl.DataType -) -> None: - # Arrange - lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [[1.0, 1.1], [2.0, 2.1]], + "pk": [1, 2, 3], + "a_left": [[1.0, 1.1], [2.0, 2.1], [3.0, 3.0]], }, schema={"pk": pl.Int64, "a_left": lhs_type}, ) rhs = pl.DataFrame( { - "pk": [1, 2], - "a_right": [[1.0, 1.1], [2.0, 2.8]], + "pk": [1, 2, 3], + "a_right": [[1.0, 1.1], [2.0, 2.2], [3.0, 3.7]], }, schema={"pk": pl.Int64, "a_right": rhs_type}, ) @@ -163,8 +118,7 @@ def test_condition_equal_columns_list_array_exceeds_tolerance( .to_series() ) - # Assert: diff is 0.7, exceeds abs_tol=0.5 - assert actual.to_list() == [True, False] + assert actual.to_list() == [True, True, False] def test_condition_equal_columns_nested_dtype_mismatch() -> None: @@ -190,7 +144,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=0, + max_list_length=None, ) ) .to_series() @@ -223,7 +177,7 @@ def test_condition_equal_columns_exactly_one_nested() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=0, + max_list_length=None, ) ) .to_series() @@ -266,7 +220,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=0, + max_list_length=None, abs_tol_temporal=dt.timedelta(seconds=2), ) ) @@ -359,7 +313,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=0, + max_list_length=None, abs_tol=0.5, rel_tol=0, ) @@ -369,6 +323,99 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: assert actual.to_list() == [True, False] +def test_condition_equal_columns_array_different_shapes() -> None: + lhs = pl.DataFrame( + { + "pk": [1], + "a_left": [[1.0, 2.0]], + }, + schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, + ) + rhs = pl.DataFrame( + { + "pk": [1], + "a_right": [[1.0, 2.0, 3.0]], + }, + schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=3)}, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=None, + ) + ) + .to_series() + ) + assert actual.to_list() == [False] + + +def test_condition_equal_columns_empty_arrays() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [[], None], + }, + schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=0)}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [[], None], + }, + schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=0)}, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=None, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, True] + + +def test_condition_equal_columns_empty_lists() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_left": [[], None, []], + }, + schema={"pk": pl.Int64, "a_left": pl.List(pl.Float64)}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_right": [[], None, None], + }, + schema={"pk": pl.Int64, "a_right": pl.List(pl.Float64)}, + ) + + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=0, + ) + ) + .to_series() + ) + assert actual.to_list() == [True, True, False] + + @pytest.mark.parametrize( ("dtype_left", "dtype_right", "can_compare_dtypes"), [ From d528ecdc420d21c79e1818aec2e63cd917b989c6 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 20 Mar 2026 09:03:55 +0100 Subject: [PATCH 5/7] clean compare_sequence_columns --- diffly/_conditions.py | 18 +++++++---------- diffly/comparison.py | 2 +- tests/test_conditions.py | 43 +++++++--------------------------------- 3 files changed, 15 insertions(+), 48 deletions(-) diff --git a/diffly/_conditions.py b/diffly/_conditions.py index acdfad0..789c7ef 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -49,7 +49,7 @@ def condition_equal_columns( column: str, dtype_left: pl.DataType, dtype_right: pl.DataType, - max_list_length: int | None = None, + max_list_length: int | None, abs_tol: float = ABS_TOL_DEFAULT, rel_tol: float = REL_TOL_DEFAULT, abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT, @@ -201,28 +201,26 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex return col.arr.get(i) return col.list.get(i, null_on_oob=True) - n_elements: int | None = None - has_same_length: pl.Expr | None = None + n_elements: int + has_same_length: pl.Expr if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array): if dtype_left.shape != dtype_right.shape: return pl.repeat(pl.lit(False), pl.len()) n_elements = dtype_left.shape[0] + has_same_length = pl.repeat(pl.lit(True), pl.len()) elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List): n_elements = dtype_left.shape[0] has_same_length = col_right.list.len().eq(pl.lit(n_elements)) elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array): n_elements = dtype_right.shape[0] has_same_length = col_left.list.len().eq(pl.lit(n_elements)) - else: - # List vs List - assert max_list_length is not None + else: # pl.List vs pl.List + assert isinstance(max_list_length, int) n_elements = max_list_length has_same_length = col_left.list.len().eq_missing(col_right.list.len()) if n_elements == 0: - if has_same_length is not None: - return _eq_missing(has_same_length, col_left, col_right) return _eq_missing(pl.lit(True), col_left, col_right) elements_match = pl.all_horizontal( @@ -241,9 +239,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex ] ) - if has_same_length is not None: - return _eq_missing(has_same_length & elements_match, col_left, col_right) - return elements_match + return _eq_missing(has_same_length & elements_match, col_left, col_right) def _compare_primitive_columns( diff --git a/diffly/comparison.py b/diffly/comparison.py index e8c55e1..045cde6 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -748,7 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr: abs_tol=self.abs_tol_by_column[column], rel_tol=self.rel_tol_by_column[column], abs_tol_temporal=self.abs_tol_temporal_by_column[column], - max_list_length=self._max_list_lengths.get(column), + max_list_length=self._max_list_lengths.get(column, 0), ) def _equal_rows(self) -> bool: diff --git a/tests/test_conditions.py b/tests/test_conditions.py index e473c7e..5cae246 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -231,17 +231,17 @@ def test_condition_equal_columns_temporal_tolerance() -> None: assert actual.to_list() == [True, False, False, True] -def test_condition_equal_columns_list_different_lengths() -> None: +def test_condition_equal_columns_two_lists() -> None: lhs = pl.DataFrame( { - "pk": [1, 2], - "a_left": [[1.0, 2.0], [3.0]], + "pk": [1, 2, 3, 4, 5], + "a_left": [[1.0, 2.0], [3.0], [5.0, None], None, None], }, ) rhs = pl.DataFrame( { - "pk": [1, 2], - "a_right": [[1.0, 2.0], [3.0, 4.0]], + "pk": [1, 2, 3, 4, 5], + "a_right": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0], None], }, ) @@ -259,36 +259,7 @@ def test_condition_equal_columns_list_different_lengths() -> None: ) .to_series() ) - assert actual.to_list() == [True, False] - - -def test_condition_equal_columns_list_nulls() -> None: - lhs = pl.DataFrame( - { - "pk": [1, 2, 3], - "a_left": [[1.0, 2.0], None, None], - }, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2, 3], - "a_right": [[1.0, 2.0], [3.0], None], - }, - ) - - actual = ( - lhs.join(rhs, on="pk", maintain_order="left") - .select( - condition_equal_columns( - "a", - dtype_left=lhs.schema["a_left"], - dtype_right=rhs.schema["a_right"], - max_list_length=2, - ) - ) - .to_series() - ) - assert actual.to_list() == [True, False, True] + assert actual.to_list() == [True, False, False, False, True] def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: @@ -323,7 +294,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: assert actual.to_list() == [True, False] -def test_condition_equal_columns_array_different_shapes() -> None: +def test_condition_equal_columns_two_arrays_different_shapes() -> None: lhs = pl.DataFrame( { "pk": [1], From d7da2507d0fb375bf5f152ab6ba5a81f3010634d Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 20 Mar 2026 09:24:40 +0100 Subject: [PATCH 6/7] clean up --- diffly/comparison.py | 8 +++--- tests/test_conditions.py | 54 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/diffly/comparison.py b/diffly/comparison.py index 045cde6..ee46e55 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -508,7 +508,7 @@ def equal(self, *, check_dtypes: bool = True) -> bool: columns=common_columns, schema_left=self.left_schema, schema_right=self.right_schema, - max_list_lengths_by_column=self._max_list_lengths, + max_list_lengths_by_column=self._max_list_lengths_by_column, abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, @@ -710,7 +710,7 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str] return list(subset) @cached_property - def _max_list_lengths(self) -> dict[str, int]: + def _max_list_lengths_by_column(self) -> dict[str, int]: list_columns = [ col for col in self._other_common_columns @@ -734,7 +734,7 @@ def _condition_equal_rows(self, columns: list[str]) -> pl.Expr: columns=columns, schema_left=self.left_schema, schema_right=self.right_schema, - max_list_lengths_by_column=self._max_list_lengths, + max_list_lengths_by_column=self._max_list_lengths_by_column, abs_tol_by_column=self.abs_tol_by_column, rel_tol_by_column=self.rel_tol_by_column, abs_tol_temporal_by_column=self.abs_tol_temporal_by_column, @@ -748,7 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr: abs_tol=self.abs_tol_by_column[column], rel_tol=self.rel_tol_by_column[column], abs_tol_temporal=self.abs_tol_temporal_by_column[column], - max_list_length=self._max_list_lengths.get(column, 0), + max_list_length=self._max_list_lengths_by_column.get(column), ) def _equal_rows(self) -> bool: diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 5cae246..5abca49 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -121,6 +121,60 @@ def test_condition_equal_columns_list_array_with_tolerance( assert actual.to_list() == [True, True, False] +@pytest.mark.parametrize( + "lhs_type", + [pl.Array(pl.Float64, shape=(2, 2)), pl.List(pl.List(pl.Float64))], +) +@pytest.mark.parametrize( + "rhs_type", + [pl.Array(pl.Float64, shape=(2, 2)), pl.List(pl.List(pl.Float64))], +) +def test_condition_equal_columns_nested_list_array_with_tolerance( + lhs_type: pl.DataType, rhs_type: pl.DataType +) -> None: + # Arrange + lhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_left": [ + [[1.0, 1.1], [2.0, 2.1]], + [[3.0, 3.0], [4.0, 4.0]], + [[5.0, 5.0], [6.0, 6.0]], + ], + }, + schema={"pk": pl.Int64, "a_left": lhs_type}, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2, 3], + "a_right": [ + [[1.0, 1.1], [2.0, 2.1]], + [[3.0, 3.0], [4.0, 4.4]], + [[5.0, 5.0], [6.0, 6.8]], + ], + }, + schema={"pk": pl.Int64, "a_right": rhs_type}, + ) + + # Act + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + abs_tol=0.5, + rel_tol=0, + max_list_length=2, + ) + ) + .to_series() + ) + + assert actual.to_list() == [True, True, False] + + def test_condition_equal_columns_nested_dtype_mismatch() -> None: # Arrange lhs = pl.DataFrame( From 955517c7f28eeb735613d8db98d28b926ed1266a Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 20 Mar 2026 10:16:13 +0100 Subject: [PATCH 7/7] feedback copilot --- diffly/_conditions.py | 18 +++++++++++------- tests/test_conditions.py | 21 ++++++++++++--------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 789c7ef..ce11363 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -196,11 +196,6 @@ def _compare_sequence_columns( inner_left = dtype_left.inner inner_right = dtype_right.inner - def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Expr: - if isinstance(dtype, pl.Array): - return col.arr.get(i) - return col.list.get(i, null_on_oob=True) - n_elements: int has_same_length: pl.Expr @@ -216,13 +211,22 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex n_elements = dtype_right.shape[0] has_same_length = col_left.list.len().eq(pl.lit(n_elements)) else: # pl.List vs pl.List - assert isinstance(max_list_length, int) + if not isinstance(max_list_length, int): + # Fallback for nested list comparisons where no max_list_length is + # available: perform a direct equality comparison without element-wise + # unrolling. + return _eq_missing(col_left.eq_missing(col_right), col_left, col_right) n_elements = max_list_length has_same_length = col_left.list.len().eq_missing(col_right.list.len()) if n_elements == 0: return _eq_missing(pl.lit(True), col_left, col_right) + def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Expr: + if isinstance(dtype, pl.Array): + return col.arr.get(i) + return col.list.get(i, null_on_oob=True) + elements_match = pl.all_horizontal( [ _compare_columns( @@ -233,7 +237,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, - max_list_length=max_list_length, + max_list_length=None, ) for i in range(n_elements) ] diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 5abca49..9093ee6 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -123,11 +123,11 @@ def test_condition_equal_columns_list_array_with_tolerance( @pytest.mark.parametrize( "lhs_type", - [pl.Array(pl.Float64, shape=(2, 2)), pl.List(pl.List(pl.Float64))], + [pl.Array(pl.Float64, shape=(2, 3)), pl.List(pl.List(pl.Float64))], ) @pytest.mark.parametrize( "rhs_type", - [pl.Array(pl.Float64, shape=(2, 2)), pl.List(pl.List(pl.Float64))], + [pl.Array(pl.Float64, shape=(2, 3)), pl.List(pl.List(pl.Float64))], ) def test_condition_equal_columns_nested_list_array_with_tolerance( lhs_type: pl.DataType, rhs_type: pl.DataType @@ -137,9 +137,9 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( { "pk": [1, 2, 3], "a_left": [ - [[1.0, 1.1], [2.0, 2.1]], - [[3.0, 3.0], [4.0, 4.0]], - [[5.0, 5.0], [6.0, 6.0]], + [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], + [[3.0, 3.0, 3.1], [4.0, 4.0, 4.1]], + [[5.0, 5.0, 5.1], [6.0, 6.0, 6.1]], ], }, schema={"pk": pl.Int64, "a_left": lhs_type}, @@ -148,9 +148,9 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( { "pk": [1, 2, 3], "a_right": [ - [[1.0, 1.1], [2.0, 2.1]], - [[3.0, 3.0], [4.0, 4.4]], - [[5.0, 5.0], [6.0, 6.8]], + [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], + [[3.0, 3.0, 3.1], [4.0, 4.4, 4.1]], + [[5.0, 5.0, 5.1], [6.0, 6.8, 6.1]], ], }, schema={"pk": pl.Int64, "a_right": rhs_type}, @@ -172,7 +172,10 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( .to_series() ) - assert actual.to_list() == [True, True, False] + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): + assert actual.to_list() == [True, False, False] + else: + assert actual.to_list() == [True, True, False] def test_condition_equal_columns_nested_dtype_mismatch() -> None: