Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions test/pytests/test_parseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import pytest
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import DML, Keyword, Punctuation

from mycli.packages import parseutils
from mycli.packages.parseutils import (
extract_columns_from_select,
extract_from_part,
Expand Down Expand Up @@ -280,6 +283,18 @@ def test_extract_from_part_handles_multiple_joins_and_skips_on_clause():
assert token_values(tokens) == ['abc', 'join', 'def', 'ghi']


def test_extract_from_part_recurses_into_subselect_and_stops_at_punctuation():
parsed = sqlparse.parse('select * from (select * from inner_table), outer_table')[0]
tokens = extract_from_part(parsed)
assert token_values(tokens) == ['inner_table']


def test_extract_from_part_stops_at_punctuation_when_requested():
parsed = TokenList([Token(Keyword, 'FROM'), Token(Punctuation, ','), Token(Keyword, 'SELECT')])
tokens = extract_from_part(parsed, stop_at_punctuation=True)
assert token_values(tokens) == []


def test_extract_table_identifiers_handles_identifier_list():
parsed = sqlparse.parse('select * from abc a, def d')[0]
token_stream = extract_from_part(parsed)
Expand All @@ -301,6 +316,33 @@ def test_extract_table_identifiers_handles_function_tokens():
assert list(extract_table_identifiers(token_stream)) == [(None, 'my_func', 'my_func')]


def test_extract_table_identifiers_skips_identifier_list_entries_without_identifier_methods():
class BrokenIdentifierList(IdentifierList):
def get_identifiers(self):
return [object()]

assert list(extract_table_identifiers(iter([BrokenIdentifierList([])]))) == []


def test_extract_table_identifiers_uses_name_when_identifier_has_no_real_name():
class NamelessIdentifier(Identifier):
def get_real_name(self):
return None

def get_parent_name(self):
return None

def get_name(self):
return 'fallback_name'

def get_alias(self):
return None

assert list(extract_table_identifiers(iter([NamelessIdentifier([])]))) == [
(None, 'fallback_name', 'fallback_name'),
]


@pytest.mark.parametrize(
('sql', 'expected_keyword', 'expected_text'),
[
Expand Down Expand Up @@ -335,6 +377,82 @@ def test_query_is_single_table_update(sql, is_single_table):
assert query_is_single_table_update(sql) is is_single_table


def test_extract_columns_from_select_handles_falsey_last_select(monkeypatch):
monkeypatch.setattr(parseutils, 'get_last_select', lambda _parsed: [])
assert extract_columns_from_select('select 1') == []


def test_extract_columns_from_select_handles_single_identifier(monkeypatch):
class SingleIdentifier(Identifier):
def get_real_name(self):
return 'column_name'

monkeypatch.setattr(
parseutils,
'get_last_select',
lambda _parsed: TokenList([Token(DML, 'SELECT'), SingleIdentifier([])]),
)

assert extract_columns_from_select('select column_name') == ['column_name']


def test_extract_columns_from_select_ignores_unhandled_identifier_list_entries(monkeypatch):
class WeirdIdentifierList(IdentifierList):
def get_identifiers(self):
return [object()]

monkeypatch.setattr(
parseutils,
'get_last_select',
lambda _parsed: TokenList([Token(DML, 'SELECT'), WeirdIdentifierList([])]),
)

assert extract_columns_from_select('select 1') == []


def test_extract_columns_from_select_stops_at_keyword_before_collecting_columns(monkeypatch):
monkeypatch.setattr(
parseutils,
'get_last_select',
lambda _parsed: TokenList([Token(DML, 'SELECT'), Token(Keyword, 'FROM')]),
)

assert extract_columns_from_select('select 1') == []


def test_extract_tables_from_complete_statements_returns_empty_for_falsey_rough_parse(monkeypatch):
monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: [])

assert extract_tables_from_complete_statements('select * from t') == []


def test_extract_tables_from_complete_statements_skips_cte_table_identifiers(monkeypatch):
class FakeParentSelect:
def sql(self):
return 'WITH cte AS (SELECT 1) SELECT * FROM cte'

class FakeIdentifier:
parent_select = FakeParentSelect()
db = ''
name = 'cte'
alias = ''

class FakeStatement:
def find_all(self, _table_type):
return [FakeIdentifier()]

monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: ['stmt'])
monkeypatch.setattr(parseutils.sqlglot, 'parse_one', lambda *_args, **_kwargs: FakeStatement())

assert extract_tables_from_complete_statements('with cte as (select 1) select * from cte') == []


def test_query_is_single_table_update_returns_false_when_parse_result_is_empty(monkeypatch):
monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: [])

assert query_is_single_table_update('update test set x = 1') is False


def test_is_destructive():
sql = "use test;\nshow databases;\ndrop database foo;"
assert is_destructive(["drop"], sql) is True
Expand All @@ -360,6 +478,16 @@ def test_is_destructive_update_without_where_clause():
assert is_destructive(["update"], sql) is True


def test_is_destructive_skips_empty_split_queries(monkeypatch):
monkeypatch.setattr(parseutils.sqlparse, 'split', lambda _queries: ['', ''])

assert is_destructive(['drop'], 'ignored') is False


def test_is_destructive_returns_false_when_no_query_matches_keywords() -> None:
assert is_destructive(['drop'], 'select 1; show databases;') is False


@pytest.mark.parametrize(
("sql", "has_where_clause"),
[
Expand Down Expand Up @@ -389,3 +517,7 @@ def test_query_has_where_clause(sql, has_where_clause):
)
def test_is_dropping_database(sql, dbname, is_dropping):
assert is_dropping_database(sql, dbname) == is_dropping


def test_is_dropping_database_skips_statements_without_enough_keywords():
assert is_dropping_database('drop foo', 'foo') is False
Loading