diff --git a/sql_compare/__init__.py b/sql_compare/__init__.py index b67c1cb..575a2bd 100644 --- a/sql_compare/__init__.py +++ b/sql_compare/__init__.py @@ -19,6 +19,7 @@ import dataclasses import itertools +import re import typing import sqlparse @@ -29,16 +30,41 @@ import pathlib +_PSQL_META_COMMAND_RE = re.compile(r"^\\[^\n]*$", re.MULTILINE) + + +def _parse_statements(sql: str) -> set[Statement]: + """Parse SQL string into a set of meaningful SQL statements. + + Non-SQL content (psql meta-commands, empty statements) is filtered out. + """ + # Strip psql meta-commands (e.g. \unrestrict) before parsing + sql = _PSQL_META_COMMAND_RE.sub("", sql) + return { + stmt + for t in sqlparse.parse(sql) + if (stmt := Statement(t)).statement_type != Statement.UNKNOWN_TYPE + } + + def compare_files(first_file: pathlib.Path, second_file: pathlib.Path) -> bool: """Compare two SQL files.""" return compare(first_file.read_text(), second_file.read_text()) +def diff_files(first_file: pathlib.Path, second_file: pathlib.Path) -> set[Statement]: + """Return the set of statements that differ between two SQL files.""" + return diff(first_file.read_text(), second_file.read_text()) + + def compare(first_sql: str, second_sql: str) -> bool: """Compare two SQL strings.""" - first_sql_statements = [Statement(t) for t in sqlparse.parse(first_sql)] - second_sql_statements = [Statement(t) for t in sqlparse.parse(second_sql)] - return first_sql_statements == second_sql_statements + return not diff(first_sql, second_sql) + + +def diff(first_sql: str, second_sql: str) -> set[Statement]: + """Return the set of statements that differ between two SQL strings.""" + return _parse_statements(first_sql) ^ _parse_statements(second_sql) def get_diff( @@ -46,8 +72,8 @@ def get_diff( second_sql: str, ) -> list[list[list[str]]]: """Show the difference between two SQL schemas, ignoring differences due to column order and other non-significant SQL changes.""" - first_set = {Statement(t) for t in sqlparse.parse(first_sql)} - second_set = {Statement(t) for t in sqlparse.parse(second_sql)} + first_set = _parse_statements(first_sql) + second_set = _parse_statements(second_sql) first_diffs = sorted([stmt.str_tokens for stmt in first_set - second_set]) second_diffs = sorted([stmt.str_tokens for stmt in second_set - first_set]) @@ -207,6 +233,7 @@ def tokens(self) -> Generator[Token | TokenList, None, None]: # Tokens are "()", no need to sort them if not split_result.identifier_groups and not split_result.separators: yield from filtered_tokens + return # Sort identifier groups by their hash split_result.identifier_groups.sort(key=lambda g: "".join(t.hash for t in g)) diff --git a/tests/test_sql_compare.py b/tests/test_sql_compare.py index a334dca..5d3f73d 100644 --- a/tests/test_sql_compare.py +++ b/tests/test_sql_compare.py @@ -83,10 +83,26 @@ "CREATE INDEX foo_idx ON foo (id1, id2)", "CREATE INDEX foo_idx ON foo (id1, id2)", ), + # Ignore statement order + ( + "CREATE TABLE foo (id INT); CREATE TABLE bar (id INT);", + "CREATE TABLE bar (id INT); CREATE TABLE foo (id INT);", + ), + # Ignore non-SQL content (psql meta-commands) + ( + "\\unrestrict abc123\nCREATE TABLE foo (id INT)", + "\\unrestrict xyz789\nCREATE TABLE foo (id INT)", + ), + # Ignore non-SQL content with different tokens + ( + "CREATE TABLE foo (id INT);\n\\unrestrict abc123\n", + "CREATE TABLE foo (id INT);\n\\unrestrict xyz789\n", + ), ], ) def test_compare_eq(first_sql: str, second_sql: str) -> None: assert sql_compare.compare(first_sql, second_sql) + assert not sql_compare.diff(first_sql, second_sql) @pytest.mark.parametrize( @@ -149,6 +165,7 @@ def test_compare_eq(first_sql: str, second_sql: str) -> None: ) def test_compare_neq(first_sql: str, second_sql: str) -> None: assert not sql_compare.compare(first_sql, second_sql) + assert sql_compare.diff(first_sql, second_sql) @pytest.mark.parametrize(