22
33import pytest
44import sqlparse
5+ from sqlparse .sql import Identifier , IdentifierList , Token , TokenList
6+ from sqlparse .tokens import DML , Keyword , Punctuation
57
8+ from mycli .packages import parseutils
69from mycli .packages .parseutils import (
710 extract_columns_from_select ,
811 extract_from_part ,
@@ -280,6 +283,18 @@ def test_extract_from_part_handles_multiple_joins_and_skips_on_clause():
280283 assert token_values (tokens ) == ['abc' , 'join' , 'def' , 'ghi' ]
281284
282285
286+ def test_extract_from_part_recurses_into_subselect_and_stops_at_punctuation ():
287+ parsed = sqlparse .parse ('select * from (select * from inner_table), outer_table' )[0 ]
288+ tokens = extract_from_part (parsed )
289+ assert token_values (tokens ) == ['inner_table' ]
290+
291+
292+ def test_extract_from_part_stops_at_punctuation_when_requested ():
293+ parsed = TokenList ([Token (Keyword , 'FROM' ), Token (Punctuation , ',' ), Token (Keyword , 'SELECT' )])
294+ tokens = extract_from_part (parsed , stop_at_punctuation = True )
295+ assert token_values (tokens ) == []
296+
297+
283298def test_extract_table_identifiers_handles_identifier_list ():
284299 parsed = sqlparse .parse ('select * from abc a, def d' )[0 ]
285300 token_stream = extract_from_part (parsed )
@@ -301,6 +316,33 @@ def test_extract_table_identifiers_handles_function_tokens():
301316 assert list (extract_table_identifiers (token_stream )) == [(None , 'my_func' , 'my_func' )]
302317
303318
319+ def test_extract_table_identifiers_skips_identifier_list_entries_without_identifier_methods ():
320+ class BrokenIdentifierList (IdentifierList ):
321+ def get_identifiers (self ):
322+ return [object ()]
323+
324+ assert list (extract_table_identifiers (iter ([BrokenIdentifierList ([])]))) == []
325+
326+
327+ def test_extract_table_identifiers_uses_name_when_identifier_has_no_real_name ():
328+ class NamelessIdentifier (Identifier ):
329+ def get_real_name (self ):
330+ return None
331+
332+ def get_parent_name (self ):
333+ return None
334+
335+ def get_name (self ):
336+ return 'fallback_name'
337+
338+ def get_alias (self ):
339+ return None
340+
341+ assert list (extract_table_identifiers (iter ([NamelessIdentifier ([])]))) == [
342+ (None , 'fallback_name' , 'fallback_name' ),
343+ ]
344+
345+
304346@pytest .mark .parametrize (
305347 ('sql' , 'expected_keyword' , 'expected_text' ),
306348 [
@@ -335,6 +377,82 @@ def test_query_is_single_table_update(sql, is_single_table):
335377 assert query_is_single_table_update (sql ) is is_single_table
336378
337379
380+ def test_extract_columns_from_select_handles_falsey_last_select (monkeypatch ):
381+ monkeypatch .setattr (parseutils , 'get_last_select' , lambda _parsed : [])
382+ assert extract_columns_from_select ('select 1' ) == []
383+
384+
385+ def test_extract_columns_from_select_handles_single_identifier (monkeypatch ):
386+ class SingleIdentifier (Identifier ):
387+ def get_real_name (self ):
388+ return 'column_name'
389+
390+ monkeypatch .setattr (
391+ parseutils ,
392+ 'get_last_select' ,
393+ lambda _parsed : TokenList ([Token (DML , 'SELECT' ), SingleIdentifier ([])]),
394+ )
395+
396+ assert extract_columns_from_select ('select column_name' ) == ['column_name' ]
397+
398+
399+ def test_extract_columns_from_select_ignores_unhandled_identifier_list_entries (monkeypatch ):
400+ class WeirdIdentifierList (IdentifierList ):
401+ def get_identifiers (self ):
402+ return [object ()]
403+
404+ monkeypatch .setattr (
405+ parseutils ,
406+ 'get_last_select' ,
407+ lambda _parsed : TokenList ([Token (DML , 'SELECT' ), WeirdIdentifierList ([])]),
408+ )
409+
410+ assert extract_columns_from_select ('select 1' ) == []
411+
412+
413+ def test_extract_columns_from_select_stops_at_keyword_before_collecting_columns (monkeypatch ):
414+ monkeypatch .setattr (
415+ parseutils ,
416+ 'get_last_select' ,
417+ lambda _parsed : TokenList ([Token (DML , 'SELECT' ), Token (Keyword , 'FROM' )]),
418+ )
419+
420+ assert extract_columns_from_select ('select 1' ) == []
421+
422+
423+ def test_extract_tables_from_complete_statements_returns_empty_for_falsey_rough_parse (monkeypatch ):
424+ monkeypatch .setattr (parseutils .sqlparse , 'parse' , lambda _sql : [])
425+
426+ assert extract_tables_from_complete_statements ('select * from t' ) == []
427+
428+
429+ def test_extract_tables_from_complete_statements_skips_cte_table_identifiers (monkeypatch ):
430+ class FakeParentSelect :
431+ def sql (self ):
432+ return 'WITH cte AS (SELECT 1) SELECT * FROM cte'
433+
434+ class FakeIdentifier :
435+ parent_select = FakeParentSelect ()
436+ db = ''
437+ name = 'cte'
438+ alias = ''
439+
440+ class FakeStatement :
441+ def find_all (self , _table_type ):
442+ return [FakeIdentifier ()]
443+
444+ monkeypatch .setattr (parseutils .sqlparse , 'parse' , lambda _sql : ['stmt' ])
445+ monkeypatch .setattr (parseutils .sqlglot , 'parse_one' , lambda * _args , ** _kwargs : FakeStatement ())
446+
447+ assert extract_tables_from_complete_statements ('with cte as (select 1) select * from cte' ) == []
448+
449+
450+ def test_query_is_single_table_update_returns_false_when_parse_result_is_empty (monkeypatch ):
451+ monkeypatch .setattr (parseutils .sqlparse , 'parse' , lambda _sql : [])
452+
453+ assert query_is_single_table_update ('update test set x = 1' ) is False
454+
455+
338456def test_is_destructive ():
339457 sql = "use test;\n show databases;\n drop database foo;"
340458 assert is_destructive (["drop" ], sql ) is True
@@ -360,6 +478,16 @@ def test_is_destructive_update_without_where_clause():
360478 assert is_destructive (["update" ], sql ) is True
361479
362480
481+ def test_is_destructive_skips_empty_split_queries (monkeypatch ):
482+ monkeypatch .setattr (parseutils .sqlparse , 'split' , lambda _queries : ['' , '' ])
483+
484+ assert is_destructive (['drop' ], 'ignored' ) is False
485+
486+
487+ def test_is_destructive_returns_false_when_no_query_matches_keywords () -> None :
488+ assert is_destructive (['drop' ], 'select 1; show databases;' ) is False
489+
490+
363491@pytest .mark .parametrize (
364492 ("sql" , "has_where_clause" ),
365493 [
@@ -389,3 +517,7 @@ def test_query_has_where_clause(sql, has_where_clause):
389517)
390518def test_is_dropping_database (sql , dbname , is_dropping ):
391519 assert is_dropping_database (sql , dbname ) == is_dropping
520+
521+
522+ def test_is_dropping_database_skips_statements_without_enough_keywords ():
523+ assert is_dropping_database ('drop foo' , 'foo' ) is False
0 commit comments