Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions diffly/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def condition_equal_rows(
columns: list[str],
schema_left: pl.Schema,
schema_right: pl.Schema,
max_list_lengths_by_column: Mapping[str, int],
abs_tol_by_column: Mapping[str, float],
rel_tol_by_column: Mapping[str, float],
abs_tol_temporal_by_column: Mapping[str, dt.timedelta],
Expand All @@ -34,6 +35,7 @@ def condition_equal_rows(
column=column,
dtype_left=schema_left[column],
dtype_right=schema_right[column],
max_list_length=max_list_lengths_by_column.get(column),
abs_tol=abs_tol_by_column[column],
rel_tol=rel_tol_by_column[column],
abs_tol_temporal=abs_tol_temporal_by_column[column],
Expand All @@ -47,6 +49,7 @@ def condition_equal_columns(
column: str,
dtype_left: pl.DataType,
dtype_right: pl.DataType,
max_list_length: int | None,
abs_tol: float = ABS_TOL_DEFAULT,
rel_tol: float = REL_TOL_DEFAULT,
abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT,
Expand All @@ -58,6 +61,7 @@ def condition_equal_columns(
col_right=pl.col(f"{column}_{Side.RIGHT}"),
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
Expand Down Expand Up @@ -92,6 +96,7 @@ def _compare_columns(
col_right: pl.Expr,
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
max_list_length: int | None,
abs_tol: float,
rel_tol: float,
abs_tol_temporal: dt.timedelta,
Expand Down Expand Up @@ -123,6 +128,7 @@ def _compare_columns(
col_right=col_right.struct[field],
dtype_left=fields_left[field],
dtype_right=fields_right[field],
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
Expand All @@ -133,10 +139,16 @@ def _compare_columns(
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
dtype_right, pl.List | pl.Array
):
# As of polars 1.28, there is no way to access another column within
# `list.eval`. Hence, we necessarily need to resort to a primitive
# comparison in this case.
pass
return _compare_sequence_columns(
col_left=col_left,
col_right=col_right,
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
)

if (
isinstance(dtype_left, pl.Enum)
Expand All @@ -154,6 +166,7 @@ def _compare_columns(
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
max_list_length=max_list_length,
)

return _compare_primitive_columns(
Expand All @@ -167,6 +180,72 @@ def _compare_columns(
)


def _compare_sequence_columns(
col_left: pl.Expr,
col_right: pl.Expr,
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
max_list_length: int | None,
abs_tol: float,
rel_tol: float,
abs_tol_temporal: dt.timedelta,
) -> pl.Expr:
"""Compare Array/List columns element-wise with tolerance."""
assert isinstance(dtype_left, pl.List | pl.Array)
assert isinstance(dtype_right, pl.List | pl.Array)
inner_left = dtype_left.inner
inner_right = dtype_right.inner

n_elements: int
has_same_length: pl.Expr

if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array):
if dtype_left.shape != dtype_right.shape:
return pl.repeat(pl.lit(False), pl.len())
n_elements = dtype_left.shape[0]
has_same_length = pl.repeat(pl.lit(True), pl.len())
elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List):
n_elements = dtype_left.shape[0]
has_same_length = col_right.list.len().eq(pl.lit(n_elements))
elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array):
n_elements = dtype_right.shape[0]
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
else: # pl.List vs pl.List
if not isinstance(max_list_length, int):
# Fallback for nested list comparisons where no max_list_length is
# available: perform a direct equality comparison without element-wise
# unrolling.
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
n_elements = max_list_length
has_same_length = col_left.list.len().eq_missing(col_right.list.len())

if n_elements == 0:
return _eq_missing(pl.lit(True), col_left, col_right)

def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Expr:
if isinstance(dtype, pl.Array):
return col.arr.get(i)
return col.list.get(i, null_on_oob=True)

elements_match = pl.all_horizontal(
[
_compare_columns(
col_left=_get_element(col_left, dtype_left, i),
col_right=_get_element(col_right, dtype_right, i),
dtype_left=inner_left,
dtype_right=inner_right,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
max_list_length=None,
)
for i in range(n_elements)
]
)

return _eq_missing(has_same_length & elements_match, col_left, col_right)


def _compare_primitive_columns(
col_left: pl.Expr,
col_right: pl.Expr,
Expand Down
23 changes: 23 additions & 0 deletions diffly/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def equal(self, *, check_dtypes: bool = True) -> bool:
columns=common_columns,
schema_left=self.left_schema,
schema_right=self.right_schema,
max_list_lengths_by_column=self._max_list_lengths_by_column,
abs_tol_by_column=self.abs_tol_by_column,
rel_tol_by_column=self.rel_tol_by_column,
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
Expand Down Expand Up @@ -708,11 +709,32 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
raise ValueError(f"{difference} are not common columns.")
return list(subset)

@cached_property
def _max_list_lengths_by_column(self) -> dict[str, int]:
list_columns = [
col
for col in self._other_common_columns
if isinstance(self.left_schema[col], pl.List)
and isinstance(self.right_schema[col], pl.List)
]
if not list_columns:
return {}

exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
[left_max, right_max] = pl.collect_all(
[self.left.select(exprs), self.right.select(exprs)]
)
return {
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
for col in list_columns
}

def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
return condition_equal_rows(
columns=columns,
schema_left=self.left_schema,
schema_right=self.right_schema,
max_list_lengths_by_column=self._max_list_lengths_by_column,
abs_tol_by_column=self.abs_tol_by_column,
rel_tol_by_column=self.rel_tol_by_column,
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
Expand All @@ -726,6 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr:
abs_tol=self.abs_tol_by_column[column],
rel_tol=self.rel_tol_by_column[column],
abs_tol_temporal=self.abs_tol_temporal_by_column[column],
max_list_length=self._max_list_lengths_by_column.get(column),
)

def _equal_rows(self) -> bool:
Expand Down
Loading
Loading