From d708afc345cd265322ad25b74763b7c19392adbc Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 4 Apr 2026 08:14:11 -0400 Subject: [PATCH] move SQL utilities to sql_utils.py * most of the functions in parseutils.py were SQL utilities * move four functions from main.py to sql_utils.py * create cli_utils.py with the remainder of parseutils.py --- AGENTS.md | 3 +- changelog.md | 2 + mycli/main.py | 56 ++------ mycli/main_modes/batch.py | 2 +- mycli/packages/cli_utils.py | 12 ++ mycli/packages/completion_engine.py | 2 +- mycli/packages/prompt_utils.py | 2 +- .../packages/{parseutils.py => sql_utils.py} | 58 ++++++-- mycli/packages/tabular_output/sql_format.py | 2 +- mycli/sqlcompleter.py | 2 +- test/pytests/test_cli_utils.py | 24 ++++ test/pytests/test_main.py | 13 -- test/pytests/test_main_regression.py | 21 +-- .../{test_parseutils.py => test_sql_utils.py} | 127 ++++++++++++++---- 14 files changed, 199 insertions(+), 127 deletions(-) create mode 100644 mycli/packages/cli_utils.py rename mycli/packages/{parseutils.py => sql_utils.py} (90%) create mode 100644 test/pytests/test_cli_utils.py rename test/pytests/{test_parseutils.py => test_sql_utils.py} (85%) diff --git a/AGENTS.md b/AGENTS.md index dc4e860f..3920084d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -27,11 +27,12 @@ A command line client for MySQL with auto-completion and syntax highlighting. ├── mycli/packages/ # application packages ├── mycli/packages/batch_utils.py # utilities for `--batch` mode ├── mycli/packages/checkup.py # implementation of `--checkup` mode +├── mycli/packages/cli_utils.py # utilities for parsing CLI arguments ├── mycli/packages/completion_engine.py # implementation of completion suggestions ├── mycli/packages/filepaths.py # utilities for files, including completion suggestions ├── mycli/packages/hybrid_redirection.py # implementation of shell-style redirects ├── mycli/packages/paramiko_stub/ # stub in case the Paramiko library is not installed -├── mycli/packages/parseutils.py # utilities for parsing SQL statements +├── mycli/packages/sql_utils.py # utilities for parsing SQL statements ├── mycli/packages/prompt_utils.py # utilities for confirming on destructive statements ├── mycli/packages/ptoolkit/ # extends prompt_toolkit ├── mycli/packages/shortcuts.py # utilities for keyboard shortcuts diff --git a/changelog.md b/changelog.md index 2a6e4ad1..4f1f3456 100644 --- a/changelog.md +++ b/changelog.md @@ -35,6 +35,8 @@ Internal * Move `--checkup` logic to the new `main_modes` with `--batch`. * Sort coverage report in tox suite. * Skip more tests when a database connection is not present. +* Move SQL utilities to a new `sql_utils.py`. +* Move CLI utilities to a new `cli_utils.py`. 1.67.1 (2026/03/28) diff --git a/mycli/main.py b/mycli/main.py index 7f47e769..f3f8d504 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -87,14 +87,21 @@ ) from mycli.main_modes.checkup import main_checkup from mycli.packages import special +from mycli.packages.cli_utils import is_valid_connection_scheme from mycli.packages.filepaths import dir_path_exists, guess_socket_location from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command -from mycli.packages.parseutils import is_dropping_database, is_valid_connection_scheme from mycli.packages.prompt_utils import confirm, confirm_destructive_query from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import ArgType from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count +from mycli.packages.sql_utils import ( + is_dropping_database, + is_mutating, + is_select, + need_completion_refresh, + need_completion_reset, +) from mycli.packages.sqlresult import SQLResult from mycli.packages.string_utils import sanitize_terminal_title from mycli.packages.tabular_output import sql_format @@ -2702,53 +2709,6 @@ def get_password_from_file(password_file: str | None) -> str | None: mycli.close() -def need_completion_refresh(queries: str) -> bool: - """Determines if the completion needs a refresh by checking if the sql - statement is an alter, create, drop or change db.""" - for query in sqlparse.split(queries): - try: - first_token = query.split()[0] - if first_token.lower() in ("alter", "create", "use", "\\r", "\\u", "connect", "drop", "rename"): - return True - except Exception: - continue - return False - - -def need_completion_reset(queries: str) -> bool: - """Determines if the statement is a database switch such as 'use' or '\\u'. - When a database is changed the existing completions must be reset before we - start the completion refresh for the new database. - """ - for query in sqlparse.split(queries): - try: - tokens = query.split() - first_token = tokens[0] - if first_token.lower() in ("use", "\\u"): - return True - if first_token.lower() in ("\\r", "connect") and len(tokens) > 1: - return True - except Exception: - continue - return False - - -def is_mutating(status_plain: str | None) -> bool: - """Determines if the statement is mutating based on the status.""" - if not status_plain: - return False - - mutating = {"insert", "update", "delete", "alter", "create", "drop", "replace", "truncate", "load", "rename"} - return status_plain.split(None, 1)[0].lower() in mutating - - -def is_select(status_plain: str | None) -> bool: - """Returns true if the first word in status is 'select'.""" - if not status_plain: - return False - return status_plain.split(None, 1)[0].lower() == "select" - - def thanks_picker() -> str: import mycli diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py index 03b18207..f4b52467 100644 --- a/mycli/main_modes/batch.py +++ b/mycli/main_modes/batch.py @@ -12,8 +12,8 @@ import pymysql from mycli.packages.batch_utils import statements_from_filehandle -from mycli.packages.parseutils import is_destructive from mycli.packages.prompt_utils import confirm_destructive_query +from mycli.packages.sql_utils import is_destructive if TYPE_CHECKING: from mycli.main import CliArgs, MyCli diff --git a/mycli/packages/cli_utils.py b/mycli/packages/cli_utils.py new file mode 100644 index 00000000..b5e7c5e6 --- /dev/null +++ b/mycli/packages/cli_utils.py @@ -0,0 +1,12 @@ +from __future__ import annotations + + +def is_valid_connection_scheme(text: str) -> tuple[bool, str | None]: + # exit early if the text does not resemble a DSN URI + if "://" not in text: + return False, None + scheme = text.split("://")[0] + if scheme not in ("mysql", "mysqlx", "tcp", "socket", "ssh"): + return False, scheme + else: + return True, None diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 0d69701e..f623a38c 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -6,9 +6,9 @@ import sqlparse from sqlparse.sql import Comparison, Identifier, Token, Where -from mycli.packages.parseutils import extract_tables, find_prev_keyword, last_word from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.special.main import parse_special_command +from mycli.packages.sql_utils import extract_tables, find_prev_keyword, last_word sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/prompt_utils.py index 68c468f6..fa0f0537 100644 --- a/mycli/packages/prompt_utils.py +++ b/mycli/packages/prompt_utils.py @@ -2,7 +2,7 @@ import click -from mycli.packages.parseutils import is_destructive +from mycli.packages.sql_utils import is_destructive class ConfirmBoolParamType(click.ParamType): diff --git a/mycli/packages/parseutils.py b/mycli/packages/sql_utils.py similarity index 90% rename from mycli/packages/parseutils.py rename to mycli/packages/sql_utils.py index 53b96823..8edb5744 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/sql_utils.py @@ -23,17 +23,6 @@ } -def is_valid_connection_scheme(text: str) -> tuple[bool, str | None]: - # exit early if the text does not resemble a DSN URI - if "://" not in text: - return False, None - scheme = text.split("://")[0] - if scheme not in ("mysql", "mysqlx", "tcp", "socket", "ssh"): - return False, scheme - else: - return True, None - - def last_word( text: str, include: Literal[ @@ -433,3 +422,50 @@ def normalize_db_name(db: str) -> str: if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: result = keywords[0].normalized == "DROP" return result + + +def need_completion_refresh(queries: str) -> bool: + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop or change db.""" + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ("alter", "create", "use", "\\r", "\\u", "connect", "drop", "rename"): + return True + except Exception: + continue + return False + + +def need_completion_reset(queries: str) -> bool: + """Determines if the statement is a database switch such as 'use' or '\\u'. + When a database is changed the existing completions must be reset before we + start the completion refresh for the new database. + """ + for query in sqlparse.split(queries): + try: + tokens = query.split() + first_token = tokens[0] + if first_token.lower() in ("use", "\\u"): + return True + if first_token.lower() in ("\\r", "connect") and len(tokens) > 1: + return True + except Exception: + continue + return False + + +def is_mutating(status_plain: str | None) -> bool: + """Determines if the statement is mutating based on the status.""" + if not status_plain: + return False + + mutating = {"insert", "update", "delete", "alter", "create", "drop", "replace", "truncate", "load", "rename"} + return status_plain.split(None, 1)[0].lower() in mutating + + +def is_select(status_plain: str | None) -> bool: + """Returns true if the first word in status is 'select'.""" + if not status_plain: + return False + return status_plain.split(None, 1)[0].lower() == "select" diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py index 7583c339..31def8e1 100644 --- a/mycli/packages/tabular_output/sql_format.py +++ b/mycli/packages/tabular_output/sql_format.py @@ -6,7 +6,7 @@ from cli_helpers.tabular_output import TabularOutputFormatter -from mycli.packages.parseutils import extract_tables_from_complete_statements +from mycli.packages.sql_utils import extract_tables_from_complete_statements supported_formats = ( "sql-insert", diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index e7ee2370..c0f669c8 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -13,10 +13,10 @@ from mycli.packages.completion_engine import is_inside_quotes, suggest_type from mycli.packages.filepaths import complete_path, parse_path, suggest_path -from mycli.packages.parseutils import extract_columns_from_select, extract_tables, last_word from mycli.packages.special import llm from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.packages.sql_utils import extract_columns_from_select, extract_tables, last_word _logger = logging.getLogger(__name__) _CASE_CHANGE_PAT = re.compile('(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])') diff --git a/test/pytests/test_cli_utils.py b/test/pytests/test_cli_utils.py new file mode 100644 index 00000000..7875e2e3 --- /dev/null +++ b/test/pytests/test_cli_utils.py @@ -0,0 +1,24 @@ +# type: ignore + +import pytest + +from mycli.packages.cli_utils import ( + is_valid_connection_scheme, +) + + +@pytest.mark.parametrize( + ('text', 'is_valid', 'invalid_scheme'), + [ + ('localhost', False, None), + ('mysql://user@localhost/db', True, None), + ('mysqlx://user@localhost/db', True, None), + ('tcp://localhost:3306', True, None), + ('socket:///tmp/mysql.sock', True, None), + ('ssh://user@example.com', True, None), + ('postgres://user@localhost/db', False, 'postgres'), + ('http://example.com', False, 'http'), + ], +) +def test_is_valid_connection_scheme(text, is_valid, invalid_scheme): + assert is_valid_connection_scheme(text) == (is_valid, invalid_scheme) diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 92e29a45..67889761 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -21,7 +21,6 @@ TEST_DATABASE, ) from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint, thanks_picker -from mycli.packages.parseutils import is_valid_connection_scheme import mycli.packages.special from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.sqlresult import SQLResult @@ -143,18 +142,6 @@ def test_select_from_empty_table(executor): assert expected in result.output -@dbtest -def test_is_valid_connection_scheme_valid(executor, capsys): - is_valid, scheme = is_valid_connection_scheme(f"mysql://test@{DEFAULT_HOST}:{DEFAULT_PORT}/dev") - assert is_valid - - -@dbtest -def test_is_valid_connection_scheme_invalid(executor, capsys): - is_valid, scheme = is_valid_connection_scheme(f"nope://test@{DEFAULT_HOST}:{DEFAULT_PORT}/dev") - assert not is_valid - - def test_filtered_sys_argv_maps_single_dash_h_to_help(monkeypatch): import mycli.main diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index 33f9a6c2..377cd59d 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -1329,10 +1329,6 @@ def cursor(self) -> PromptCursor: prompt = main.MyCli.get_prompt(cli, r'\H|\y|\Y|\T|\w|\W', 1) assert prompt == '127.0.0.1|123|uptime:123|TLSv1.3|7|7' - monkeypatch.setattr(main.sqlparse, 'split', lambda text: [None]) - assert main.need_completion_refresh('sql') is False - assert main.need_completion_reset('sql') is False - def test_format_sqlresult_string_paths_and_close_and_title_early_returns(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() @@ -1484,7 +1480,6 @@ def test_filtered_sys_argv_covers_help_and_passthrough(monkeypatch: pytest.Monke assert main.filtered_sys_argv() == ['--help'] monkeypatch.setattr(main.sys, 'argv', ['mycli', '-h', 'db.example']) assert main.filtered_sys_argv() == ['-h', 'db.example'] - assert main.need_completion_refresh('') is False def test_completion_helpers_title_helpers_thanks_tips_and_read_ssh_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: @@ -1526,21 +1521,6 @@ def test_completion_helpers_title_helpers_thanks_tips_and_read_ssh_config(monkey assert list(main.MyCli.get_completions(cli, 'select', 6)) == ['done'] assert entered_lock['count'] >= 2 - monkeypatch.setattr(main.sqlparse, 'split', lambda text: ['alter table t', 'broken']) - assert main.need_completion_refresh('sql') is True - monkeypatch.setattr(main.sqlparse, 'split', lambda text: ['']) - assert main.need_completion_refresh('sql') is False - monkeypatch.setattr(main.sqlparse, 'split', lambda text: ['use db']) - assert main.need_completion_reset('use db') is True - monkeypatch.setattr(main.sqlparse, 'split', lambda text: ['connect db']) - assert main.need_completion_reset('connect db') is True - monkeypatch.setattr(main.sqlparse, 'split', lambda text: ['select 1']) - assert main.need_completion_reset('select 1') is False - assert main.is_mutating('INSERT 1') is True - assert main.is_mutating(None) is False - assert main.is_select('SELECT 1') is True - assert main.is_select(None) is False - class FakeResource: def __init__(self, text: str | None) -> None: self.text = text @@ -2725,6 +2705,7 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'need_completion_refresh', lambda text: text == 'dropdb') monkeypatch.setattr(main, 'need_completion_reset', lambda text: True) monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: text == 'dropdb') + main.MyCli.run_cli(cli) assert reconnect_calls == ['', ''] assert any('bad op' in line for line in echoes) diff --git a/test/pytests/test_parseutils.py b/test/pytests/test_sql_utils.py similarity index 85% rename from test/pytests/test_parseutils.py rename to test/pytests/test_sql_utils.py index df53c06c..81619127 100644 --- a/test/pytests/test_parseutils.py +++ b/test/pytests/test_sql_utils.py @@ -5,8 +5,8 @@ from sqlparse.sql import Identifier, IdentifierList, Token, TokenList from sqlparse.tokens import DML, Keyword, Punctuation -from mycli.packages import parseutils -from mycli.packages.parseutils import ( +from mycli.packages import sql_utils +from mycli.packages.sql_utils import ( extract_columns_from_select, extract_from_part, extract_table_identifiers, @@ -16,9 +16,12 @@ get_last_select, is_destructive, is_dropping_database, + is_mutating, + is_select, is_subselect, - is_valid_connection_scheme, last_word, + need_completion_refresh, + need_completion_reset, queries_start_with, query_has_where_clause, query_is_single_table_update, @@ -175,23 +178,6 @@ def test_queries_start_with(): assert queries_start_with(sql, ['delete', 'update']) is False -@pytest.mark.parametrize( - ('text', 'is_valid', 'invalid_scheme'), - [ - ('localhost', False, None), - ('mysql://user@localhost/db', True, None), - ('mysqlx://user@localhost/db', True, None), - ('tcp://localhost:3306', True, None), - ('socket:///tmp/mysql.sock', True, None), - ('ssh://user@example.com', True, None), - ('postgres://user@localhost/db', False, 'postgres'), - ('http://example.com', False, 'http'), - ], -) -def test_is_valid_connection_scheme(text, is_valid, invalid_scheme): - assert is_valid_connection_scheme(text) == (is_valid, invalid_scheme) - - @pytest.mark.parametrize( ('text', 'include', 'expected'), [ @@ -378,7 +364,7 @@ def test_query_is_single_table_update(sql, is_single_table): def test_extract_columns_from_select_handles_falsey_last_select(monkeypatch): - monkeypatch.setattr(parseutils, 'get_last_select', lambda _parsed: []) + monkeypatch.setattr(sql_utils, 'get_last_select', lambda _parsed: []) assert extract_columns_from_select('select 1') == [] @@ -388,7 +374,7 @@ def get_real_name(self): return 'column_name' monkeypatch.setattr( - parseutils, + sql_utils, 'get_last_select', lambda _parsed: TokenList([Token(DML, 'SELECT'), SingleIdentifier([])]), ) @@ -402,7 +388,7 @@ def get_identifiers(self): return [object()] monkeypatch.setattr( - parseutils, + sql_utils, 'get_last_select', lambda _parsed: TokenList([Token(DML, 'SELECT'), WeirdIdentifierList([])]), ) @@ -412,7 +398,7 @@ def get_identifiers(self): def test_extract_columns_from_select_stops_at_keyword_before_collecting_columns(monkeypatch): monkeypatch.setattr( - parseutils, + sql_utils, 'get_last_select', lambda _parsed: TokenList([Token(DML, 'SELECT'), Token(Keyword, 'FROM')]), ) @@ -421,7 +407,7 @@ def test_extract_columns_from_select_stops_at_keyword_before_collecting_columns( def test_extract_tables_from_complete_statements_returns_empty_for_falsey_rough_parse(monkeypatch): - monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: []) + monkeypatch.setattr(sql_utils.sqlparse, 'parse', lambda _sql: []) assert extract_tables_from_complete_statements('select * from t') == [] @@ -441,14 +427,14 @@ class FakeStatement: def find_all(self, _table_type): return [FakeIdentifier()] - monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: ['stmt']) - monkeypatch.setattr(parseutils.sqlglot, 'parse_one', lambda *_args, **_kwargs: FakeStatement()) + monkeypatch.setattr(sql_utils.sqlparse, 'parse', lambda _sql: ['stmt']) + monkeypatch.setattr(sql_utils.sqlglot, 'parse_one', lambda *_args, **_kwargs: FakeStatement()) assert extract_tables_from_complete_statements('with cte as (select 1) select * from cte') == [] def test_query_is_single_table_update_returns_false_when_parse_result_is_empty(monkeypatch): - monkeypatch.setattr(parseutils.sqlparse, 'parse', lambda _sql: []) + monkeypatch.setattr(sql_utils.sqlparse, 'parse', lambda _sql: []) assert query_is_single_table_update('update test set x = 1') is False @@ -479,7 +465,7 @@ def test_is_destructive_update_without_where_clause(): def test_is_destructive_skips_empty_split_queries(monkeypatch): - monkeypatch.setattr(parseutils.sqlparse, 'split', lambda _queries: ['', '']) + monkeypatch.setattr(sql_utils.sqlparse, 'split', lambda _queries: ['', '']) assert is_destructive(['drop'], 'ignored') is False @@ -521,3 +507,86 @@ def test_is_dropping_database(sql, dbname, is_dropping): def test_is_dropping_database_skips_statements_without_enough_keywords(): assert is_dropping_database('drop foo', 'foo') is False + + +@pytest.mark.parametrize( + ('queries', 'expected'), + [ + ('select 1;', False), + ('alter table foo add column bar int;', True), + ('create table foo (id int);', True), + ('use foo;', True), + ('\\r foo localhost root', True), + ('\\u foo', True), + ('connect foo localhost root', True), + ('drop table foo;', True), + ('rename table foo to bar;', True), + ], +) +def test_need_completion_refresh(queries, expected): + assert need_completion_refresh(queries) is expected + + +def test_need_completion_refresh_ignores_queries_that_fail_to_split(monkeypatch): + class BrokenQuery: + def split(self): + raise RuntimeError('broken') + + monkeypatch.setattr(sql_utils.sqlparse, 'split', lambda _queries: [BrokenQuery(), 'select 1;']) + + assert need_completion_refresh('ignored') is False + + +@pytest.mark.parametrize( + ('queries', 'expected'), + [ + ('select 1;', False), + ('use foo;', True), + ('\\u foo', True), + ('\\r', False), + ('\\r foo localhost root', True), + ('connect', False), + ('connect foo localhost root', True), + ], +) +def test_need_completion_reset(queries, expected): + assert need_completion_reset(queries) is expected + + +def test_need_completion_reset_ignores_queries_that_fail_to_split(monkeypatch): + class BrokenQuery: + def split(self): + raise RuntimeError('broken') + + monkeypatch.setattr(sql_utils.sqlparse, 'split', lambda _queries: [BrokenQuery(), 'select 1;']) + + assert need_completion_reset('ignored') is False + + +@pytest.mark.parametrize( + ('status_plain', 'expected'), + [ + (None, False), + ('', False), + ('SELECT 1', False), + ('INSERT 1', True), + ('update 3', True), + ('rename table', True), + ], +) +def test_is_mutating(status_plain, expected): + assert is_mutating(status_plain) is expected + + +@pytest.mark.parametrize( + ('status_plain', 'expected'), + [ + (None, False), + ('', False), + ('SELECT 1', True), + ('select rows', True), + ('UPDATE 1', False), + ], +) +def test_is_select(status_plain, expected): + assert is_select(status_plain) is expected