Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,28 +145,27 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
return _normalize_temp_path(orig) == _normalize_temp_path(new)
return False

if isinstance(
orig,
(
int,
bool,
complex,
type(None),
type(Ellipsis),
decimal.Decimal,
set,
bytes,
bytearray,
memoryview,
frozenset,
enum.Enum,
type,
range,
slice,
OrderedDict,
types.GenericAlias,
),
):
_equality_types = (
int,
bool,
complex,
type(None),
type(Ellipsis),
decimal.Decimal,
set,
bytes,
bytearray,
memoryview,
frozenset,
enum.Enum,
type,
range,
slice,
OrderedDict,
types.GenericAlias,
*((_union_type,) if (_union_type := getattr(types, "UnionType", None)) else ()),
)
if isinstance(orig, _equality_types):
return orig == new
if isinstance(orig, float):
if math.isnan(orig) and math.isnan(new):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5216,3 +5216,27 @@ def test_python_tempfile_pattern_regex(self):
assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmp123456/")
assert not PYTHON_TEMPFILE_PATTERN.search("/tmp/mydir/file.txt")
assert not PYTHON_TEMPFILE_PATTERN.search("/home/tmp123/file.txt")


@pytest.mark.skipif(sys.version_info < (3, 10), reason="types.UnionType requires Python 3.10+")
class TestUnionType:
def test_union_type_equal(self):
assert comparator(int | str, int | str)

def test_union_type_not_equal(self):
assert not comparator(int | str, int | float)

def test_union_type_order_independent(self):
assert comparator(int | str, str | int)

def test_union_type_multiple_args(self):
assert comparator(int | str | float, int | str | float)

def test_union_type_in_list(self):
assert comparator([int | str, 1], [int | str, 1])

def test_union_type_in_dict(self):
assert comparator({"key": int | str}, {"key": int | str})

def test_union_type_vs_none(self):
assert not comparator(int | str, None)
Loading