diff --git a/test/pytests/test_parseutils.py b/test/pytests/test_parseutils.py index 9e9d2ae9..df53c06c 100644 --- a/test/pytests/test_parseutils.py +++ b/test/pytests/test_parseutils.py @@ -2,7 +2,10 @@ import pytest import sqlparse +from sqlparse.sql import Identifier, IdentifierList, Token, TokenList +from sqlparse.tokens import DML, Keyword, Punctuation +from mycli.packages import parseutils from mycli.packages.parseutils import ( extract_columns_from_select, extract_from_part, @@ -280,6 +283,18 @@ def test_extract_from_part_handles_multiple_joins_and_skips_on_clause(): assert token_values(tokens) == ['abc', 'join', 'def', 'ghi'] +def test_extract_from_part_recurses_into_subselect_and_stops_at_punctuation(): + parsed = sqlparse.parse('select * from (select * from inner_table), outer_table')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['inner_table'] + + +def test_extract_from_part_stops_at_punctuation_when_requested(): + parsed = TokenList([Token(Keyword, 'FROM'), Token(Punctuation, ','), Token(Keyword, 'SELECT')]) + tokens = extract_from_part(parsed, stop_at_punctuation=True) + assert token_values(tokens) == [] + + def test_extract_table_identifiers_handles_identifier_list(): parsed = sqlparse.parse('select * from abc a, def d')[0] token_stream = extract_from_part(parsed) @@ -301,6 +316,33 @@ def test_extract_table_identifiers_handles_function_tokens(): assert list(extract_table_identifiers(token_stream)) == [(None, 'my_func', 'my_func')] +def test_extract_table_identifiers_skips_identifier_list_entries_without_identifier_methods(): + class BrokenIdentifierList(IdentifierList): + def get_identifiers(self): + return [object()] + + assert list(extract_table_identifiers(iter([BrokenIdentifierList([])]))) == [] + + +def test_extract_table_identifiers_uses_name_when_identifier_has_no_real_name(): + class NamelessIdentifier(Identifier): + def get_real_name(self): + return None + + def get_parent_name(self): + return None + + def get_name(self): + return 'fallback_name' + + def get_alias(self): + return None + + assert list(extract_table_identifiers(iter([NamelessIdentifier([])]))) == [ + (None, 'fallback_name', 'fallback_name'), + ] + + @pytest.mark.parametrize( ('sql', 'expected_keyword', 'expected_text'), [ @@ -335,6 +377,82 @@ def test_query_is_single_table_update(sql, is_single_table): assert query_is_single_table_update(sql) is is_single_table +def test_extract_columns_from_select_handles_falsey_last_select(monkeypatch): + monkeypatch.setattr(parseutils, 'get_last_select', lambda _parsed: []) + assert extract_columns_from_select('select 1') == [] + + +def test_extract_columns_from_select_handles_single_identifier(monkeypatch): + class SingleIdentifier(Identifier): + def get_real_name(self): + return 'column_name' + + monkeypatch.setattr( + parseutils, + 'get_last_select', + lambda _parsed: TokenList([Token(DML, 'SELECT'), SingleIdentifier([])]), + ) + + assert extract_columns_from_select('select column_name') == ['column_name'] + + +def test_extract_columns_from_select_ignores_unhandled_identifier_list_entries(monkeypatch): + class WeirdIdentifierList(IdentifierList): + def get_identifiers(self): + return [object()] + + monkeypatch.setattr( + parseutils, + 'get_last_select', + lambda _parsed: TokenList([Token(DML, 'SELECT'), WeirdIdentifierList([])]), + ) + + assert extract_columns_from_select('select 1') == [] + + +def test_extract_columns_from_select_stops_at_keyword_before_collecting_columns(monkeypatch): + monkeypatch.setattr( + parseutils, + 'get_last_select', + lambda _parsed: TokenList([Token(DML, 'SELECT'), Token(Keyword, 'FROM')]), + ) + + assert extract_columns_from_select('select 1') == [] + + +def test_extract_tables_from_complete_statements_returns_empty_for_falsey_rough_parse(monkeypatch): + monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: []) + + assert extract_tables_from_complete_statements('select * from t') == [] + + +def test_extract_tables_from_complete_statements_skips_cte_table_identifiers(monkeypatch): + class FakeParentSelect: + def sql(self): + return 'WITH cte AS (SELECT 1) SELECT * FROM cte' + + class FakeIdentifier: + parent_select = FakeParentSelect() + db = '' + name = 'cte' + alias = '' + + class FakeStatement: + def find_all(self, _table_type): + return [FakeIdentifier()] + + monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: ['stmt']) + monkeypatch.setattr(parseutils.sqlglot, 'parse_one', lambda *_args, **_kwargs: FakeStatement()) + + assert extract_tables_from_complete_statements('with cte as (select 1) select * from cte') == [] + + +def test_query_is_single_table_update_returns_false_when_parse_result_is_empty(monkeypatch): + monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: []) + + assert query_is_single_table_update('update test set x = 1') is False + + def test_is_destructive(): sql = "use test;\nshow databases;\ndrop database foo;" assert is_destructive(["drop"], sql) is True @@ -360,6 +478,16 @@ def test_is_destructive_update_without_where_clause(): assert is_destructive(["update"], sql) is True +def test_is_destructive_skips_empty_split_queries(monkeypatch): + monkeypatch.setattr(parseutils.sqlparse, 'split', lambda _queries: ['', '']) + + assert is_destructive(['drop'], 'ignored') is False + + +def test_is_destructive_returns_false_when_no_query_matches_keywords() -> None: + assert is_destructive(['drop'], 'select 1; show databases;') is False + + @pytest.mark.parametrize( ("sql", "has_where_clause"), [ @@ -389,3 +517,7 @@ def test_query_has_where_clause(sql, has_where_clause): ) def test_is_dropping_database(sql, dbname, is_dropping): assert is_dropping_database(sql, dbname) == is_dropping + + +def test_is_dropping_database_skips_statements_without_enough_keywords(): + assert is_dropping_database('drop foo', 'foo') is False