From 5c4d3e374c6c587d6bcdddfb739ebaa0300826db Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 4 Apr 2026 17:44:51 -0400 Subject: [PATCH] move REPL execution paths to main_modes/repl.py Motivation: move code out of the monolithic main.py into logical layers. There is no functional change for this refactor, just the creation of main_modes/repl.py and the migration of (some of) the REPL logic out of main.py. There is much more to do, just in relation to the REPL. For example, complete_while_typing_filter() doesn't logically still belong in main.py, but it is bound up with state set in main.py. Likewise, the updates of the prompt string, and similar updates of window title and toolbar, logically belong with the REPL, but are a bit interwoven with other code in main.py. Another desirable change might be migrating the handlers in key_binding_utils.py to repl.py or some new file repl_handlers.py. We might also consider now removing sections of the relatively brittle tests in test_main_regression.py which relate to the REPL, and a note is left to that effect. run_cli() is left in place for now, but the intention is to fully replace it with main_repl(). --- changelog.md | 1 + mycli/constants.py | 3 + mycli/main.py | 521 +--------------- mycli/main_modes/repl.py | 572 +++++++++++++++++ mycli/types.py | 4 + test/pytests/test_main.py | 7 +- test/pytests/test_main_modes_repl.py | 890 +++++++++++++++++++++++++++ test/pytests/test_main_regression.py | 312 +++++----- 8 files changed, 1620 insertions(+), 690 deletions(-) create mode 100644 mycli/main_modes/repl.py create mode 100644 mycli/types.py create mode 100644 test/pytests/test_main_modes_repl.py diff --git a/changelog.md b/changelog.md index 43f014e6..1d9be1a2 100644 --- a/changelog.md +++ b/changelog.md @@ -39,6 +39,7 @@ Internal * Move `--execute` logic to the new `main_modes` with `--batch`. * Move `--list-dsn` logic to the new `main_modes` with `--batch`. * Move `--list-ssh-config` logic to the new `main_modes` with `--batch`. +* Move REPL logic to the new `main_modes`. * Sort coverage report in tox suite. * Skip more tests when a database connection is not present. * Move SQL utilities to a new `sql_utils.py`. diff --git a/mycli/constants.py b/mycli/constants.py index 88edaa76..2d278ae4 100644 --- a/mycli/constants.py +++ b/mycli/constants.py @@ -10,3 +10,6 @@ DEFAULT_USER = 'root' TEST_DATABASE = 'mycli_test_db' + +DEFAULT_WIDTH = 80 +DEFAULT_HEIGHT = 25 diff --git a/mycli/main.py b/mycli/main.py index ba1484b5..97d4513e 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,13 +1,12 @@ from __future__ import annotations -from collections import defaultdict, namedtuple +from collections import defaultdict from dataclasses import dataclass from decimal import Decimal import functools from io import TextIOWrapper import logging import os -import random import re import shutil import subprocess @@ -21,11 +20,8 @@ except ImportError: pass from datetime import datetime -from importlib import resources import itertools -from random import choice from textwrap import dedent -from time import time from urllib.parse import parse_qs, unquote, urlparse from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors @@ -37,11 +33,9 @@ import keyring from prompt_toolkit import print_formatted_text from prompt_toolkit.application.current import get_app -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest -from prompt_toolkit.completion import Completion, DynamicCompleter +from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document -from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode -from prompt_toolkit.filters import Condition, HasFocus, IsDone +from prompt_toolkit.filters import Condition from prompt_toolkit.formatted_text import ( ANSI, HTML, @@ -50,33 +44,27 @@ to_formatted_text, to_plain_text, ) -from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor -from prompt_toolkit.lexers import PygmentsLexer -from prompt_toolkit.output import ColorDepth -from prompt_toolkit.shortcuts import CompleteStyle, PromptSession +from prompt_toolkit.shortcuts import PromptSession import pymysql from pymysql.constants.CR import CR_SERVER_LOST from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR from pymysql.cursors import Cursor import sqlparse -from mycli import __version__ -from mycli.clibuffer import cli_is_multiline +import mycli as mycli_package from mycli.clistyle import style_factory_helpers, style_factory_ptoolkit -from mycli.clitoolbar import create_toolbar_tokens_func from mycli.compat import WIN from mycli.completion_refresher import CompletionRefresher from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, strip_matching_quotes, write_default_config from mycli.constants import ( DEFAULT_CHARSET, + DEFAULT_HEIGHT, DEFAULT_HOST, DEFAULT_PORT, - HOME_URL, + DEFAULT_WIDTH, ISSUES_URL, REPO_URL, ) -from mycli.key_bindings import mycli_bindings -from mycli.lexer import MyCliLexer from mycli.main_modes.batch import ( main_batch_from_stdin, main_batch_with_progress_bar, @@ -86,42 +74,25 @@ from mycli.main_modes.execute import main_execute_from_cli from mycli.main_modes.list_dsn import main_list_dsn from mycli.main_modes.list_ssh_config import main_list_ssh_config +from mycli.main_modes.repl import main_repl from mycli.packages import special from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme from mycli.packages.filepaths import dir_path_exists, guess_socket_location -from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command -from mycli.packages.key_binding_utils import ( - handle_clip_command, - handle_editor_command, -) -from mycli.packages.prompt_utils import confirm, confirm_destructive_query -from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp +from mycli.packages.prompt_utils import confirm_destructive_query from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import ArgType from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count -from mycli.packages.sql_utils import ( - is_dropping_database, - is_mutating, - is_select, - need_completion_refresh, - need_completion_reset, -) from mycli.packages.sqlresult import SQLResult from mycli.packages.ssh_utils import read_ssh_config from mycli.packages.string_utils import sanitize_terminal_title from mycli.packages.tabular_output import sql_format from mycli.sqlcompleter import SQLCompleter from mycli.sqlexecute import FIELD_TYPES, SQLExecute +from mycli.types import Query sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] -# Query tuples are used for maintaining history -Query = namedtuple("Query", ["query", "successful", "mutating"]) - -SUPPORT_INFO = f"Home: {HOME_URL}\nBug tracker: {ISSUES_URL}" -DEFAULT_WIDTH = 80 -DEFAULT_HEIGHT = 25 MIN_COMPLETION_TRIGGER = 1 EMPTY_PASSWORD_FLAG_SENTINEL = -1 @@ -880,434 +851,7 @@ def output_timing(self, timing: str, is_warnings_style: bool = False) -> None: print_formatted_text(styled_timing, style=self.ptoolkit_style) def run_cli(self) -> None: - iterations = 0 - sqlexecute = self.sqlexecute - assert isinstance(sqlexecute, SQLExecute) - logger = self.logger - self.configure_pager() - - if self.smart_completion: - self.refresh_completions() - - history_file = os.path.expanduser(os.environ.get("MYCLI_HISTFILE", self.config.get("history_file", "~/.mycli-history"))) - if dir_path_exists(history_file): - history = FileHistoryWithTimestamp(history_file) - else: - history = None - self.echo( - f'Error: Unable to open the history file "{history_file}". Your query history will not be saved.', - err=True, - fg="red", - ) - - key_bindings = mycli_bindings(self) - - if not self.less_chatty: - print(sqlexecute.server_info) - print("mycli", __version__) - print(SUPPORT_INFO) - if random.random() <= 0.5: - print("Thanks to the contributor —", thanks_picker()) - else: - print("Tip —", tips_picker()) - - def get_prompt_message(app) -> ANSI: - if app.current_buffer.text: - return self.last_prompt_message - prompt = self.get_prompt(self.prompt_format, app.render_counter) - if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: - prompt = self.get_prompt(self.default_prompt_splitln, app.render_counter) - self.prompt_lines = prompt.count('\n') + 1 - prompt = prompt.replace("\\x1b", "\x1b") - if not self.prompt_lines: - self.prompt_lines = prompt.count('\n') + 1 - self.last_prompt_message = ANSI(prompt) - return self.last_prompt_message - - def get_continuation(width: int, _two: int, _three: int) -> AnyFormattedText: - if self.multiline_continuation_char == "": - continuation = "" - elif self.multiline_continuation_char: - left_padding = width - len(self.multiline_continuation_char) - continuation = " " * max((left_padding - 1), 0) + self.multiline_continuation_char + " " - else: - continuation = " " - return [("class:continuation", continuation)] - - def show_initial_toolbar_help() -> bool: - return iterations == 0 - - # Keep track of whether or not the query is mutating. In case - # of a multi-statement query, the overall query is considered - # mutating if any one of the component statements is mutating - mutating = False - - def output_res(results: Generator[SQLResult], start: float) -> None: - nonlocal mutating - result_count = watch_count = 0 - for result in results: - logger.debug("preamble: %r", result.preamble) - logger.debug("header: %r", result.header) - logger.debug("rows: %r", result.rows) - logger.debug("status: %r", result.status) - logger.debug("command: %r", result.command) - threshold = 1000 - # If this is a watch query, offset the start time on the 2nd+ iteration - # to account for the sleep duration - if result.command is not None and result.command["name"] == "watch": - if watch_count > 0: - try: - watch_seconds = float(result.command["seconds"]) - start += watch_seconds - except ValueError as e: - self.echo(f"Invalid watch sleep time provided ({e}).", err=True, fg="red") - sys.exit(1) - else: - watch_count += 1 - if is_select(result.status_plain) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold: - self.echo( - f"The result set has more than {threshold} rows.", - fg="red", - ) - if not confirm("Do you want to continue?"): - self.echo("Aborted!", err=True, fg="red") - break - - if self.auto_vertical_output: - if self.prompt_app is not None: - max_width = self.prompt_app.output.get_size().columns - else: - max_width = DEFAULT_WIDTH - else: - max_width = None - - formatted = self.format_sqlresult( - result, - is_expanded=special.is_expanded_output(), - is_redirected=special.is_redirected(), - null_string=self.null_string, - numeric_alignment=self.numeric_alignment, - binary_display=self.binary_display, - max_width=max_width, - ) - - t = time() - start - try: - if result_count > 0: - self.echo("") - try: - self.output(formatted, result) - except KeyboardInterrupt: - pass - if self.beep_after_seconds > 0 and t >= self.beep_after_seconds: - assert self.prompt_app is not None - self.prompt_app.output.bell() - if special.is_timing_enabled(): - self.output_timing(f"Time: {t:0.03f}s") - except KeyboardInterrupt: - pass - - start = time() - result_count += 1 - mutating = mutating or is_mutating(result.status_plain) - - # get and display warnings if enabled - if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: - warnings = sqlexecute.run("SHOW WARNINGS") - t = time() - start - saw_warning = False - for warning in warnings: - saw_warning = True - formatted = self.format_sqlresult( - warning, - is_expanded=special.is_expanded_output(), - is_redirected=special.is_redirected(), - null_string=self.null_string, - numeric_alignment=self.numeric_alignment, - binary_display=self.binary_display, - max_width=max_width, - is_warnings_style=True, - ) - self.echo("") - self.output(formatted, warning, is_warnings_style=True) - - if saw_warning and special.is_timing_enabled(): - self.output_timing(f"Time: {t:0.03f}s", is_warnings_style=True) - - def keepalive_hook(_context): - """ - prompt_toolkit shares the event loop with this hook, which seems - to get called a bit faster than once/second on one machine. - - It would be nice to reset the counter whenever user input is made, - but was not clear how to do that with context.input_is_ready(). - - Example at https://github.com/prompt-toolkit/python-prompt-toolkit/blob/main/examples/prompts/inputhook.py - """ - if self.keepalive_ticks is None: - return - if self.keepalive_ticks < 1: - return - self._keepalive_counter += 1 - if self._keepalive_counter > self.keepalive_ticks: - self._keepalive_counter = 0 - self.logger.debug('keepalive ping') - try: - assert self.sqlexecute is not None - assert self.sqlexecute.conn is not None - self.sqlexecute.conn.ping(reconnect=False) - except Exception as e: - self.logger.debug('keepalive ping error %r', e) - - def one_iteration(text: str | None = None) -> None: - inputhook = keepalive_hook if self.keepalive_ticks and self.keepalive_ticks >= 1 else None - if text is None: - try: - assert self.prompt_app is not None - loaded_message_fn = functools.partial(get_prompt_message, self.prompt_app.app) - text = self.prompt_app.prompt( - inputhook=inputhook, - message=loaded_message_fn, - ) - except KeyboardInterrupt: - return - - special.set_expanded_output(False) - special.set_forced_horizontal_output(False) - - try: - text = handle_editor_command( - self, - text, - inputhook, - loaded_message_fn, - ) - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - - try: - if handle_clip_command(self, text): - return - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - # LLM command support - while special.is_llm_command(text): - start = time() - try: - assert isinstance(self.sqlexecute, SQLExecute) - assert sqlexecute.conn is not None - cur = sqlexecute.conn.cursor() - context, sql, duration = special.handle_llm( - text, - cur, - sqlexecute.dbname or '', - self.llm_prompt_field_truncate, - self.llm_prompt_section_truncate, - ) - if context: - click.echo("LLM Response:") - click.echo(context) - click.echo("---") - if special.is_timing_enabled(): - self.output_timing(f"Time: {duration:.2f} seconds") - text = self.prompt_app.prompt( - default=sql or '', - inputhook=inputhook, - message=loaded_message_fn, - ) - except KeyboardInterrupt: - return - except special.FinishIteration as e: - if e.results: - return output_res(e.results, start) - else: - return None - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - - text = text.strip() - - if not text: - return - - if is_redirect_command(text): - sql_part, command_part, file_operator_part, file_part = get_redirect_components(text) - text = sql_part or '' - try: - special.set_redirect(command_part, file_operator_part, file_part) - except (FileNotFoundError, OSError, RuntimeError) as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - - if self.destructive_warning: - destroy = confirm_destructive_query(self.destructive_keywords, text) - if destroy is None: - pass # Query was not destructive. Nothing to do here. - elif destroy is True: - self.echo("Your call!") - else: - self.echo("Wise choice!") - return - else: - destroy = True - - try: - logger.debug("sql: %r", text) - - special.write_tee(self.last_prompt_message, nl=False) - special.write_tee(text) - self.log_query(text) - - successful = False - start = time() - res = sqlexecute.run(text) - self.main_formatter.query = text - self.redirect_formatter.query = text - successful = True - output_res(res, start) - special.unset_once_if_written(self.post_redirect_command) - special.flush_pipe_once_if_written(self.post_redirect_command) - except pymysql.err.InterfaceError: - # attempt to reconnect - if not self.reconnect(): - return - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except EOFError as e: - raise e - except KeyboardInterrupt: - # get last connection id - connection_id_to_kill = sqlexecute.connection_id or 0 - # some mysql-compatible databases may not implement connection_id() - if connection_id_to_kill > 0: - logger.debug("connection id to kill: %r", connection_id_to_kill) - try: - sqlexecute.connect() - for kill_result in sqlexecute.run(f"kill {connection_id_to_kill}"): - status_str = str(kill_result.status_plain).lower() - if status_str.find("ok") > -1: - logger.debug("cancelled query, connection id: %r, sql: %r", connection_id_to_kill, text) - self.echo(f"Cancelled query id: {connection_id_to_kill}", err=True, fg="blue") - else: - logger.debug( - "Failed to confirm query cancellation, connection id: %r, sql: %r", - connection_id_to_kill, - text, - ) - self.echo(f"Failed to confirm query cancellation, id: {connection_id_to_kill}", err=True, fg="red") - except Exception as e2: - self.echo(f"Encountered error while cancelling query: {e2}", err=True, fg="red") - else: - logger.debug("Did not get a connection id, skip cancelling query") - self.echo("Did not get a connection id, skip cancelling query", err=True, fg="red") - except NotImplementedError: - self.echo("Not Yet Implemented.", fg="yellow") - except pymysql.OperationalError as e1: - logger.debug("Exception: %r", e1) - if e1.args[0] in (2003, 2006, 2013): - # attempt to reconnect - if not self.reconnect(): - return - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - else: - logger.error("sql: %r, error: %r", text, e1) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e1), err=True, fg="red") - except Exception as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - else: - if is_dropping_database(text, sqlexecute.dbname): - sqlexecute.dbname = None - sqlexecute.connect() - - # Refresh the table names and column names if necessary. - if need_completion_refresh(text): - self.refresh_completions(reset=need_completion_reset(text)) - finally: - if self.logfile is False: - self.echo("Warning: This query was not logged.", err=True, fg="red") - query = Query(text, successful, mutating) - self.query_history.append(query) - - if self.toolbar_format.lower() == 'none': - get_toolbar_tokens = None - else: - get_toolbar_tokens = create_toolbar_tokens_func( - self, - show_initial_toolbar_help, - self.toolbar_format, - ) - - if self.wider_completion_menu: - complete_style = CompleteStyle.MULTI_COLUMN - else: - complete_style = CompleteStyle.COLUMN - - with self._completer_lock: - if self.key_bindings == "vi": - editing_mode = EditingMode.VI - else: - editing_mode = EditingMode.EMACS - - self.prompt_app = PromptSession( - color_depth=ColorDepth.DEPTH_24_BIT if 'truecolor' in os.getenv('COLORTERM', '').lower() else None, - lexer=PygmentsLexer(MyCliLexer), - reserve_space_for_menu=self.get_reserved_space(), - prompt_continuation=get_continuation, - bottom_toolbar=get_toolbar_tokens, - complete_style=complete_style, - input_processors=[ - ConditionalProcessor( - processor=HighlightMatchingBracketProcessor(chars="[](){}"), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() - ) - ], - tempfile_suffix=".sql", - completer=DynamicCompleter(lambda: self.completer), - complete_in_thread=True, - history=history, - auto_suggest=ThreadedAutoSuggest(AutoSuggestFromHistory()), - complete_while_typing=complete_while_typing_filter, - multiline=cli_is_multiline(self), - # why not self.ptoolkit_style here? - style=style_factory_ptoolkit(self.syntax_style, self.cli_style), - include_default_pygments_style=False, - key_bindings=key_bindings, - enable_open_in_editor=True, - enable_system_prompt=True, - enable_suspend=True, - editing_mode=editing_mode, - search_ignore_case=True, - ) - - if self.key_bindings == 'vi': - self.prompt_app.app.ttimeoutlen = self.vi_ttimeoutlen - else: - self.prompt_app.app.ttimeoutlen = self.emacs_ttimeoutlen - - self.set_all_external_titles() - - try: - while True: - one_iteration() - iterations += 1 - except EOFError: - special.close_tee() - if not self.less_chatty: - self.echo("Goodbye!") + main_repl(self) def reconnect(self, database: str = "") -> bool: """ @@ -2107,7 +1651,7 @@ class CliArgs: @click.command() @clickdc.adddc('cli_args', CliArgs) -@click.version_option(__version__, '--version', '-V', help="Output mycli's version.") +@click.version_option(mycli_package.__version__, '--version', '-V', help="Output mycli's version.") def click_entrypoint( cli_args: CliArgs, ) -> None: @@ -2566,47 +2110,6 @@ def get_password_from_file(password_file: str | None) -> str | None: mycli.close() -def thanks_picker() -> str: - import mycli - - lines: str = "" - try: - with resources.files(mycli).joinpath("AUTHORS").open('r') as f: - lines += f.read() - except FileNotFoundError: - pass - - try: - with resources.files(mycli).joinpath("SPONSORS").open('r') as f: - lines += f.read() - except FileNotFoundError: - pass - - contents = [] - for line in lines.split("\n"): - if m := re.match(r"^ *\* (.*)", line): - contents.append(m.group(1)) - return choice(contents) if contents else 'our sponsors' - - -def tips_picker() -> str: - import mycli - - tips = [] - - try: - with resources.files(mycli).joinpath('TIPS').open('r') as f: - for line in f: - if line.startswith("#"): - continue - if tip := line.strip(): - tips.append(tip) - except FileNotFoundError: - pass - - return choice(tips) if tips else r'\? or "help" for help!' - - def main() -> int | None: try: result = click_entrypoint.main( diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py new file mode 100644 index 00000000..a507a38f --- /dev/null +++ b/mycli/main_modes/repl.py @@ -0,0 +1,572 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from importlib import resources +import os +import random +import re +import sys +import time +import traceback +from typing import TYPE_CHECKING, Any, Generator + +import click +import prompt_toolkit +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest +from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ( + ANSI, +) +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.output import ColorDepth +from prompt_toolkit.shortcuts import CompleteStyle, PromptSession +import pymysql +from pymysql.cursors import Cursor + +import mycli as mycli_package +from mycli.clibuffer import cli_is_multiline +from mycli.clistyle import style_factory_ptoolkit +from mycli.clitoolbar import create_toolbar_tokens_func +from mycli.constants import ( + DEFAULT_WIDTH, + HOME_URL, + ISSUES_URL, +) +from mycli.key_bindings import mycli_bindings +from mycli.lexer import MyCliLexer +from mycli.packages import special +from mycli.packages.filepaths import dir_path_exists +from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command +from mycli.packages.key_binding_utils import ( + handle_clip_command, + handle_editor_command, +) +from mycli.packages.prompt_utils import confirm, confirm_destructive_query +from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp +from mycli.packages.sql_utils import ( + is_dropping_database, + is_mutating, + is_select, + need_completion_refresh, + need_completion_reset, +) +from mycli.packages.sqlresult import SQLResult +from mycli.sqlexecute import SQLExecute +from mycli.types import Query + +if TYPE_CHECKING: + from prompt_toolkit.formatted_text import AnyFormattedText + + from mycli.main import MyCli + + +SUPPORT_INFO = f"Home: {HOME_URL}\nBug tracker: {ISSUES_URL}" + + +def _main_module(): + from mycli import main as main_module + + return main_module + + +@dataclass(slots=True) +class ReplState: + iterations: int = 0 + mutating: bool = False + + +def _create_history(mycli: 'MyCli') -> FileHistoryWithTimestamp | None: + history_file = os.path.expanduser(os.environ.get('MYCLI_HISTFILE', mycli.config.get('history_file', '~/.mycli-history'))) + if dir_path_exists(history_file): + return FileHistoryWithTimestamp(history_file) + + mycli.echo( + f'Error: Unable to open the history file "{history_file}". Your query history will not be saved.', + err=True, + fg='red', + ) + return None + + +def _show_startup_banner( + mycli: 'MyCli', + sqlexecute: SQLExecute, +) -> None: + if mycli.less_chatty: + return + + print(sqlexecute.server_info) + print('mycli', mycli_package.__version__) + print(SUPPORT_INFO) + if random.random() <= 0.5: + print('Thanks to the contributor —', _thanks_picker()) + else: + print('Tip —', _tips_picker()) + + +def _get_prompt_message( + mycli: 'MyCli', + app: prompt_toolkit.application.application.Application, +) -> ANSI: + if app.current_buffer.text: + return mycli.last_prompt_message + + prompt = mycli.get_prompt(mycli.prompt_format, app.render_counter) + if mycli.prompt_format == mycli.default_prompt and len(prompt) > mycli.max_len_prompt: + prompt = mycli.get_prompt(mycli.default_prompt_splitln, app.render_counter) + mycli.prompt_lines = prompt.count('\n') + 1 + prompt = prompt.replace('\\x1b', '\x1b') + if not mycli.prompt_lines: + mycli.prompt_lines = prompt.count('\n') + 1 + mycli.last_prompt_message = ANSI(prompt) + return mycli.last_prompt_message + + +def _get_continuation( + mycli: 'MyCli', + width: int, + _two: int, + _three: int, +) -> AnyFormattedText: + if mycli.multiline_continuation_char == '': + continuation = '' + elif mycli.multiline_continuation_char: + left_padding = width - len(mycli.multiline_continuation_char) + continuation = ' ' * max((left_padding - 1), 0) + mycli.multiline_continuation_char + ' ' + else: + continuation = ' ' + return [('class:continuation', continuation)] + + +def _output_results( + mycli: 'MyCli', + state: ReplState, + results: Generator[SQLResult], + start: float, +) -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + + result_count = 0 + watch_count = 0 + for result in results: + mycli.logger.debug('preamble: %r', result.preamble) + mycli.logger.debug('header: %r', result.header) + mycli.logger.debug('rows: %r', result.rows) + mycli.logger.debug('status: %r', result.status) + mycli.logger.debug('command: %r', result.command) + threshold = 1000 + if result.command is not None and result.command['name'] == 'watch': + if watch_count > 0: + try: + watch_seconds = float(result.command['seconds']) + start += watch_seconds + except ValueError as e: + mycli.echo(f'Invalid watch sleep time provided ({e}).', err=True, fg='red') + sys.exit(1) + else: + watch_count += 1 + + if is_select(result.status_plain) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold: + mycli.echo( + f'The result set has more than {threshold} rows.', + fg='red', + ) + if not confirm('Do you want to continue?'): + mycli.echo('Aborted!', err=True, fg='red') + break + + if mycli.auto_vertical_output: + if mycli.prompt_app is not None: + max_width = mycli.prompt_app.output.get_size().columns + else: + max_width = DEFAULT_WIDTH + else: + max_width = None + + formatted = mycli.format_sqlresult( + result, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=mycli.null_string, + numeric_alignment=mycli.numeric_alignment, + binary_display=mycli.binary_display, + max_width=max_width, + ) + + duration = time.time() - start + try: + if result_count > 0: + mycli.echo('') + try: + mycli.output(formatted, result) + except KeyboardInterrupt: + pass + if mycli.beep_after_seconds > 0 and duration >= mycli.beep_after_seconds: + assert mycli.prompt_app is not None + mycli.prompt_app.output.bell() + if special.is_timing_enabled(): + mycli.output_timing(f'Time: {duration:0.03f}s') + except KeyboardInterrupt: + pass + + start = time.time() + result_count += 1 + state.mutating = state.mutating or is_mutating(result.status_plain) + + if mycli.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: + warnings = sqlexecute.run('SHOW WARNINGS') + warnings_duration = time.time() - start + saw_warning = False + for warning in warnings: + saw_warning = True + formatted = mycli.format_sqlresult( + warning, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=mycli.null_string, + numeric_alignment=mycli.numeric_alignment, + binary_display=mycli.binary_display, + max_width=max_width, + is_warnings_style=True, + ) + mycli.echo('') + mycli.output(formatted, warning, is_warnings_style=True) + + if saw_warning and special.is_timing_enabled(): + mycli.output_timing(f'Time: {warnings_duration:0.03f}s', is_warnings_style=True) + + +def _keepalive_hook( + mycli: 'MyCli', + _context: Any, +) -> None: + if mycli.keepalive_ticks is None: + return + if mycli.keepalive_ticks < 1: + return + + mycli._keepalive_counter += 1 + if mycli._keepalive_counter > mycli.keepalive_ticks: + mycli._keepalive_counter = 0 + mycli.logger.debug('keepalive ping') + try: + assert mycli.sqlexecute is not None + assert mycli.sqlexecute.conn is not None + mycli.sqlexecute.conn.ping(reconnect=False) + except Exception as e: + mycli.logger.debug('keepalive ping error %r', e) + + +def _build_prompt_session( + mycli: 'MyCli', + state: ReplState, + history: FileHistoryWithTimestamp | None, + key_bindings: KeyBindings, +) -> None: + if mycli.toolbar_format.lower() == 'none': + get_toolbar_tokens = None + else: + get_toolbar_tokens = create_toolbar_tokens_func( + mycli, + lambda: state.iterations == 0, + mycli.toolbar_format, + ) + + if mycli.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with mycli._completer_lock: + if mycli.key_bindings == 'vi': + editing_mode = EditingMode.VI + else: + editing_mode = EditingMode.EMACS + + mycli.prompt_app = PromptSession( + color_depth=ColorDepth.DEPTH_24_BIT if 'truecolor' in os.getenv('COLORTERM', '').lower() else None, + lexer=PygmentsLexer(MyCliLexer), + reserve_space_for_menu=mycli.get_reserved_space(), + prompt_continuation=lambda width, two, three: _get_continuation(mycli, width, two, three), + bottom_toolbar=get_toolbar_tokens, + complete_style=complete_style, + input_processors=[ + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars='[](){}'), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ) + ], + tempfile_suffix='.sql', + completer=DynamicCompleter(lambda: mycli.completer), + complete_in_thread=True, + history=history, + auto_suggest=ThreadedAutoSuggest(AutoSuggestFromHistory()), + complete_while_typing=_main_module().complete_while_typing_filter, + multiline=cli_is_multiline(mycli), + style=style_factory_ptoolkit(mycli.syntax_style, mycli.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=editing_mode, + search_ignore_case=True, + ) + + if mycli.key_bindings == 'vi': + mycli.prompt_app.app.ttimeoutlen = mycli.vi_ttimeoutlen + else: + mycli.prompt_app.app.ttimeoutlen = mycli.emacs_ttimeoutlen + + +def _one_iteration( + mycli: 'MyCli', + state: ReplState, + text: str | None = None, +) -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + + inputhook = partial(_keepalive_hook, mycli) if mycli.keepalive_ticks and mycli.keepalive_ticks >= 1 else None + + if text is None: + try: + assert mycli.prompt_app is not None + loaded_message_fn = partial(_get_prompt_message, mycli, mycli.prompt_app.app) + text = mycli.prompt_app.prompt( + inputhook=inputhook, + message=loaded_message_fn, + ) + except KeyboardInterrupt: + return + + special.set_expanded_output(False) + special.set_forced_horizontal_output(False) + + try: + text = handle_editor_command( + mycli, + text, + inputhook, + loaded_message_fn, + ) + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + try: + if handle_clip_command(mycli, text): + return + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + while special.is_llm_command(text): + start = time.time() + try: + assert sqlexecute.conn is not None + cur = sqlexecute.conn.cursor() + context, sql, duration = special.handle_llm( + text, + cur, + sqlexecute.dbname or '', + mycli.llm_prompt_field_truncate, + mycli.llm_prompt_section_truncate, + ) + if context: + click.echo('LLM Response:') + click.echo(context) + click.echo('---') + if special.is_timing_enabled(): + mycli.output_timing(f'Time: {duration:.2f} seconds') + assert mycli.prompt_app is not None + text = mycli.prompt_app.prompt( + default=sql or '', + inputhook=inputhook, + message=loaded_message_fn, + ) + except KeyboardInterrupt: + return + except special.FinishIteration as e: + if e.results: + _output_results(mycli, state, e.results, start) + return + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + text = text.strip() + if not text: + return + + if is_redirect_command(text): + sql_part, command_part, file_operator_part, file_part = get_redirect_components(text) + text = sql_part or '' + try: + special.set_redirect(command_part, file_operator_part, file_part) + except (FileNotFoundError, OSError, RuntimeError) as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + if mycli.destructive_warning: + destroy = confirm_destructive_query(mycli.destructive_keywords, text) + if destroy is None: + pass + elif destroy is True: + mycli.echo('Your call!') + else: + mycli.echo('Wise choice!') + return + + successful = False + try: + mycli.logger.debug('sql: %r', text) + special.write_tee(mycli.last_prompt_message, nl=False) + special.write_tee(text) + mycli.log_query(text) + + start = time.time() + results = sqlexecute.run(text) + mycli.main_formatter.query = text + mycli.redirect_formatter.query = text + successful = True + _output_results(mycli, state, results, start) + special.unset_once_if_written(mycli.post_redirect_command) + special.flush_pipe_once_if_written(mycli.post_redirect_command) + except pymysql.err.InterfaceError: + if not mycli.reconnect(): + return + _one_iteration(mycli, state, text) + return + except EOFError as e: + raise e + except KeyboardInterrupt: + connection_id_to_kill = sqlexecute.connection_id or 0 + if connection_id_to_kill > 0: + mycli.logger.debug('connection id to kill: %r', connection_id_to_kill) + try: + sqlexecute.connect() + for kill_result in sqlexecute.run(f'kill {connection_id_to_kill}'): + status_str = str(kill_result.status_plain).lower() + if status_str.find('ok') > -1: + mycli.logger.debug('cancelled query, connection id: %r, sql: %r', connection_id_to_kill, text) + mycli.echo(f'Cancelled query id: {connection_id_to_kill}', err=True, fg='blue') + else: + mycli.logger.debug( + 'Failed to confirm query cancellation, connection id: %r, sql: %r', + connection_id_to_kill, + text, + ) + mycli.echo(f'Failed to confirm query cancellation, id: {connection_id_to_kill}', err=True, fg='red') + except Exception as e2: + mycli.echo(f'Encountered error while cancelling query: {e2}', err=True, fg='red') + else: + mycli.logger.debug('Did not get a connection id, skip cancelling query') + mycli.echo('Did not get a connection id, skip cancelling query', err=True, fg='red') + except NotImplementedError: + mycli.echo('Not Yet Implemented.', fg='yellow') + except pymysql.OperationalError as e1: + mycli.logger.debug('Exception: %r', e1) + if e1.args[0] in (2003, 2006, 2013): + if not mycli.reconnect(): + return + _one_iteration(mycli, state, text) + return + + mycli.logger.error('sql: %r, error: %r', text, e1) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e1), err=True, fg='red') + except Exception as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + else: + if is_dropping_database(text, sqlexecute.dbname): + sqlexecute.dbname = None + sqlexecute.connect() + + if need_completion_refresh(text): + mycli.refresh_completions(reset=need_completion_reset(text)) + finally: + if mycli.logfile is False: + mycli.echo('Warning: This query was not logged.', err=True, fg='red') + + query = Query(text, successful, state.mutating) + mycli.query_history.append(query) + + +def _thanks_picker() -> str: + lines: str = "" + + try: + with resources.files(mycli_package).joinpath("AUTHORS").open('r') as f: + lines += f.read() + except FileNotFoundError: + pass + + try: + with resources.files(mycli_package).joinpath("SPONSORS").open('r') as f: + lines += f.read() + except FileNotFoundError: + pass + + contents = [] + for line in lines.split("\n"): + if m := re.match(r"^ *\* (.*)", line): + contents.append(m.group(1)) + return random.choice(contents) if contents else 'our sponsors' + + +def _tips_picker() -> str: + tips = [] + + try: + with resources.files(mycli_package).joinpath('TIPS').open('r') as f: + for line in f: + if line.startswith("#"): + continue + if tip := line.strip(): + tips.append(tip) + except FileNotFoundError: + pass + + return random.choice(tips) if tips else r'\? or "help" for help!' + + +def main_repl(mycli: 'MyCli') -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + state = ReplState() + + mycli.configure_pager() + if mycli.smart_completion: + mycli.refresh_completions() + + history = _create_history(mycli) + key_bindings = mycli_bindings(mycli) + _show_startup_banner(mycli, sqlexecute) + _build_prompt_session(mycli, state, history, key_bindings) + mycli.set_all_external_titles() + + try: + while True: + _one_iteration(mycli, state) + state.iterations += 1 + except EOFError: + special.close_tee() + if not mycli.less_chatty: + mycli.echo('Goodbye!') diff --git a/mycli/types.py b/mycli/types.py new file mode 100644 index 00000000..207d62d9 --- /dev/null +++ b/mycli/types.py @@ -0,0 +1,4 @@ +from collections import namedtuple + +# Query tuples are used for maintaining history +Query = namedtuple("Query", ["query", "successful", "mutating"]) diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 3af76d21..bcfccaac 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -20,7 +20,7 @@ DEFAULT_USER, TEST_DATABASE, ) -from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint, thanks_picker +from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint import mycli.packages.special from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.sqlresult import SQLResult @@ -682,11 +682,6 @@ def test_batch_mode_csv(executor): assert expected in "".join(result.output) -def test_thanks_picker_utf8(): - name = thanks_picker() - assert name and isinstance(name, str) - - def test_help_strings_end_with_periods(): """Make sure click options have help text that end with a period.""" for param in click_entrypoint.params: diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py new file mode 100644 index 00000000..2d1812c6 --- /dev/null +++ b/test/pytests/test_main_modes_repl.py @@ -0,0 +1,890 @@ +from __future__ import annotations + +import builtins +from collections.abc import Generator, Iterator +from dataclasses import dataclass +from io import StringIO +import os +from types import SimpleNamespace +from typing import Any, Literal, cast + +from prompt_toolkit.formatted_text import to_plain_text +import pymysql +import pytest + +import mycli.main as main_module +import mycli.main_modes.repl as repl_mode +from mycli.packages.sqlresult import SQLResult + + +class DummyLogger: + def __init__(self) -> None: + self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + self.error_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def debug(self, *args: Any, **kwargs: Any) -> None: + self.debug_calls.append((args, kwargs)) + + def error(self, *args: Any, **kwargs: Any) -> None: + self.error_calls.append((args, kwargs)) + + +@dataclass +class DummyFormatterWithQuery: + query: str = '' + + +class FakeApp: + def __init__(self, text: str = '', render_counter: int = 0) -> None: + self.current_buffer = SimpleNamespace(text=text) + self.render_counter = render_counter + self.ttimeoutlen: float | None = None + + +class FakePromptOutput: + def __init__(self, columns: int = 80, rows: int = 24) -> None: + self.columns = columns + self.rows = rows + self.bell_count = 0 + + def get_size(self) -> SimpleNamespace: + return SimpleNamespace(columns=self.columns, rows=self.rows) + + def bell(self) -> None: + self.bell_count += 1 + + +class FakePromptSession: + def __init__(self, responses: list[Any] | None = None, columns: int = 80, rows: int = 24) -> None: + self.responses = list(responses or []) + self.output = FakePromptOutput(columns=columns, rows=rows) + self.app = FakeApp() + self.prompt_calls: list[dict[str, Any]] = [] + + def prompt(self, **kwargs: Any) -> str: + self.prompt_calls.append(dict(kwargs)) + if not self.responses: + raise EOFError() + response = self.responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + +class FakeCursorBase: + def __init__( + self, + rows: list[tuple[Any, ...]] | None = None, + rowcount: int = 0, + warning_count: int = 0, + ) -> None: + self._rows = list(rows or []) + self.rowcount = rowcount + self.warning_count = warning_count + + def __iter__(self) -> Iterator[tuple[Any, ...]]: + return iter(self._rows) + + +class FakeConnection: + def __init__(self, ping_exc: Exception | None = None, cursor_value: Any = 'cursor') -> None: + self.ping_exc = ping_exc + self.cursor_value = cursor_value + self.ping_calls: list[bool] = [] + + def ping(self, reconnect: bool = False) -> None: + self.ping_calls.append(reconnect) + if self.ping_exc is not None: + raise self.ping_exc + + def cursor(self) -> Any: + return self.cursor_value + + +class ReusableLock: + def __enter__(self) -> 'ReusableLock': + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + return False + + +def sqlresult_generator(*results: SQLResult) -> Generator[SQLResult, None, None]: + for result in results: + yield result + + +class FakeResourceTree: + def __init__(self, files: dict[str, str], path: str | None = None) -> None: + self.files = files + self.path = path + + def joinpath(self, path: str) -> 'FakeResourceTree': + return FakeResourceTree(self.files, path) + + def open(self, mode: str) -> StringIO: + assert self.path is not None + if self.path not in self.files: + raise FileNotFoundError(self.path) + return StringIO(self.files[self.path]) + + +def make_repl_cli(sqlexecute: Any | None = None) -> Any: + cli = SimpleNamespace() + cli.logger = DummyLogger() + cli.query_history = [] + cli.last_prompt_message = repl_mode.ANSI('') + cli.last_custom_toolbar_message = repl_mode.ANSI('') + cli.prompt_lines = 0 + cli.default_prompt = r'\t \u@\h:\d> ' + cli.default_prompt_splitln = r'\u@\h\n(\t):\d>' + cli.max_len_prompt = 45 + cli.prompt_format = cli.default_prompt + cli.multiline_continuation_char = '>' + cli.toolbar_format = 'default' + cli.less_chatty = True + cli.keepalive_ticks = None + cli._keepalive_counter = 0 + cli.auto_vertical_output = False + cli.beep_after_seconds = 0.0 + cli.show_warnings = False + cli.null_string = '' + cli.numeric_alignment = 'right' + cli.binary_display = None + cli.prompt_app = None + cli.post_redirect_command = None + cli.logfile = None + cli.smart_completion = False + cli.config = {'history_file': '~/.mycli-history-testing'} + cli.key_bindings = 'emacs' + cli.wider_completion_menu = False + cli._completer_lock = ReusableLock() + cli.completer = object() + cli.syntax_style = 'native' + cli.cli_style = {} + cli.emacs_ttimeoutlen = 1.0 + cli.vi_ttimeoutlen = 2.0 + cli.destructive_warning = False + cli.destructive_keywords = ['drop'] + cli.llm_prompt_field_truncate = 0 + cli.llm_prompt_section_truncate = 0 + cli.main_formatter = DummyFormatterWithQuery() + cli.redirect_formatter = DummyFormatterWithQuery() + cli.pager_configured = 0 + refresh_calls: list[bool] = [] + output_calls: list[tuple[list[str], Any, bool]] = [] + echo_calls: list[str] = [] + timing_calls: list[tuple[str, bool]] = [] + log_queries: list[str] = [] + cli.refresh_calls = refresh_calls + cli.output_calls = output_calls + cli.echo_calls = echo_calls + cli.timing_calls = timing_calls + cli.log_queries = log_queries + cli.title_calls = 0 + cli.sqlexecute = sqlexecute + cli.get_reserved_space = lambda: 3 + cli.get_last_query = lambda: cli.query_history[-1].query if cli.query_history else None + cli.configure_pager = lambda: setattr(cli, 'pager_configured', cli.pager_configured + 1) + + def refresh_completions(reset: bool = False) -> list[SQLResult]: + cli.refresh_calls.append(reset) + return [SQLResult(status='refresh')] + + cli.refresh_completions = refresh_completions + cli.set_all_external_titles = lambda: setattr(cli, 'title_calls', cli.title_calls + 1) + + def output_timing(timing: str, is_warnings_style: bool = False) -> None: + cli.timing_calls.append((timing, is_warnings_style)) + + cli.output_timing = output_timing + + def log_query(text: str) -> None: + cli.log_queries.append(text) + + cli.log_query = log_query + cli.reconnect = lambda database='': False + + def echo(message: Any, **kwargs: Any) -> None: + cli.echo_calls.append(str(message)) + + cli.echo = echo + + def format_sqlresult(result: SQLResult, **kwargs: Any) -> Iterator[str]: + return iter([str(kwargs.get('max_width')), result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult + + def output(formatted: Any, result: Any, is_warnings_style: bool = False) -> None: + cli.output_calls.append((list(formatted), result, is_warnings_style)) + + cli.output = output + cli.get_prompt = lambda string, render_counter: f'{string}:{render_counter}' + return cli + + +def patch_repl_runtime_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(repl_mode.special, 'set_expanded_output', lambda value: None) + monkeypatch.setattr(repl_mode.special, 'set_forced_horizontal_output', lambda value: None) + monkeypatch.setattr(repl_mode.special, 'is_llm_command', lambda text: False) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr(repl_mode.special, 'write_tee', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'unset_once_if_written', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: None) + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda mycli, text, inputhook, loaded_message_fn: text) + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: False) + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(repl_mode, 'confirm_destructive_query', lambda keywords, text: None) + monkeypatch.setattr(repl_mode, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(repl_mode, 'need_completion_reset', lambda text: False) + monkeypatch.setattr(repl_mode, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + +def test_repl_main_module_and_create_history(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli() + monkeypatch.setenv('MYCLI_HISTFILE', '~/override-history') + monkeypatch.setattr(repl_mode, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(repl_mode, 'FileHistoryWithTimestamp', lambda path: f'history:{path}') + assert repl_mode._main_module() is main_module + history = cast(Any, repl_mode._create_history(cli)) + assert history == f'history:{os.path.expanduser("~/override-history")}' + + monkeypatch.delenv('MYCLI_HISTFILE') + monkeypatch.setattr(repl_mode, 'dir_path_exists', lambda path: False) + assert repl_mode._create_history(cli) is None + assert 'Unable to open the history file' in cli.echo_calls[-1] + + +def test_repl_picker_helpers_cover_present_and_missing_resources(monkeypatch: pytest.MonkeyPatch) -> None: + files = { + 'AUTHORS': '* Alice\n* Bob\n', + 'SPONSORS': '* Carol\n', + 'TIPS': '# comment\nTip 1\n\nTip 2\n', + } + monkeypatch.setattr(repl_mode.resources, 'files', lambda package: FakeResourceTree(files)) + monkeypatch.setattr(repl_mode.random, 'choice', lambda seq: seq[0]) + assert repl_mode._thanks_picker() == 'Alice' + assert repl_mode._tips_picker() == 'Tip 1' + + monkeypatch.setattr(repl_mode.resources, 'files', lambda package: FakeResourceTree({})) + assert repl_mode._thanks_picker() == 'our sponsors' + assert repl_mode._tips_picker() == r'\? or "help" for help!' + + +def test_repl_show_startup_banner_and_prompt_helpers(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace(server_info='Server')) + printed: list[str] = [] + monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(' '.join(str(x) for x in args))) + monkeypatch.setattr(repl_mode.random, 'random', lambda: 0.4) + monkeypatch.setattr(repl_mode, '_thanks_picker', lambda: 'Alice') + monkeypatch.setattr(repl_mode, '_tips_picker', lambda: 'Tip') + + cli.less_chatty = False + repl_mode._show_startup_banner(cli, cli.sqlexecute) + monkeypatch.setattr(repl_mode.random, 'random', lambda: 0.6) + repl_mode._show_startup_banner(cli, cli.sqlexecute) + cli.less_chatty = True + repl_mode._show_startup_banner(cli, cli.sqlexecute) + assert any('Thanks to the contributor' in line for line in printed) + assert any('Tip — Tip' in line for line in printed) + + cli.get_prompt = lambda string, render_counter: '0123456' if string == cli.default_prompt else 'a\nb' + cli.max_len_prompt = 5 + prompt_text = to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=2)))) + assert prompt_text == 'a\nb' + assert cli.prompt_lines == 2 + + cli.last_prompt_message = repl_mode.ANSI('cached') + assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='typing', render_counter=3)))) == 'cached' + + cli.prompt_format = 'custom' + cli.prompt_lines = 0 + cli.get_prompt = lambda string, render_counter: 'single' + assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=4)))) == 'single' + assert cli.prompt_lines == 1 + + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', ' > ')] + cli.multiline_continuation_char = '' + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', '')] + cli.multiline_continuation_char = None + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', ' ')] + + +def test_output_results_covers_watch_warning_timing_beep_and_interrupts(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeSQLExecute: + def run(self, text: str) -> list[SQLResult]: + assert text == 'SHOW WARNINGS' + return [SQLResult(status='warning', rows=[('warn',)])] + + cli = make_repl_cli(FakeSQLExecute()) + cli.auto_vertical_output = True + cli.prompt_app = FakePromptSession(columns=91) + cli.beep_after_seconds = 0.1 + cli.show_warnings = True + state = repl_mode.ReplState() + format_widths: list[int | None] = [] + + def format_sqlresult(result: SQLResult, **kwargs: Any) -> Iterator[str]: + format_widths.append(kwargs.get('max_width')) + return iter([result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult + time_values = iter([0.2, 1.0, 2.0, 3.0, 3.2]) + monkeypatch.setattr(repl_mode.time, 'time', lambda: next(time_values)) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: status == 'mut') + + results = sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='mut', rows=cast(Any, FakeCursorBase(rowcount=1, warning_count=1))), + ) + + repl_mode._output_results(cli, state, results, start=0.0) + + assert state.mutating is True + assert format_widths[:2] == [91, 91] + assert cli.prompt_app.output.bell_count == 2 + assert '' in cli.echo_calls + assert any(is_warnings_style is True for _, _, is_warnings_style in cli.output_calls) + assert any(is_warnings_style is False for _, is_warnings_style in cli.timing_calls) + assert any(is_warnings_style is True for _, is_warnings_style in cli.timing_calls) + + cli_interrupt = make_repl_cli(SimpleNamespace()) + cli_interrupt.echo = lambda message, **kwargs: ( + (_ for _ in ()).throw(KeyboardInterrupt()) if message == '' else cli_interrupt.echo_calls.append(str(message)) + ) + cli_interrupt.output = lambda formatted, result, is_warnings_style=False: (_ for _ in ()).throw(KeyboardInterrupt()) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + monkeypatch.setattr(repl_mode.time, 'time', lambda: 0.0) + repl_mode._output_results( + cli_interrupt, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='first'), SQLResult(status='second')), + start=0.0, + ) + + +def test_output_results_handles_abort_default_width_and_bad_watch(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.auto_vertical_output = True + widths: list[int | None] = [] + + def format_sqlresult_with_width(result: SQLResult, **kwargs: Any) -> Iterator[str]: + widths.append(kwargs.get('max_width')) + return iter([result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult_with_width + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: status == 'select') + monkeypatch.setattr(repl_mode, 'confirm', lambda text: False) + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='select', rows=cast(Any, FakeCursorBase(rowcount=1001)))), + start=0.0, + ) + assert 'The result set has more than 1000 rows.' in cli.echo_calls + assert 'Aborted!' in cli.echo_calls + + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='ok')), + start=0.0, + ) + assert widths[-1] == repl_mode.DEFAULT_WIDTH + + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + with pytest.raises(SystemExit): + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='watch', command={'name': 'watch', 'seconds': 'bad'}), + ), + start=0.0, + ) + + +def test_keepalive_hook_covers_threshold_and_errors() -> None: + cli = make_repl_cli(SimpleNamespace(conn=FakeConnection())) + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + + cli.keepalive_ticks = 0 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + + cli.keepalive_ticks = 1 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 1 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + assert cli.sqlexecute.conn.ping_calls == [False] + + cli.sqlexecute.conn = FakeConnection(ping_exc=RuntimeError('boom')) + repl_mode._keepalive_hook(cli, None) + repl_mode._keepalive_hook(cli, None) + assert any('keepalive ping error' in call[0][0] for call in cli.logger.debug_calls) + + +def test_build_prompt_session_covers_toolbar_modes_and_editing_modes(monkeypatch: pytest.MonkeyPatch) -> None: + captured_kwargs: list[dict[str, Any]] = [] + toolbar_help: list[bool] = [] + + def fake_prompt_session(**kwargs: Any) -> FakePromptSession: + captured_kwargs.append(kwargs) + return FakePromptSession() + + monkeypatch.setattr(repl_mode, 'PromptSession', fake_prompt_session) + monkeypatch.setattr(repl_mode, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') + monkeypatch.setattr(repl_mode, 'cli_is_multiline', lambda mycli: False) + + def fake_toolbar_tokens(mycli: Any, show_help: Any, fmt: str) -> str: + toolbar_help.append(show_help()) + return 'toolbar' + + monkeypatch.setattr(repl_mode, 'create_toolbar_tokens_func', fake_toolbar_tokens) + + cli = make_repl_cli(SimpleNamespace()) + state = repl_mode.ReplState() + cli.toolbar_format = 'none' + cli.key_bindings = 'vi' + cli.wider_completion_menu = True + repl_mode._build_prompt_session(cli, state, history=cast(Any, 'history'), key_bindings=cast(Any, 'bindings')) + first_kwargs = captured_kwargs[-1] + assert first_kwargs['bottom_toolbar'] is None + assert first_kwargs['complete_style'] == repl_mode.CompleteStyle.MULTI_COLUMN + assert first_kwargs['editing_mode'] == repl_mode.EditingMode.VI + assert cli.prompt_app.app.ttimeoutlen == cli.vi_ttimeoutlen + + cli.toolbar_format = 'default' + cli.key_bindings = 'emacs' + cli.wider_completion_menu = False + state.iterations = 0 + repl_mode._build_prompt_session(cli, state, history=cast(Any, 'history'), key_bindings=cast(Any, 'bindings')) + latest_kwargs = captured_kwargs[-1] + assert latest_kwargs['bottom_toolbar'] == 'toolbar' + assert latest_kwargs['complete_style'] == repl_mode.CompleteStyle.COLUMN + assert latest_kwargs['editing_mode'] == repl_mode.EditingMode.EMACS + assert toolbar_help == [True] + assert cli.prompt_app.app.ttimeoutlen == cli.emacs_ttimeoutlen + assert latest_kwargs['prompt_continuation'](4, 0, 0) == [('class:continuation', ' > ')] + + +def test_one_iteration_handles_prompt_interrupt_empty_editor_clip_and_clip_true(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + cli = make_repl_cli(SimpleNamespace(run=lambda text: iter([SQLResult(status='ok')]), conn=FakeConnection())) + cli.keepalive_ticks = 1 + cli.prompt_app = FakePromptSession([KeyboardInterrupt(), ' ', 'edit-error', 'clip-error', 'clip-stop']) + + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + inputhook = cli.prompt_app.prompt_calls[-1]['inputhook'] + assert inputhook is not None + inputhook(None) + + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda *args: (_ for _ in ()).throw(RuntimeError('edit boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert 'edit boom' in cli.echo_calls[-1] + + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda mycli, text, inputhook, loaded_message_fn: text) + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: (_ for _ in ()).throw(RuntimeError('clip boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert 'clip boom' in cli.echo_calls[-1] + + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: True) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + + +def test_one_iteration_covers_llm_paths(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + click_output: list[str] = [] + monkeypatch.setattr(repl_mode.click, 'echo', lambda message='', **kwargs: click_output.append(str(message))) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode.special, 'is_llm_command', lambda text: text.startswith('\\llm')) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.conn = FakeConnection(cursor_value='cursor') + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status=f'ran:{text}')]) + + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda text, cur, dbname, field_truncate, section_truncate: ('context', 'select 1', 1.25), + ) + cli = make_repl_cli(FakeSQLExecute()) + cli.prompt_app = FakePromptSession(['\\llm ask', 'select 1']) + repl_mode._one_iteration( + cli, + repl_mode.ReplState(), + ) + assert click_output[:3] == ['LLM Response:', 'context', '---'] + assert cli.output_calls[0][0] == ['None', 'ran:select 1'] + + cli_finish = make_repl_cli(FakeSQLExecute()) + cli_finish.prompt_app = FakePromptSession(['\\llm finish']) + cli_finish.format_sqlresult = lambda result, **kwargs: iter([result.status_plain or 'row']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(repl_mode.special.FinishIteration(iter([SQLResult(status='done')]))), + ) + repl_mode._one_iteration(cli_finish, repl_mode.ReplState()) + assert cli_finish.output_calls[0][0] == ['done'] + + cli_empty = make_repl_cli(FakeSQLExecute()) + cli_empty.prompt_app = FakePromptSession(['\\llm empty']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(repl_mode.special.FinishIteration(None)), + ) + repl_mode._one_iteration(cli_empty, repl_mode.ReplState()) + assert cli_empty.output_calls == [] + + cli_err = make_repl_cli(FakeSQLExecute()) + cli_err.prompt_app = FakePromptSession(['\\llm err']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError('llm boom')), + ) + repl_mode._one_iteration(cli_err, repl_mode.ReplState()) + assert 'llm boom' in cli_err.echo_calls[-1] + + cli_interrupt = make_repl_cli(FakeSQLExecute()) + cli_interrupt.prompt_app = FakePromptSession(['\\llm stop']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(KeyboardInterrupt()), + ) + repl_mode._one_iteration(cli_interrupt, repl_mode.ReplState()) + assert cli_interrupt.output_calls == [] + + cli_quiet = make_repl_cli(FakeSQLExecute()) + cli_quiet.prompt_app = FakePromptSession(['\\llm quiet', 'select 2']) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda text, cur, dbname, field_truncate, section_truncate: ('', 'select 2', 0.5), + ) + repl_mode._one_iteration(cli_quiet, repl_mode.ReplState()) + assert cli_quiet.output_calls[0][0] == ['None', 'ran:select 2'] + + +def test_one_iteration_covers_redirect_destructive_success_refresh_and_logfile(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def connect(self) -> None: + self.calls.append('connect') + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + return iter([SQLResult(status='DROP 1')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.logfile = False + cli.destructive_warning = True + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: text == 'redirect') + monkeypatch.setattr(repl_mode, 'get_redirect_components', lambda text: ('dropdb', 'tee', '>', 'out.txt')) + redirects: list[tuple[Any, ...]] = [] + monkeypatch.setattr(repl_mode.special, 'set_redirect', lambda *args: redirects.append(args)) + monkeypatch.setattr( + repl_mode, + 'confirm_destructive_query', + lambda keywords, text: None if text == 'dropdb' else (True if text == 'approved' else False), + ) + monkeypatch.setattr(repl_mode, 'is_dropping_database', lambda text, dbname: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'need_completion_refresh', lambda text: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'need_completion_reset', lambda text: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: True) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'redirect') + assert redirects == [('tee', '>', 'out.txt')] + assert cli.refresh_calls == [True] + assert cli.query_history[-1].query == 'dropdb' + assert cli.query_history[-1].successful is True + assert cli.query_history[-1].mutating is True + assert sqlexecute.dbname is None + assert sqlexecute.calls == ['dropdb', 'connect'] + assert 'Warning: This query was not logged.' in cli.echo_calls + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'approved') + assert 'Your call!' in cli.echo_calls + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'denied') + assert 'Wise choice!' in cli.echo_calls + + +def test_one_iteration_covers_reconnect_and_error_paths(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class InterfaceSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + if text == 'iface' and self.calls.count('iface') == 1: + raise pymysql.err.InterfaceError() + return iter([SQLResult(status=f'ok:{text}')]) + + interface_sql = InterfaceSQLExecute() + cli_interface = make_repl_cli(interface_sql) + interface_reconnect_calls: list[str] = [] + interface_results = iter([True]) + + def interface_reconnect(database: str = '') -> bool: + interface_reconnect_calls.append(database) + return next(interface_results) + + cli_interface.reconnect = interface_reconnect + + repl_mode._one_iteration(cli_interface, repl_mode.ReplState(), 'iface') + assert interface_sql.calls.count('iface') == 2 + assert cli_interface.query_history[-1].query == 'iface' + assert interface_reconnect_calls == [''] + + cli_interface_false = make_repl_cli(InterfaceSQLExecute()) + false_calls: list[str] = [] + + def interface_reconnect_false(database: str = '') -> bool: + false_calls.append(database) + return False + + cli_interface_false.reconnect = interface_reconnect_false + repl_mode._one_iteration(cli_interface_false, repl_mode.ReplState(), 'iface') + assert false_calls == [''] + + class ErrorSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + if text == 'oplost' and self.calls.count('oplost') == 1: + raise pymysql.OperationalError(2003, 'lost') + if text == 'opbad': + raise pymysql.OperationalError(9999, 'bad op') + if text == 'nyi': + raise NotImplementedError() + if text == 'boom': + raise RuntimeError('boom') + if text == 'eof': + raise EOFError() + return iter([SQLResult(status=f'ok:{text}')]) + + error_sql = ErrorSQLExecute() + cli_error = make_repl_cli(error_sql) + error_reconnect_calls: list[str] = [] + + def error_reconnect(database: str = '') -> bool: + error_reconnect_calls.append(database) + return True + + cli_error.reconnect = error_reconnect + + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'oplost') + assert error_sql.calls.count('oplost') == 2 + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'opbad') + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'nyi') + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'boom') + assert any('bad op' in line for line in cli_error.echo_calls) + assert 'Not Yet Implemented.' in cli_error.echo_calls + assert any('boom' in line for line in cli_error.echo_calls) + assert error_reconnect_calls == [''] + + cli_error_false = make_repl_cli(ErrorSQLExecute()) + false_reconnect_calls: list[str] = [] + + def error_reconnect_false(database: str = '') -> bool: + false_reconnect_calls.append(database) + return False + + cli_error_false.reconnect = error_reconnect_false + repl_mode._one_iteration(cli_error_false, repl_mode.ReplState(), 'oplost') + assert false_reconnect_calls == [''] + + with pytest.raises(EOFError): + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'eof') + + +def test_one_iteration_reraises_eoferror(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class EofSQLExecute: + dbname = 'db' + connection_id = 0 + + def run(self, text: str) -> Iterator[SQLResult]: + raise EOFError() + + with pytest.raises(EOFError): + repl_mode._one_iteration(make_repl_cli(EofSQLExecute()), repl_mode.ReplState(), 'eof') + + +def test_one_iteration_covers_cancel_paths_and_redirect_error(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + + def connect(self) -> None: + return None + + def run(self, text: str) -> Iterator[SQLResult]: + if text == 'cancel-ok': + self.connection_id = 7 + raise KeyboardInterrupt() + if text == 'kill 7': + return iter([SQLResult(status='OK')]) + if text == 'cancel-fail': + self.connection_id = 8 + raise KeyboardInterrupt() + if text == 'kill 8': + return iter([SQLResult(status='failed')]) + if text == 'cancel-error': + self.connection_id = 9 + raise KeyboardInterrupt() + if text == 'kill 9': + raise RuntimeError('kill failed') + if text == 'cancel-missing': + self.connection_id = 0 + raise KeyboardInterrupt() + return iter([SQLResult(status='ok')]) + + cli = make_repl_cli(FakeSQLExecute()) + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: text == 'redirect-bad') + monkeypatch.setattr(repl_mode, 'get_redirect_components', lambda text: ('sql', 'tee', '>', 'out.txt')) + monkeypatch.setattr(repl_mode.special, 'set_redirect', lambda *args: (_ for _ in ()).throw(RuntimeError('redirect boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'redirect-bad') + assert 'redirect boom' in cli.echo_calls[-1] + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-ok') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-fail') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-error') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-missing') + assert 'Cancelled query id: 7' in cli.echo_calls + assert any('Failed to confirm query cancellation' in line for line in cli.echo_calls) + assert any('Encountered error while cancelling query' in line for line in cli.echo_calls) + assert 'Did not get a connection id, skip cancelling query' in cli.echo_calls + + +def test_main_repl_covers_setup_loop_and_goodbye(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.less_chatty = False + cli.smart_completion = True + loop_iterations: list[int] = [] + monkeypatch.setattr(repl_mode, '_create_history', lambda mycli: 'history') + monkeypatch.setattr(repl_mode, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(repl_mode, '_show_startup_banner', lambda mycli, sqlexecute: None) + monkeypatch.setattr( + repl_mode, + '_build_prompt_session', + lambda mycli, state, history, key_bindings: setattr(mycli, 'prompt_app', FakePromptSession()), + ) + + def fake_one_iteration(mycli: Any, state: repl_mode.ReplState) -> None: + loop_iterations.append(state.iterations) + if len(loop_iterations) == 2: + raise EOFError() + + closed: list[bool] = [] + monkeypatch.setattr(repl_mode, '_one_iteration', fake_one_iteration) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: closed.append(True)) + + repl_mode.main_repl(cli) + + assert cli.pager_configured == 1 + assert cli.refresh_calls == [False] + assert cli.title_calls == 1 + assert loop_iterations == [0, 1] + assert closed == [True] + assert cli.echo_calls[-1] == 'Goodbye!' + + +def test_main_repl_covers_no_refresh_and_quiet_exit(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.less_chatty = True + cli.smart_completion = False + monkeypatch.setattr(repl_mode, '_create_history', lambda mycli: 'history') + monkeypatch.setattr(repl_mode, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(repl_mode, '_show_startup_banner', lambda mycli, sqlexecute: None) + monkeypatch.setattr( + repl_mode, + '_build_prompt_session', + lambda mycli, state, history, key_bindings: setattr(mycli, 'prompt_app', FakePromptSession()), + ) + monkeypatch.setattr(repl_mode, '_one_iteration', lambda mycli, state: (_ for _ in ()).throw(EOFError())) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: None) + + repl_mode.main_repl(cli) + + assert cli.refresh_calls == [] + assert cli.echo_calls == [] + + +def test_output_results_covers_remaining_watch_select_and_warning_branches(monkeypatch: pytest.MonkeyPatch) -> None: + class WarninglessSQLExecute: + def run(self, text: str) -> list[SQLResult]: + assert text == 'SHOW WARNINGS' + return [] + + cli = make_repl_cli(WarninglessSQLExecute()) + cli.show_warnings = True + cli.auto_vertical_output = False + cli.prompt_app = FakePromptSession(columns=77) + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + monkeypatch.setattr(repl_mode, 'confirm', lambda text: True) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: status == 'select') + monkeypatch.setattr(repl_mode.time, 'time', lambda: 0.0) + + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='watch', command={'name': 'watch', 'seconds': '2'}), + SQLResult(status='select', rows=cast(Any, FakeCursorBase(rowcount=1001, warning_count=1))), + ), + start=0.0, + ) + assert cli.output_calls diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index f12bd1a5..a04de3ec 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -10,6 +10,10 @@ * migrating individual tests if content moves out of main.py * migrating individual tests to test_main.py after assessment of quality * removing and rewriting these tests if contracts change + +For example, since the generation of these tests, main_modes/repl.py was +created, and all tests here touching the REPL functionality should in +principle be removed. """ from __future__ import annotations @@ -21,7 +25,9 @@ import itertools import os from pathlib import Path +import random import sys +import time from types import ModuleType, SimpleNamespace from typing import Any, Callable, Literal, cast @@ -31,7 +37,9 @@ import pymysql import pytest -from mycli import key_bindings, main +from mycli import main +import mycli.key_bindings +import mycli.main_modes.repl from mycli.packages import key_binding_utils from mycli.packages.sqlresult import SQLResult @@ -677,15 +685,18 @@ def test_initialize_logging_covers_none_bad_path_and_file_handler(tmp_path: Path cli.echo = lambda message, **kwargs: echo_calls.append(message) # type: ignore[assignment] cli.config = {'main': {'log_file': str(tmp_path / 'mycli.log'), 'log_level': 'NONE'}} monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) main.MyCli.initialize_logging(cli) cli.config = {'main': {'log_file': str(tmp_path / 'missing' / 'mycli.log'), 'log_level': 'INFO'}} monkeypatch.setattr(main, 'dir_path_exists', lambda path: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: False) main.MyCli.initialize_logging(cli) assert echo_calls[-1].startswith('Error: Unable to open the log file') cli.config = {'main': {'log_file': str(tmp_path / 'mycli.log'), 'log_level': 'INFO'}} monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) main.MyCli.initialize_logging(cli) @@ -1043,23 +1054,23 @@ def test_handle_editor_clip_and_output_timing(monkeypatch: pytest.MonkeyPatch) - monkeypatch.setattr(main.special, 'get_filename', lambda text: 'query.sql') monkeypatch.setattr(main.special, 'get_editor_query', lambda text: 'select 1') monkeypatch.setattr(main.special, 'open_external_editor', lambda filename, sql: ('edited sql', None)) - assert key_binding_utils.handle_editor_command(cli, r'select 1\e', None, lambda: None) == 'edited sql' + assert mycli.main_modes.repl.handle_editor_command(cli, r'select 1\e', None, lambda: None) == 'edited sql' monkeypatch.setattr(main.special, 'open_external_editor', lambda filename, sql: ('', 'boom')) with pytest.raises(RuntimeError, match='boom'): - key_binding_utils.handle_editor_command(cli, r'select 1\e', None, lambda: None) + mycli.main_modes.repl.handle_editor_command(cli, r'select 1\e', None, lambda: None) monkeypatch.setattr(main.special, 'clip_command', lambda text: True) monkeypatch.setattr(main.special, 'get_clip_query', lambda text: None) monkeypatch.setattr(main.special, 'copy_query_to_clipboard', lambda sql: None) - assert key_binding_utils.handle_clip_command(cli, r'select 1\clip') is True + assert mycli.main_modes.repl.handle_clip_command(cli, r'select 1\clip') is True monkeypatch.setattr(main.special, 'copy_query_to_clipboard', lambda sql: 'clipboard failed') with pytest.raises(RuntimeError, match='clipboard failed'): - key_binding_utils.handle_clip_command(cli, r'select 1\clip') + mycli.main_modes.repl.handle_clip_command(cli, r'select 1\clip') monkeypatch.setattr(main.special, 'clip_command', lambda text: False) - assert key_binding_utils.handle_clip_command(cli, 'select 1') is False + assert mycli.main_modes.repl.handle_clip_command(cli, 'select 1') is False printed: list[tuple[Any, Any]] = [] monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) @@ -1103,7 +1114,7 @@ def test_format_sqlresult_run_query_reserved_space_and_last_query(monkeypatch: p assert main.MyCli.get_last_query(cli) == 'select 1' -def test_reconnect_logging_output_titles_prompt_and_picker_fallbacks(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: +def test_reconnect_logging_output_titles_prompt(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cli = make_bare_mycli() sqlexecute = object.__new__(main.SQLExecute) @@ -1195,17 +1206,6 @@ def failing_connect() -> None: monkeypatch.setattr(main.subprocess, 'run', lambda *args, **kwargs: (_ for _ in ()).throw(FileNotFoundError())) main.MyCli.set_external_multiplex_window_title(cli) - class MissingResource: - def joinpath(self, name: str) -> 'MissingResource': - return self - - def open(self, mode: str) -> StringIO: - raise FileNotFoundError() - - monkeypatch.setattr(main.resources, 'files', lambda package: MissingResource()) - assert main.thanks_picker() == 'our sponsors' - assert main.tips_picker() == r'\? or "help" for help!' - def test_reconnect_first_and_second_passes(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() @@ -1480,40 +1480,6 @@ def test_completion_helpers_title_helpers_thanks_tips(monkeypatch: pytest.Monkey assert list(main.MyCli.get_completions(cli, 'select', 6)) == ['done'] assert entered_lock['count'] >= 2 - class FakeResource: - def __init__(self, text: str | None) -> None: - self.text = text - - def joinpath(self, name: str) -> 'FakeResource': - if name == 'AUTHORS': - return FakeResource('* Alice\n') - if name == 'SPONSORS': - raise FileNotFoundError() - if name == 'TIPS': - return FakeResource('# comment\nTip one\n\nTip two\n') - raise FileNotFoundError() - - def open(self, mode: str) -> StringIO: - if self.text is None: - raise FileNotFoundError() - return StringIO(self.text) - - monkeypatch.setattr(main.resources, 'files', lambda package: FakeResource(None)) - monkeypatch.setattr(main, 'choice', lambda values: values[0]) - assert main.thanks_picker() == 'Alice' - assert main.tips_picker() == 'Tip one' - - class SponsorResource(FakeResource): - def joinpath(self, name: str) -> 'FakeResource': - if name == 'AUTHORS': - raise FileNotFoundError() - if name == 'SPONSORS': - return FakeResource('* Sponsor Person\n') - raise FileNotFoundError() - - monkeypatch.setattr(main.resources, 'files', lambda package: SponsorResource(None)) - assert main.thanks_picker() == 'Sponsor Person' - def test_main_wrapper_and_edit_and_execute(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(main, 'filtered_sys_argv', lambda: ['--help']) @@ -1555,7 +1521,7 @@ class ErrorNoCode(click.ClickException): current_buffer=SimpleNamespace(open_in_editor=lambda validate_and_handle=False: opened.append(validate_and_handle)) ), ) - key_bindings.edit_and_execute(event) + mycli.key_bindings.edit_and_execute(event) assert opened == [False] @@ -2057,6 +2023,9 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> list[SQLResult]: return [SQLResult(status='SELECT 1', header=['a'], rows=[(1,)])] @@ -2065,12 +2034,13 @@ def run(self, text: str) -> list[SQLResult]: sqlexecute = FakeRunSQLExecute() cli.sqlexecute = cast(Any, sqlexecute) monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: False) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: False) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2081,10 +2051,10 @@ def run(self, text: str) -> list[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) main.MyCli.run_cli(cli) assert refresh_resets == [False] assert outputs == [['formatted']] @@ -2093,6 +2063,14 @@ def run(self, text: str) -> list[SQLResult]: assert prompt_session.app.ttimeoutlen == cli.emacs_ttimeoutlen +def test_run_cli_delegates_to_main_repl(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + calls: list[Any] = [] + monkeypatch.setattr(main, 'main_repl', lambda target: calls.append(target)) + main.MyCli.run_cli(cli) + assert calls == [cli] + + def test_run_cli_large_select_asks_for_confirmation(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() cli.config = {'history_file': '~/.mycli-history-testing'} @@ -2105,13 +2083,14 @@ def test_run_cli_large_select_asks_for_confirmation(monkeypatch: pytest.MonkeyPa echoed: list[str] = [] cli.echo = lambda message, **kwargs: echoed.append(str(message)) # type: ignore[assignment] prompt_session = FakePromptSession(responses=['select * from t', EOFError()]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) - monkeypatch.setattr(main, 'Cursor', FakeCursorBase) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'Cursor', FakeCursorBase) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2122,11 +2101,11 @@ def test_run_cli_large_select_asks_for_confirmation(monkeypatch: pytest.MonkeyPa monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) - monkeypatch.setattr(main, 'confirm', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'confirm', lambda text: False) rows = FakeCursorBase(rows=[(1,)], rowcount=1001, description=[('id', 3)], warning_count=0) class FakeRunSQLExecute: @@ -2161,13 +2140,14 @@ def test_run_cli_outputs_warnings_and_timing(monkeypatch: pytest.MonkeyPatch) -> cli.output_timing = lambda timing, is_warnings_style=False: timings.append((timing, is_warnings_style)) # type: ignore[assignment] cli.format_sqlresult = lambda result, **kwargs: iter([result.status_plain or 'row']) # type: ignore[assignment] prompt_session = FakePromptSession(responses=['select 1', EOFError()]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) - monkeypatch.setattr(main, 'Cursor', FakeCursorBase) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'Cursor', FakeCursorBase) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2178,10 +2158,10 @@ def test_run_cli_outputs_warnings_and_timing(monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) warning_rows = FakeCursorBase(rows=[('Level', 1, 'Message')], rowcount=1, description=[('id', 3)], warning_count=1) main_result = SQLResult(status='SELECT 1', header=['id'], rows=cast(Any, warning_rows)) warning_result = SQLResult(status='Warning', header=['level'], rows=[('Warning',)]) @@ -2191,6 +2171,9 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> list[SQLResult]: if text == 'SHOW WARNINGS': @@ -2237,6 +2220,9 @@ def __init__(self) -> None: self.server_info = 'Server' self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) @@ -2249,21 +2235,20 @@ def fake_prompt_session(**kwargs: Any) -> InspectPromptSession: continuations.append(kwargs['prompt_continuation'](4, 0, 0)) return prompt_session - monkeypatch.setattr(main, 'PromptSession', fake_prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', fake_prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') def fake_create_toolbar_tokens(mycli: Any, show_help: Any, fmt: str) -> str: toolbar_help.append(show_help()) return 'toolbar' - monkeypatch.setattr(main, 'create_toolbar_tokens_func', fake_create_toolbar_tokens) + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', fake_create_toolbar_tokens) monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main.random, 'random', lambda: 0.4) - monkeypatch.setattr(main, 'thanks_picker', lambda: 'Alice') - monkeypatch.setattr(main, 'tips_picker', lambda: 'Tip') + monkeypatch.setattr(random, 'random', lambda: 0.4) monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: prints.append(' '.join(str(x) for x in args))) echoed: list[str] = [] cli.echo = lambda message, **kwargs: echoed.append(str(message)) # type: ignore[assignment] @@ -2303,6 +2288,9 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.conn = LLMConnection() + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> Iterator[SQLResult]: return iter([SQLResult(status=f'ran:{text}')]) @@ -2310,12 +2298,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) prompt_session = FakePromptSession(responses=['\\llm ask', 'select 1', '\\llm finish', '\\llm empty', '\\llm err', EOFError()]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_expanded_output', lambda: False) @@ -2325,10 +2314,10 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: text.startswith('\\llm')) def fake_handle_llm(text: str, cur: Any, dbname: str, field_truncate: int, section_truncate: int) -> tuple[str, str, float]: @@ -2396,6 +2385,9 @@ def __init__(self) -> None: self.connection_id = 0 self.conn = SimpleNamespace() self.calls: list[str] = [] + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def connect(self) -> None: self.calls.append('connect') @@ -2417,12 +2409,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) sqlexecute = FakeRunSQLExecute() cli.sqlexecute = cast(Any, sqlexecute) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2433,11 +2426,11 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: text == 'dropdb') - monkeypatch.setattr(main, 'need_completion_reset', lambda text: True) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: text == 'dropdb') + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: text == 'dropdb') + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_reset', lambda text: True) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: text == 'dropdb') main.MyCli.run_cli(cli) assert reconnect_calls == ['', ''] @@ -2479,6 +2472,9 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.conn = SimpleNamespace(cursor=lambda: 'cursor') + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def connect(self) -> None: return None @@ -2496,12 +2492,13 @@ def run(self, text: str) -> Iterator[SQLResult]: raise EOFError() return iter([SQLResult(status=f'ok:{text}')]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_expanded_output', lambda: False) @@ -2511,10 +2508,10 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: text.startswith('\\llm')) monkeypatch.setattr(main.special, 'handle_llm', lambda *args, **kwargs: (_ for _ in ()).throw(KeyboardInterrupt())) monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) @@ -2542,18 +2539,22 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> Iterator[SQLResult]: if text == 'iface': raise pymysql.err.InterfaceError() raise pymysql.OperationalError(2003, 'lost') - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2564,62 +2565,15 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) main.MyCli.run_cli(cli) -def test_run_cli_tip_prompt_lines_toolbar_none_and_keepalive_noops(monkeypatch: pytest.MonkeyPatch) -> None: - cli = make_bare_mycli() - cli.less_chatty = False - cli.toolbar_format = 'none' - cli.keepalive_ticks = 1 - cli.prompt_format = 'prompt' - cli.config = {'history_file': '~/.mycli-history-testing'} - cli.set_all_external_titles = lambda: None # type: ignore[assignment] - cli.get_prompt = lambda string, render_counter: 'prompt' # type: ignore[assignment] - printed: list[str] = [] - - class PromptOnce(FakePromptSession): - def prompt(self, **kwargs: Any) -> str: - inputhook = kwargs.get('inputhook') - if inputhook is not None: - cli.keepalive_ticks = None - inputhook(None) - cli.keepalive_ticks = 0 - inputhook(None) - kwargs['message']() - raise EOFError() - - class FakeRunSQLExecute: - def __init__(self) -> None: - self.server_info = 'Server' - self.dbname = 'db' - self.connection_id = 0 - - monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) - cli.sqlexecute = cast(Any, FakeRunSQLExecute()) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: PromptOnce()) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr( - main, 'create_toolbar_tokens_func', lambda *args: (_ for _ in ()).throw(AssertionError('toolbar should be disabled')) - ) - monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') - monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) - monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main.random, 'random', lambda: 0.6) - monkeypatch.setattr(main, 'tips_picker', lambda: 'Tip') - monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(' '.join(str(x) for x in args))) - main.MyCli.run_cli(cli) - assert any('Tip' in line for line in printed) - assert cli.prompt_lines == 1 - - def test_run_cli_watch_beep_auto_vertical_and_cancel_failure_paths(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() cli.config = {'history_file': '~/.mycli-history-testing'} @@ -2649,6 +2603,9 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.conn = SimpleNamespace() + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def connect(self) -> None: return None @@ -2673,12 +2630,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2689,11 +2647,11 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) - monkeypatch.setattr(main, 'time', iter([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).__next__) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(time, 'time', iter([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).__next__) main.MyCli.run_cli(cli) assert recorded_widths[:2] == [91, 91] assert '' in echoes @@ -2726,6 +2684,9 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> Iterator[SQLResult]: cli.prompt_app = None @@ -2733,12 +2694,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2749,9 +2711,9 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) main.MyCli.run_cli(cli) assert widths == [main.DEFAULT_WIDTH]