Skip to content

Commit 566c7ad

Browse files
authored
Merge pull request #1797 from dbcli/RW/add-parseutils-tests-02
Increase test coverage for `parseutils.py`
2 parents c49798f + 890ce31 commit 566c7ad

1 file changed

Lines changed: 132 additions & 0 deletions

File tree

test/pytests/test_parseutils.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22

33
import pytest
44
import sqlparse
5+
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
6+
from sqlparse.tokens import DML, Keyword, Punctuation
57

8+
from mycli.packages import parseutils
69
from mycli.packages.parseutils import (
710
extract_columns_from_select,
811
extract_from_part,
@@ -280,6 +283,18 @@ def test_extract_from_part_handles_multiple_joins_and_skips_on_clause():
280283
assert token_values(tokens) == ['abc', 'join', 'def', 'ghi']
281284

282285

286+
def test_extract_from_part_recurses_into_subselect_and_stops_at_punctuation():
287+
parsed = sqlparse.parse('select * from (select * from inner_table), outer_table')[0]
288+
tokens = extract_from_part(parsed)
289+
assert token_values(tokens) == ['inner_table']
290+
291+
292+
def test_extract_from_part_stops_at_punctuation_when_requested():
293+
parsed = TokenList([Token(Keyword, 'FROM'), Token(Punctuation, ','), Token(Keyword, 'SELECT')])
294+
tokens = extract_from_part(parsed, stop_at_punctuation=True)
295+
assert token_values(tokens) == []
296+
297+
283298
def test_extract_table_identifiers_handles_identifier_list():
284299
parsed = sqlparse.parse('select * from abc a, def d')[0]
285300
token_stream = extract_from_part(parsed)
@@ -301,6 +316,33 @@ def test_extract_table_identifiers_handles_function_tokens():
301316
assert list(extract_table_identifiers(token_stream)) == [(None, 'my_func', 'my_func')]
302317

303318

319+
def test_extract_table_identifiers_skips_identifier_list_entries_without_identifier_methods():
320+
class BrokenIdentifierList(IdentifierList):
321+
def get_identifiers(self):
322+
return [object()]
323+
324+
assert list(extract_table_identifiers(iter([BrokenIdentifierList([])]))) == []
325+
326+
327+
def test_extract_table_identifiers_uses_name_when_identifier_has_no_real_name():
328+
class NamelessIdentifier(Identifier):
329+
def get_real_name(self):
330+
return None
331+
332+
def get_parent_name(self):
333+
return None
334+
335+
def get_name(self):
336+
return 'fallback_name'
337+
338+
def get_alias(self):
339+
return None
340+
341+
assert list(extract_table_identifiers(iter([NamelessIdentifier([])]))) == [
342+
(None, 'fallback_name', 'fallback_name'),
343+
]
344+
345+
304346
@pytest.mark.parametrize(
305347
('sql', 'expected_keyword', 'expected_text'),
306348
[
@@ -335,6 +377,82 @@ def test_query_is_single_table_update(sql, is_single_table):
335377
assert query_is_single_table_update(sql) is is_single_table
336378

337379

380+
def test_extract_columns_from_select_handles_falsey_last_select(monkeypatch):
381+
monkeypatch.setattr(parseutils, 'get_last_select', lambda _parsed: [])
382+
assert extract_columns_from_select('select 1') == []
383+
384+
385+
def test_extract_columns_from_select_handles_single_identifier(monkeypatch):
386+
class SingleIdentifier(Identifier):
387+
def get_real_name(self):
388+
return 'column_name'
389+
390+
monkeypatch.setattr(
391+
parseutils,
392+
'get_last_select',
393+
lambda _parsed: TokenList([Token(DML, 'SELECT'), SingleIdentifier([])]),
394+
)
395+
396+
assert extract_columns_from_select('select column_name') == ['column_name']
397+
398+
399+
def test_extract_columns_from_select_ignores_unhandled_identifier_list_entries(monkeypatch):
400+
class WeirdIdentifierList(IdentifierList):
401+
def get_identifiers(self):
402+
return [object()]
403+
404+
monkeypatch.setattr(
405+
parseutils,
406+
'get_last_select',
407+
lambda _parsed: TokenList([Token(DML, 'SELECT'), WeirdIdentifierList([])]),
408+
)
409+
410+
assert extract_columns_from_select('select 1') == []
411+
412+
413+
def test_extract_columns_from_select_stops_at_keyword_before_collecting_columns(monkeypatch):
414+
monkeypatch.setattr(
415+
parseutils,
416+
'get_last_select',
417+
lambda _parsed: TokenList([Token(DML, 'SELECT'), Token(Keyword, 'FROM')]),
418+
)
419+
420+
assert extract_columns_from_select('select 1') == []
421+
422+
423+
def test_extract_tables_from_complete_statements_returns_empty_for_falsey_rough_parse(monkeypatch):
424+
monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: [])
425+
426+
assert extract_tables_from_complete_statements('select * from t') == []
427+
428+
429+
def test_extract_tables_from_complete_statements_skips_cte_table_identifiers(monkeypatch):
430+
class FakeParentSelect:
431+
def sql(self):
432+
return 'WITH cte AS (SELECT 1) SELECT * FROM cte'
433+
434+
class FakeIdentifier:
435+
parent_select = FakeParentSelect()
436+
db = ''
437+
name = 'cte'
438+
alias = ''
439+
440+
class FakeStatement:
441+
def find_all(self, _table_type):
442+
return [FakeIdentifier()]
443+
444+
monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: ['stmt'])
445+
monkeypatch.setattr(parseutils.sqlglot, 'parse_one', lambda *_args, **_kwargs: FakeStatement())
446+
447+
assert extract_tables_from_complete_statements('with cte as (select 1) select * from cte') == []
448+
449+
450+
def test_query_is_single_table_update_returns_false_when_parse_result_is_empty(monkeypatch):
451+
monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: [])
452+
453+
assert query_is_single_table_update('update test set x = 1') is False
454+
455+
338456
def test_is_destructive():
339457
sql = "use test;\nshow databases;\ndrop database foo;"
340458
assert is_destructive(["drop"], sql) is True
@@ -360,6 +478,16 @@ def test_is_destructive_update_without_where_clause():
360478
assert is_destructive(["update"], sql) is True
361479

362480

481+
def test_is_destructive_skips_empty_split_queries(monkeypatch):
482+
monkeypatch.setattr(parseutils.sqlparse, 'split', lambda _queries: ['', ''])
483+
484+
assert is_destructive(['drop'], 'ignored') is False
485+
486+
487+
def test_is_destructive_returns_false_when_no_query_matches_keywords() -> None:
488+
assert is_destructive(['drop'], 'select 1; show databases;') is False
489+
490+
363491
@pytest.mark.parametrize(
364492
("sql", "has_where_clause"),
365493
[
@@ -389,3 +517,7 @@ def test_query_has_where_clause(sql, has_where_clause):
389517
)
390518
def test_is_dropping_database(sql, dbname, is_dropping):
391519
assert is_dropping_database(sql, dbname) == is_dropping
520+
521+
522+
def test_is_dropping_database_skips_statements_without_enough_keywords():
523+
assert is_dropping_database('drop foo', 'foo') is False

0 commit comments

Comments
 (0)