Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions sql_compare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import dataclasses
import itertools
import re
import typing

import sqlparse
Expand All @@ -29,25 +30,50 @@
import pathlib


_PSQL_META_COMMAND_RE = re.compile(r"^\\[^\n]*$", re.MULTILINE)


def _parse_statements(sql: str) -> set[Statement]:
"""Parse SQL string into a set of meaningful SQL statements.

Non-SQL content (psql meta-commands, empty statements) is filtered out.
"""
# Strip psql meta-commands (e.g. \unrestrict) before parsing
sql = _PSQL_META_COMMAND_RE.sub("", sql)
return {
stmt
for t in sqlparse.parse(sql)
if (stmt := Statement(t)).statement_type != Statement.UNKNOWN_TYPE
}


def compare_files(first_file: pathlib.Path, second_file: pathlib.Path) -> bool:
"""Compare two SQL files."""
return compare(first_file.read_text(), second_file.read_text())


def diff_files(first_file: pathlib.Path, second_file: pathlib.Path) -> set[Statement]:
"""Return the set of statements that differ between two SQL files."""
return diff(first_file.read_text(), second_file.read_text())


def compare(first_sql: str, second_sql: str) -> bool:
"""Compare two SQL strings."""
first_sql_statements = [Statement(t) for t in sqlparse.parse(first_sql)]
second_sql_statements = [Statement(t) for t in sqlparse.parse(second_sql)]
return first_sql_statements == second_sql_statements
return not diff(first_sql, second_sql)


def diff(first_sql: str, second_sql: str) -> set[Statement]:
"""Return the set of statements that differ between two SQL strings."""
return _parse_statements(first_sql) ^ _parse_statements(second_sql)


def get_diff(
first_sql: str,
second_sql: str,
) -> list[list[list[str]]]:
"""Show the difference between two SQL schemas, ignoring differences due to column order and other non-significant SQL changes."""
first_set = {Statement(t) for t in sqlparse.parse(first_sql)}
second_set = {Statement(t) for t in sqlparse.parse(second_sql)}
first_set = _parse_statements(first_sql)
second_set = _parse_statements(second_sql)
first_diffs = sorted([stmt.str_tokens for stmt in first_set - second_set])
second_diffs = sorted([stmt.str_tokens for stmt in second_set - first_set])

Expand Down Expand Up @@ -207,6 +233,7 @@ def tokens(self) -> Generator[Token | TokenList, None, None]:
# Tokens are "()", no need to sort them
if not split_result.identifier_groups and not split_result.separators:
yield from filtered_tokens
return

# Sort identifier groups by their hash
split_result.identifier_groups.sort(key=lambda g: "".join(t.hash for t in g))
Expand Down
17 changes: 17 additions & 0 deletions tests/test_sql_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,26 @@
"CREATE INDEX foo_idx ON foo (id1, id2)",
"CREATE INDEX foo_idx ON foo (id1, id2)",
),
# Ignore statement order
(
"CREATE TABLE foo (id INT); CREATE TABLE bar (id INT);",
"CREATE TABLE bar (id INT); CREATE TABLE foo (id INT);",
),
# Ignore non-SQL content (psql meta-commands)
(
"\\unrestrict abc123\nCREATE TABLE foo (id INT)",
"\\unrestrict xyz789\nCREATE TABLE foo (id INT)",
),
# Ignore non-SQL content with different tokens
(
"CREATE TABLE foo (id INT);\n\\unrestrict abc123\n",
"CREATE TABLE foo (id INT);\n\\unrestrict xyz789\n",
),
],
)
def test_compare_eq(first_sql: str, second_sql: str) -> None:
assert sql_compare.compare(first_sql, second_sql)
assert not sql_compare.diff(first_sql, second_sql)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -149,6 +165,7 @@ def test_compare_eq(first_sql: str, second_sql: str) -> None:
)
def test_compare_neq(first_sql: str, second_sql: str) -> None:
assert not sql_compare.compare(first_sql, second_sql)
assert sql_compare.diff(first_sql, second_sql)


@pytest.mark.parametrize(
Expand Down