diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 6429b5520..5b9c8d7b0 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -145,28 +145,27 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return _normalize_temp_path(orig) == _normalize_temp_path(new) return False - if isinstance( - orig, - ( - int, - bool, - complex, - type(None), - type(Ellipsis), - decimal.Decimal, - set, - bytes, - bytearray, - memoryview, - frozenset, - enum.Enum, - type, - range, - slice, - OrderedDict, - types.GenericAlias, - ), - ): + _equality_types = ( + int, + bool, + complex, + type(None), + type(Ellipsis), + decimal.Decimal, + set, + bytes, + bytearray, + memoryview, + frozenset, + enum.Enum, + type, + range, + slice, + OrderedDict, + types.GenericAlias, + *((_union_type,) if (_union_type := getattr(types, "UnionType", None)) else ()), + ) + if isinstance(orig, _equality_types): return orig == new if isinstance(orig, float): if math.isnan(orig) and math.isnan(new): diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 100e385fd..cc42956bf 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -5216,3 +5216,27 @@ def test_python_tempfile_pattern_regex(self): assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmp123456/") assert not PYTHON_TEMPFILE_PATTERN.search("/tmp/mydir/file.txt") assert not PYTHON_TEMPFILE_PATTERN.search("/home/tmp123/file.txt") + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="types.UnionType requires Python 3.10+") +class TestUnionType: + def test_union_type_equal(self): + assert comparator(int | str, int | str) + + def test_union_type_not_equal(self): + assert not comparator(int | str, int | float) + + def test_union_type_order_independent(self): + assert comparator(int | str, str | int) + + def test_union_type_multiple_args(self): + assert comparator(int | str | float, int | str | float) + + def test_union_type_in_list(self): + assert comparator([int | str, 1], [int | str, 1]) + + def test_union_type_in_dict(self): + assert comparator({"key": int | str}, {"key": int | str}) + + def test_union_type_vs_none(self): + assert not comparator(int | str, None)