diff --git a/changelog.md b/changelog.md index 43f014e6..1d9be1a2 100644 --- a/changelog.md +++ b/changelog.md @@ -39,6 +39,7 @@ Internal * Move `--execute` logic to the new `main_modes` with `--batch`. * Move `--list-dsn` logic to the new `main_modes` with `--batch`. * Move `--list-ssh-config` logic to the new `main_modes` with `--batch`. +* Move REPL logic to the new `main_modes`. * Sort coverage report in tox suite. * Skip more tests when a database connection is not present. * Move SQL utilities to a new `sql_utils.py`. diff --git a/mycli/constants.py b/mycli/constants.py index 88edaa76..2d278ae4 100644 --- a/mycli/constants.py +++ b/mycli/constants.py @@ -10,3 +10,6 @@ DEFAULT_USER = 'root' TEST_DATABASE = 'mycli_test_db' + +DEFAULT_WIDTH = 80 +DEFAULT_HEIGHT = 25 diff --git a/mycli/main.py b/mycli/main.py index ba1484b5..97d4513e 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,13 +1,12 @@ from __future__ import annotations -from collections import defaultdict, namedtuple +from collections import defaultdict from dataclasses import dataclass from decimal import Decimal import functools from io import TextIOWrapper import logging import os -import random import re import shutil import subprocess @@ -21,11 +20,8 @@ except ImportError: pass from datetime import datetime -from importlib import resources import itertools -from random import choice from textwrap import dedent -from time import time from urllib.parse import parse_qs, unquote, urlparse from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors @@ -37,11 +33,9 @@ import keyring from prompt_toolkit import print_formatted_text from prompt_toolkit.application.current import get_app -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest -from prompt_toolkit.completion import Completion, DynamicCompleter +from prompt_toolkit.completion import Completion from prompt_toolkit.document import Document -from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode -from prompt_toolkit.filters import Condition, HasFocus, IsDone +from prompt_toolkit.filters import Condition from prompt_toolkit.formatted_text import ( ANSI, HTML, @@ -50,33 +44,27 @@ to_formatted_text, to_plain_text, ) -from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor -from prompt_toolkit.lexers import PygmentsLexer -from prompt_toolkit.output import ColorDepth -from prompt_toolkit.shortcuts import CompleteStyle, PromptSession +from prompt_toolkit.shortcuts import PromptSession import pymysql from pymysql.constants.CR import CR_SERVER_LOST from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR from pymysql.cursors import Cursor import sqlparse -from mycli import __version__ -from mycli.clibuffer import cli_is_multiline +import mycli as mycli_package from mycli.clistyle import style_factory_helpers, style_factory_ptoolkit -from mycli.clitoolbar import create_toolbar_tokens_func from mycli.compat import WIN from mycli.completion_refresher import CompletionRefresher from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, strip_matching_quotes, write_default_config from mycli.constants import ( DEFAULT_CHARSET, + DEFAULT_HEIGHT, DEFAULT_HOST, DEFAULT_PORT, - HOME_URL, + DEFAULT_WIDTH, ISSUES_URL, REPO_URL, ) -from mycli.key_bindings import mycli_bindings -from mycli.lexer import MyCliLexer from mycli.main_modes.batch import ( main_batch_from_stdin, main_batch_with_progress_bar, @@ -86,42 +74,25 @@ from mycli.main_modes.execute import main_execute_from_cli from mycli.main_modes.list_dsn import main_list_dsn from mycli.main_modes.list_ssh_config import main_list_ssh_config +from mycli.main_modes.repl import main_repl from mycli.packages import special from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme from mycli.packages.filepaths import dir_path_exists, guess_socket_location -from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command -from mycli.packages.key_binding_utils import ( - handle_clip_command, - handle_editor_command, -) -from mycli.packages.prompt_utils import confirm, confirm_destructive_query -from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp +from mycli.packages.prompt_utils import confirm_destructive_query from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import ArgType from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count -from mycli.packages.sql_utils import ( - is_dropping_database, - is_mutating, - is_select, - need_completion_refresh, - need_completion_reset, -) from mycli.packages.sqlresult import SQLResult from mycli.packages.ssh_utils import read_ssh_config from mycli.packages.string_utils import sanitize_terminal_title from mycli.packages.tabular_output import sql_format from mycli.sqlcompleter import SQLCompleter from mycli.sqlexecute import FIELD_TYPES, SQLExecute +from mycli.types import Query sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] -# Query tuples are used for maintaining history -Query = namedtuple("Query", ["query", "successful", "mutating"]) - -SUPPORT_INFO = f"Home: {HOME_URL}\nBug tracker: {ISSUES_URL}" -DEFAULT_WIDTH = 80 -DEFAULT_HEIGHT = 25 MIN_COMPLETION_TRIGGER = 1 EMPTY_PASSWORD_FLAG_SENTINEL = -1 @@ -880,434 +851,7 @@ def output_timing(self, timing: str, is_warnings_style: bool = False) -> None: print_formatted_text(styled_timing, style=self.ptoolkit_style) def run_cli(self) -> None: - iterations = 0 - sqlexecute = self.sqlexecute - assert isinstance(sqlexecute, SQLExecute) - logger = self.logger - self.configure_pager() - - if self.smart_completion: - self.refresh_completions() - - history_file = os.path.expanduser(os.environ.get("MYCLI_HISTFILE", self.config.get("history_file", "~/.mycli-history"))) - if dir_path_exists(history_file): - history = FileHistoryWithTimestamp(history_file) - else: - history = None - self.echo( - f'Error: Unable to open the history file "{history_file}". Your query history will not be saved.', - err=True, - fg="red", - ) - - key_bindings = mycli_bindings(self) - - if not self.less_chatty: - print(sqlexecute.server_info) - print("mycli", __version__) - print(SUPPORT_INFO) - if random.random() <= 0.5: - print("Thanks to the contributor —", thanks_picker()) - else: - print("Tip —", tips_picker()) - - def get_prompt_message(app) -> ANSI: - if app.current_buffer.text: - return self.last_prompt_message - prompt = self.get_prompt(self.prompt_format, app.render_counter) - if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: - prompt = self.get_prompt(self.default_prompt_splitln, app.render_counter) - self.prompt_lines = prompt.count('\n') + 1 - prompt = prompt.replace("\\x1b", "\x1b") - if not self.prompt_lines: - self.prompt_lines = prompt.count('\n') + 1 - self.last_prompt_message = ANSI(prompt) - return self.last_prompt_message - - def get_continuation(width: int, _two: int, _three: int) -> AnyFormattedText: - if self.multiline_continuation_char == "": - continuation = "" - elif self.multiline_continuation_char: - left_padding = width - len(self.multiline_continuation_char) - continuation = " " * max((left_padding - 1), 0) + self.multiline_continuation_char + " " - else: - continuation = " " - return [("class:continuation", continuation)] - - def show_initial_toolbar_help() -> bool: - return iterations == 0 - - # Keep track of whether or not the query is mutating. In case - # of a multi-statement query, the overall query is considered - # mutating if any one of the component statements is mutating - mutating = False - - def output_res(results: Generator[SQLResult], start: float) -> None: - nonlocal mutating - result_count = watch_count = 0 - for result in results: - logger.debug("preamble: %r", result.preamble) - logger.debug("header: %r", result.header) - logger.debug("rows: %r", result.rows) - logger.debug("status: %r", result.status) - logger.debug("command: %r", result.command) - threshold = 1000 - # If this is a watch query, offset the start time on the 2nd+ iteration - # to account for the sleep duration - if result.command is not None and result.command["name"] == "watch": - if watch_count > 0: - try: - watch_seconds = float(result.command["seconds"]) - start += watch_seconds - except ValueError as e: - self.echo(f"Invalid watch sleep time provided ({e}).", err=True, fg="red") - sys.exit(1) - else: - watch_count += 1 - if is_select(result.status_plain) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold: - self.echo( - f"The result set has more than {threshold} rows.", - fg="red", - ) - if not confirm("Do you want to continue?"): - self.echo("Aborted!", err=True, fg="red") - break - - if self.auto_vertical_output: - if self.prompt_app is not None: - max_width = self.prompt_app.output.get_size().columns - else: - max_width = DEFAULT_WIDTH - else: - max_width = None - - formatted = self.format_sqlresult( - result, - is_expanded=special.is_expanded_output(), - is_redirected=special.is_redirected(), - null_string=self.null_string, - numeric_alignment=self.numeric_alignment, - binary_display=self.binary_display, - max_width=max_width, - ) - - t = time() - start - try: - if result_count > 0: - self.echo("") - try: - self.output(formatted, result) - except KeyboardInterrupt: - pass - if self.beep_after_seconds > 0 and t >= self.beep_after_seconds: - assert self.prompt_app is not None - self.prompt_app.output.bell() - if special.is_timing_enabled(): - self.output_timing(f"Time: {t:0.03f}s") - except KeyboardInterrupt: - pass - - start = time() - result_count += 1 - mutating = mutating or is_mutating(result.status_plain) - - # get and display warnings if enabled - if self.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: - warnings = sqlexecute.run("SHOW WARNINGS") - t = time() - start - saw_warning = False - for warning in warnings: - saw_warning = True - formatted = self.format_sqlresult( - warning, - is_expanded=special.is_expanded_output(), - is_redirected=special.is_redirected(), - null_string=self.null_string, - numeric_alignment=self.numeric_alignment, - binary_display=self.binary_display, - max_width=max_width, - is_warnings_style=True, - ) - self.echo("") - self.output(formatted, warning, is_warnings_style=True) - - if saw_warning and special.is_timing_enabled(): - self.output_timing(f"Time: {t:0.03f}s", is_warnings_style=True) - - def keepalive_hook(_context): - """ - prompt_toolkit shares the event loop with this hook, which seems - to get called a bit faster than once/second on one machine. - - It would be nice to reset the counter whenever user input is made, - but was not clear how to do that with context.input_is_ready(). - - Example at https://github.com/prompt-toolkit/python-prompt-toolkit/blob/main/examples/prompts/inputhook.py - """ - if self.keepalive_ticks is None: - return - if self.keepalive_ticks < 1: - return - self._keepalive_counter += 1 - if self._keepalive_counter > self.keepalive_ticks: - self._keepalive_counter = 0 - self.logger.debug('keepalive ping') - try: - assert self.sqlexecute is not None - assert self.sqlexecute.conn is not None - self.sqlexecute.conn.ping(reconnect=False) - except Exception as e: - self.logger.debug('keepalive ping error %r', e) - - def one_iteration(text: str | None = None) -> None: - inputhook = keepalive_hook if self.keepalive_ticks and self.keepalive_ticks >= 1 else None - if text is None: - try: - assert self.prompt_app is not None - loaded_message_fn = functools.partial(get_prompt_message, self.prompt_app.app) - text = self.prompt_app.prompt( - inputhook=inputhook, - message=loaded_message_fn, - ) - except KeyboardInterrupt: - return - - special.set_expanded_output(False) - special.set_forced_horizontal_output(False) - - try: - text = handle_editor_command( - self, - text, - inputhook, - loaded_message_fn, - ) - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - - try: - if handle_clip_command(self, text): - return - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - # LLM command support - while special.is_llm_command(text): - start = time() - try: - assert isinstance(self.sqlexecute, SQLExecute) - assert sqlexecute.conn is not None - cur = sqlexecute.conn.cursor() - context, sql, duration = special.handle_llm( - text, - cur, - sqlexecute.dbname or '', - self.llm_prompt_field_truncate, - self.llm_prompt_section_truncate, - ) - if context: - click.echo("LLM Response:") - click.echo(context) - click.echo("---") - if special.is_timing_enabled(): - self.output_timing(f"Time: {duration:.2f} seconds") - text = self.prompt_app.prompt( - default=sql or '', - inputhook=inputhook, - message=loaded_message_fn, - ) - except KeyboardInterrupt: - return - except special.FinishIteration as e: - if e.results: - return output_res(e.results, start) - else: - return None - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - - text = text.strip() - - if not text: - return - - if is_redirect_command(text): - sql_part, command_part, file_operator_part, file_part = get_redirect_components(text) - text = sql_part or '' - try: - special.set_redirect(command_part, file_operator_part, file_part) - except (FileNotFoundError, OSError, RuntimeError) as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - return - - if self.destructive_warning: - destroy = confirm_destructive_query(self.destructive_keywords, text) - if destroy is None: - pass # Query was not destructive. Nothing to do here. - elif destroy is True: - self.echo("Your call!") - else: - self.echo("Wise choice!") - return - else: - destroy = True - - try: - logger.debug("sql: %r", text) - - special.write_tee(self.last_prompt_message, nl=False) - special.write_tee(text) - self.log_query(text) - - successful = False - start = time() - res = sqlexecute.run(text) - self.main_formatter.query = text - self.redirect_formatter.query = text - successful = True - output_res(res, start) - special.unset_once_if_written(self.post_redirect_command) - special.flush_pipe_once_if_written(self.post_redirect_command) - except pymysql.err.InterfaceError: - # attempt to reconnect - if not self.reconnect(): - return - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except EOFError as e: - raise e - except KeyboardInterrupt: - # get last connection id - connection_id_to_kill = sqlexecute.connection_id or 0 - # some mysql-compatible databases may not implement connection_id() - if connection_id_to_kill > 0: - logger.debug("connection id to kill: %r", connection_id_to_kill) - try: - sqlexecute.connect() - for kill_result in sqlexecute.run(f"kill {connection_id_to_kill}"): - status_str = str(kill_result.status_plain).lower() - if status_str.find("ok") > -1: - logger.debug("cancelled query, connection id: %r, sql: %r", connection_id_to_kill, text) - self.echo(f"Cancelled query id: {connection_id_to_kill}", err=True, fg="blue") - else: - logger.debug( - "Failed to confirm query cancellation, connection id: %r, sql: %r", - connection_id_to_kill, - text, - ) - self.echo(f"Failed to confirm query cancellation, id: {connection_id_to_kill}", err=True, fg="red") - except Exception as e2: - self.echo(f"Encountered error while cancelling query: {e2}", err=True, fg="red") - else: - logger.debug("Did not get a connection id, skip cancelling query") - self.echo("Did not get a connection id, skip cancelling query", err=True, fg="red") - except NotImplementedError: - self.echo("Not Yet Implemented.", fg="yellow") - except pymysql.OperationalError as e1: - logger.debug("Exception: %r", e1) - if e1.args[0] in (2003, 2006, 2013): - # attempt to reconnect - if not self.reconnect(): - return - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - else: - logger.error("sql: %r, error: %r", text, e1) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e1), err=True, fg="red") - except Exception as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg="red") - else: - if is_dropping_database(text, sqlexecute.dbname): - sqlexecute.dbname = None - sqlexecute.connect() - - # Refresh the table names and column names if necessary. - if need_completion_refresh(text): - self.refresh_completions(reset=need_completion_reset(text)) - finally: - if self.logfile is False: - self.echo("Warning: This query was not logged.", err=True, fg="red") - query = Query(text, successful, mutating) - self.query_history.append(query) - - if self.toolbar_format.lower() == 'none': - get_toolbar_tokens = None - else: - get_toolbar_tokens = create_toolbar_tokens_func( - self, - show_initial_toolbar_help, - self.toolbar_format, - ) - - if self.wider_completion_menu: - complete_style = CompleteStyle.MULTI_COLUMN - else: - complete_style = CompleteStyle.COLUMN - - with self._completer_lock: - if self.key_bindings == "vi": - editing_mode = EditingMode.VI - else: - editing_mode = EditingMode.EMACS - - self.prompt_app = PromptSession( - color_depth=ColorDepth.DEPTH_24_BIT if 'truecolor' in os.getenv('COLORTERM', '').lower() else None, - lexer=PygmentsLexer(MyCliLexer), - reserve_space_for_menu=self.get_reserved_space(), - prompt_continuation=get_continuation, - bottom_toolbar=get_toolbar_tokens, - complete_style=complete_style, - input_processors=[ - ConditionalProcessor( - processor=HighlightMatchingBracketProcessor(chars="[](){}"), filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() - ) - ], - tempfile_suffix=".sql", - completer=DynamicCompleter(lambda: self.completer), - complete_in_thread=True, - history=history, - auto_suggest=ThreadedAutoSuggest(AutoSuggestFromHistory()), - complete_while_typing=complete_while_typing_filter, - multiline=cli_is_multiline(self), - # why not self.ptoolkit_style here? - style=style_factory_ptoolkit(self.syntax_style, self.cli_style), - include_default_pygments_style=False, - key_bindings=key_bindings, - enable_open_in_editor=True, - enable_system_prompt=True, - enable_suspend=True, - editing_mode=editing_mode, - search_ignore_case=True, - ) - - if self.key_bindings == 'vi': - self.prompt_app.app.ttimeoutlen = self.vi_ttimeoutlen - else: - self.prompt_app.app.ttimeoutlen = self.emacs_ttimeoutlen - - self.set_all_external_titles() - - try: - while True: - one_iteration() - iterations += 1 - except EOFError: - special.close_tee() - if not self.less_chatty: - self.echo("Goodbye!") + main_repl(self) def reconnect(self, database: str = "") -> bool: """ @@ -2107,7 +1651,7 @@ class CliArgs: @click.command() @clickdc.adddc('cli_args', CliArgs) -@click.version_option(__version__, '--version', '-V', help="Output mycli's version.") +@click.version_option(mycli_package.__version__, '--version', '-V', help="Output mycli's version.") def click_entrypoint( cli_args: CliArgs, ) -> None: @@ -2566,47 +2110,6 @@ def get_password_from_file(password_file: str | None) -> str | None: mycli.close() -def thanks_picker() -> str: - import mycli - - lines: str = "" - try: - with resources.files(mycli).joinpath("AUTHORS").open('r') as f: - lines += f.read() - except FileNotFoundError: - pass - - try: - with resources.files(mycli).joinpath("SPONSORS").open('r') as f: - lines += f.read() - except FileNotFoundError: - pass - - contents = [] - for line in lines.split("\n"): - if m := re.match(r"^ *\* (.*)", line): - contents.append(m.group(1)) - return choice(contents) if contents else 'our sponsors' - - -def tips_picker() -> str: - import mycli - - tips = [] - - try: - with resources.files(mycli).joinpath('TIPS').open('r') as f: - for line in f: - if line.startswith("#"): - continue - if tip := line.strip(): - tips.append(tip) - except FileNotFoundError: - pass - - return choice(tips) if tips else r'\? or "help" for help!' - - def main() -> int | None: try: result = click_entrypoint.main( diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py new file mode 100644 index 00000000..a507a38f --- /dev/null +++ b/mycli/main_modes/repl.py @@ -0,0 +1,572 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial +from importlib import resources +import os +import random +import re +import sys +import time +import traceback +from typing import TYPE_CHECKING, Any, Generator + +import click +import prompt_toolkit +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest +from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.filters import HasFocus, IsDone +from prompt_toolkit.formatted_text import ( + ANSI, +) +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.output import ColorDepth +from prompt_toolkit.shortcuts import CompleteStyle, PromptSession +import pymysql +from pymysql.cursors import Cursor + +import mycli as mycli_package +from mycli.clibuffer import cli_is_multiline +from mycli.clistyle import style_factory_ptoolkit +from mycli.clitoolbar import create_toolbar_tokens_func +from mycli.constants import ( + DEFAULT_WIDTH, + HOME_URL, + ISSUES_URL, +) +from mycli.key_bindings import mycli_bindings +from mycli.lexer import MyCliLexer +from mycli.packages import special +from mycli.packages.filepaths import dir_path_exists +from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command +from mycli.packages.key_binding_utils import ( + handle_clip_command, + handle_editor_command, +) +from mycli.packages.prompt_utils import confirm, confirm_destructive_query +from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp +from mycli.packages.sql_utils import ( + is_dropping_database, + is_mutating, + is_select, + need_completion_refresh, + need_completion_reset, +) +from mycli.packages.sqlresult import SQLResult +from mycli.sqlexecute import SQLExecute +from mycli.types import Query + +if TYPE_CHECKING: + from prompt_toolkit.formatted_text import AnyFormattedText + + from mycli.main import MyCli + + +SUPPORT_INFO = f"Home: {HOME_URL}\nBug tracker: {ISSUES_URL}" + + +def _main_module(): + from mycli import main as main_module + + return main_module + + +@dataclass(slots=True) +class ReplState: + iterations: int = 0 + mutating: bool = False + + +def _create_history(mycli: 'MyCli') -> FileHistoryWithTimestamp | None: + history_file = os.path.expanduser(os.environ.get('MYCLI_HISTFILE', mycli.config.get('history_file', '~/.mycli-history'))) + if dir_path_exists(history_file): + return FileHistoryWithTimestamp(history_file) + + mycli.echo( + f'Error: Unable to open the history file "{history_file}". Your query history will not be saved.', + err=True, + fg='red', + ) + return None + + +def _show_startup_banner( + mycli: 'MyCli', + sqlexecute: SQLExecute, +) -> None: + if mycli.less_chatty: + return + + print(sqlexecute.server_info) + print('mycli', mycli_package.__version__) + print(SUPPORT_INFO) + if random.random() <= 0.5: + print('Thanks to the contributor —', _thanks_picker()) + else: + print('Tip —', _tips_picker()) + + +def _get_prompt_message( + mycli: 'MyCli', + app: prompt_toolkit.application.application.Application, +) -> ANSI: + if app.current_buffer.text: + return mycli.last_prompt_message + + prompt = mycli.get_prompt(mycli.prompt_format, app.render_counter) + if mycli.prompt_format == mycli.default_prompt and len(prompt) > mycli.max_len_prompt: + prompt = mycli.get_prompt(mycli.default_prompt_splitln, app.render_counter) + mycli.prompt_lines = prompt.count('\n') + 1 + prompt = prompt.replace('\\x1b', '\x1b') + if not mycli.prompt_lines: + mycli.prompt_lines = prompt.count('\n') + 1 + mycli.last_prompt_message = ANSI(prompt) + return mycli.last_prompt_message + + +def _get_continuation( + mycli: 'MyCli', + width: int, + _two: int, + _three: int, +) -> AnyFormattedText: + if mycli.multiline_continuation_char == '': + continuation = '' + elif mycli.multiline_continuation_char: + left_padding = width - len(mycli.multiline_continuation_char) + continuation = ' ' * max((left_padding - 1), 0) + mycli.multiline_continuation_char + ' ' + else: + continuation = ' ' + return [('class:continuation', continuation)] + + +def _output_results( + mycli: 'MyCli', + state: ReplState, + results: Generator[SQLResult], + start: float, +) -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + + result_count = 0 + watch_count = 0 + for result in results: + mycli.logger.debug('preamble: %r', result.preamble) + mycli.logger.debug('header: %r', result.header) + mycli.logger.debug('rows: %r', result.rows) + mycli.logger.debug('status: %r', result.status) + mycli.logger.debug('command: %r', result.command) + threshold = 1000 + if result.command is not None and result.command['name'] == 'watch': + if watch_count > 0: + try: + watch_seconds = float(result.command['seconds']) + start += watch_seconds + except ValueError as e: + mycli.echo(f'Invalid watch sleep time provided ({e}).', err=True, fg='red') + sys.exit(1) + else: + watch_count += 1 + + if is_select(result.status_plain) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold: + mycli.echo( + f'The result set has more than {threshold} rows.', + fg='red', + ) + if not confirm('Do you want to continue?'): + mycli.echo('Aborted!', err=True, fg='red') + break + + if mycli.auto_vertical_output: + if mycli.prompt_app is not None: + max_width = mycli.prompt_app.output.get_size().columns + else: + max_width = DEFAULT_WIDTH + else: + max_width = None + + formatted = mycli.format_sqlresult( + result, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=mycli.null_string, + numeric_alignment=mycli.numeric_alignment, + binary_display=mycli.binary_display, + max_width=max_width, + ) + + duration = time.time() - start + try: + if result_count > 0: + mycli.echo('') + try: + mycli.output(formatted, result) + except KeyboardInterrupt: + pass + if mycli.beep_after_seconds > 0 and duration >= mycli.beep_after_seconds: + assert mycli.prompt_app is not None + mycli.prompt_app.output.bell() + if special.is_timing_enabled(): + mycli.output_timing(f'Time: {duration:0.03f}s') + except KeyboardInterrupt: + pass + + start = time.time() + result_count += 1 + state.mutating = state.mutating or is_mutating(result.status_plain) + + if mycli.show_warnings and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: + warnings = sqlexecute.run('SHOW WARNINGS') + warnings_duration = time.time() - start + saw_warning = False + for warning in warnings: + saw_warning = True + formatted = mycli.format_sqlresult( + warning, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=mycli.null_string, + numeric_alignment=mycli.numeric_alignment, + binary_display=mycli.binary_display, + max_width=max_width, + is_warnings_style=True, + ) + mycli.echo('') + mycli.output(formatted, warning, is_warnings_style=True) + + if saw_warning and special.is_timing_enabled(): + mycli.output_timing(f'Time: {warnings_duration:0.03f}s', is_warnings_style=True) + + +def _keepalive_hook( + mycli: 'MyCli', + _context: Any, +) -> None: + if mycli.keepalive_ticks is None: + return + if mycli.keepalive_ticks < 1: + return + + mycli._keepalive_counter += 1 + if mycli._keepalive_counter > mycli.keepalive_ticks: + mycli._keepalive_counter = 0 + mycli.logger.debug('keepalive ping') + try: + assert mycli.sqlexecute is not None + assert mycli.sqlexecute.conn is not None + mycli.sqlexecute.conn.ping(reconnect=False) + except Exception as e: + mycli.logger.debug('keepalive ping error %r', e) + + +def _build_prompt_session( + mycli: 'MyCli', + state: ReplState, + history: FileHistoryWithTimestamp | None, + key_bindings: KeyBindings, +) -> None: + if mycli.toolbar_format.lower() == 'none': + get_toolbar_tokens = None + else: + get_toolbar_tokens = create_toolbar_tokens_func( + mycli, + lambda: state.iterations == 0, + mycli.toolbar_format, + ) + + if mycli.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with mycli._completer_lock: + if mycli.key_bindings == 'vi': + editing_mode = EditingMode.VI + else: + editing_mode = EditingMode.EMACS + + mycli.prompt_app = PromptSession( + color_depth=ColorDepth.DEPTH_24_BIT if 'truecolor' in os.getenv('COLORTERM', '').lower() else None, + lexer=PygmentsLexer(MyCliLexer), + reserve_space_for_menu=mycli.get_reserved_space(), + prompt_continuation=lambda width, two, three: _get_continuation(mycli, width, two, three), + bottom_toolbar=get_toolbar_tokens, + complete_style=complete_style, + input_processors=[ + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars='[](){}'), + filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(), + ) + ], + tempfile_suffix='.sql', + completer=DynamicCompleter(lambda: mycli.completer), + complete_in_thread=True, + history=history, + auto_suggest=ThreadedAutoSuggest(AutoSuggestFromHistory()), + complete_while_typing=_main_module().complete_while_typing_filter, + multiline=cli_is_multiline(mycli), + style=style_factory_ptoolkit(mycli.syntax_style, mycli.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=editing_mode, + search_ignore_case=True, + ) + + if mycli.key_bindings == 'vi': + mycli.prompt_app.app.ttimeoutlen = mycli.vi_ttimeoutlen + else: + mycli.prompt_app.app.ttimeoutlen = mycli.emacs_ttimeoutlen + + +def _one_iteration( + mycli: 'MyCli', + state: ReplState, + text: str | None = None, +) -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + + inputhook = partial(_keepalive_hook, mycli) if mycli.keepalive_ticks and mycli.keepalive_ticks >= 1 else None + + if text is None: + try: + assert mycli.prompt_app is not None + loaded_message_fn = partial(_get_prompt_message, mycli, mycli.prompt_app.app) + text = mycli.prompt_app.prompt( + inputhook=inputhook, + message=loaded_message_fn, + ) + except KeyboardInterrupt: + return + + special.set_expanded_output(False) + special.set_forced_horizontal_output(False) + + try: + text = handle_editor_command( + mycli, + text, + inputhook, + loaded_message_fn, + ) + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + try: + if handle_clip_command(mycli, text): + return + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + while special.is_llm_command(text): + start = time.time() + try: + assert sqlexecute.conn is not None + cur = sqlexecute.conn.cursor() + context, sql, duration = special.handle_llm( + text, + cur, + sqlexecute.dbname or '', + mycli.llm_prompt_field_truncate, + mycli.llm_prompt_section_truncate, + ) + if context: + click.echo('LLM Response:') + click.echo(context) + click.echo('---') + if special.is_timing_enabled(): + mycli.output_timing(f'Time: {duration:.2f} seconds') + assert mycli.prompt_app is not None + text = mycli.prompt_app.prompt( + default=sql or '', + inputhook=inputhook, + message=loaded_message_fn, + ) + except KeyboardInterrupt: + return + except special.FinishIteration as e: + if e.results: + _output_results(mycli, state, e.results, start) + return + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + text = text.strip() + if not text: + return + + if is_redirect_command(text): + sql_part, command_part, file_operator_part, file_part = get_redirect_components(text) + text = sql_part or '' + try: + special.set_redirect(command_part, file_operator_part, file_part) + except (FileNotFoundError, OSError, RuntimeError) as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + if mycli.destructive_warning: + destroy = confirm_destructive_query(mycli.destructive_keywords, text) + if destroy is None: + pass + elif destroy is True: + mycli.echo('Your call!') + else: + mycli.echo('Wise choice!') + return + + successful = False + try: + mycli.logger.debug('sql: %r', text) + special.write_tee(mycli.last_prompt_message, nl=False) + special.write_tee(text) + mycli.log_query(text) + + start = time.time() + results = sqlexecute.run(text) + mycli.main_formatter.query = text + mycli.redirect_formatter.query = text + successful = True + _output_results(mycli, state, results, start) + special.unset_once_if_written(mycli.post_redirect_command) + special.flush_pipe_once_if_written(mycli.post_redirect_command) + except pymysql.err.InterfaceError: + if not mycli.reconnect(): + return + _one_iteration(mycli, state, text) + return + except EOFError as e: + raise e + except KeyboardInterrupt: + connection_id_to_kill = sqlexecute.connection_id or 0 + if connection_id_to_kill > 0: + mycli.logger.debug('connection id to kill: %r', connection_id_to_kill) + try: + sqlexecute.connect() + for kill_result in sqlexecute.run(f'kill {connection_id_to_kill}'): + status_str = str(kill_result.status_plain).lower() + if status_str.find('ok') > -1: + mycli.logger.debug('cancelled query, connection id: %r, sql: %r', connection_id_to_kill, text) + mycli.echo(f'Cancelled query id: {connection_id_to_kill}', err=True, fg='blue') + else: + mycli.logger.debug( + 'Failed to confirm query cancellation, connection id: %r, sql: %r', + connection_id_to_kill, + text, + ) + mycli.echo(f'Failed to confirm query cancellation, id: {connection_id_to_kill}', err=True, fg='red') + except Exception as e2: + mycli.echo(f'Encountered error while cancelling query: {e2}', err=True, fg='red') + else: + mycli.logger.debug('Did not get a connection id, skip cancelling query') + mycli.echo('Did not get a connection id, skip cancelling query', err=True, fg='red') + except NotImplementedError: + mycli.echo('Not Yet Implemented.', fg='yellow') + except pymysql.OperationalError as e1: + mycli.logger.debug('Exception: %r', e1) + if e1.args[0] in (2003, 2006, 2013): + if not mycli.reconnect(): + return + _one_iteration(mycli, state, text) + return + + mycli.logger.error('sql: %r, error: %r', text, e1) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e1), err=True, fg='red') + except Exception as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + else: + if is_dropping_database(text, sqlexecute.dbname): + sqlexecute.dbname = None + sqlexecute.connect() + + if need_completion_refresh(text): + mycli.refresh_completions(reset=need_completion_reset(text)) + finally: + if mycli.logfile is False: + mycli.echo('Warning: This query was not logged.', err=True, fg='red') + + query = Query(text, successful, state.mutating) + mycli.query_history.append(query) + + +def _thanks_picker() -> str: + lines: str = "" + + try: + with resources.files(mycli_package).joinpath("AUTHORS").open('r') as f: + lines += f.read() + except FileNotFoundError: + pass + + try: + with resources.files(mycli_package).joinpath("SPONSORS").open('r') as f: + lines += f.read() + except FileNotFoundError: + pass + + contents = [] + for line in lines.split("\n"): + if m := re.match(r"^ *\* (.*)", line): + contents.append(m.group(1)) + return random.choice(contents) if contents else 'our sponsors' + + +def _tips_picker() -> str: + tips = [] + + try: + with resources.files(mycli_package).joinpath('TIPS').open('r') as f: + for line in f: + if line.startswith("#"): + continue + if tip := line.strip(): + tips.append(tip) + except FileNotFoundError: + pass + + return random.choice(tips) if tips else r'\? or "help" for help!' + + +def main_repl(mycli: 'MyCli') -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + state = ReplState() + + mycli.configure_pager() + if mycli.smart_completion: + mycli.refresh_completions() + + history = _create_history(mycli) + key_bindings = mycli_bindings(mycli) + _show_startup_banner(mycli, sqlexecute) + _build_prompt_session(mycli, state, history, key_bindings) + mycli.set_all_external_titles() + + try: + while True: + _one_iteration(mycli, state) + state.iterations += 1 + except EOFError: + special.close_tee() + if not mycli.less_chatty: + mycli.echo('Goodbye!') diff --git a/mycli/types.py b/mycli/types.py new file mode 100644 index 00000000..207d62d9 --- /dev/null +++ b/mycli/types.py @@ -0,0 +1,4 @@ +from collections import namedtuple + +# Query tuples are used for maintaining history +Query = namedtuple("Query", ["query", "successful", "mutating"]) diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 3af76d21..bcfccaac 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -20,7 +20,7 @@ DEFAULT_USER, TEST_DATABASE, ) -from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint, thanks_picker +from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint import mycli.packages.special from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.sqlresult import SQLResult @@ -682,11 +682,6 @@ def test_batch_mode_csv(executor): assert expected in "".join(result.output) -def test_thanks_picker_utf8(): - name = thanks_picker() - assert name and isinstance(name, str) - - def test_help_strings_end_with_periods(): """Make sure click options have help text that end with a period.""" for param in click_entrypoint.params: diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py new file mode 100644 index 00000000..2d1812c6 --- /dev/null +++ b/test/pytests/test_main_modes_repl.py @@ -0,0 +1,890 @@ +from __future__ import annotations + +import builtins +from collections.abc import Generator, Iterator +from dataclasses import dataclass +from io import StringIO +import os +from types import SimpleNamespace +from typing import Any, Literal, cast + +from prompt_toolkit.formatted_text import to_plain_text +import pymysql +import pytest + +import mycli.main as main_module +import mycli.main_modes.repl as repl_mode +from mycli.packages.sqlresult import SQLResult + + +class DummyLogger: + def __init__(self) -> None: + self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + self.error_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def debug(self, *args: Any, **kwargs: Any) -> None: + self.debug_calls.append((args, kwargs)) + + def error(self, *args: Any, **kwargs: Any) -> None: + self.error_calls.append((args, kwargs)) + + +@dataclass +class DummyFormatterWithQuery: + query: str = '' + + +class FakeApp: + def __init__(self, text: str = '', render_counter: int = 0) -> None: + self.current_buffer = SimpleNamespace(text=text) + self.render_counter = render_counter + self.ttimeoutlen: float | None = None + + +class FakePromptOutput: + def __init__(self, columns: int = 80, rows: int = 24) -> None: + self.columns = columns + self.rows = rows + self.bell_count = 0 + + def get_size(self) -> SimpleNamespace: + return SimpleNamespace(columns=self.columns, rows=self.rows) + + def bell(self) -> None: + self.bell_count += 1 + + +class FakePromptSession: + def __init__(self, responses: list[Any] | None = None, columns: int = 80, rows: int = 24) -> None: + self.responses = list(responses or []) + self.output = FakePromptOutput(columns=columns, rows=rows) + self.app = FakeApp() + self.prompt_calls: list[dict[str, Any]] = [] + + def prompt(self, **kwargs: Any) -> str: + self.prompt_calls.append(dict(kwargs)) + if not self.responses: + raise EOFError() + response = self.responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + +class FakeCursorBase: + def __init__( + self, + rows: list[tuple[Any, ...]] | None = None, + rowcount: int = 0, + warning_count: int = 0, + ) -> None: + self._rows = list(rows or []) + self.rowcount = rowcount + self.warning_count = warning_count + + def __iter__(self) -> Iterator[tuple[Any, ...]]: + return iter(self._rows) + + +class FakeConnection: + def __init__(self, ping_exc: Exception | None = None, cursor_value: Any = 'cursor') -> None: + self.ping_exc = ping_exc + self.cursor_value = cursor_value + self.ping_calls: list[bool] = [] + + def ping(self, reconnect: bool = False) -> None: + self.ping_calls.append(reconnect) + if self.ping_exc is not None: + raise self.ping_exc + + def cursor(self) -> Any: + return self.cursor_value + + +class ReusableLock: + def __enter__(self) -> 'ReusableLock': + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + return False + + +def sqlresult_generator(*results: SQLResult) -> Generator[SQLResult, None, None]: + for result in results: + yield result + + +class FakeResourceTree: + def __init__(self, files: dict[str, str], path: str | None = None) -> None: + self.files = files + self.path = path + + def joinpath(self, path: str) -> 'FakeResourceTree': + return FakeResourceTree(self.files, path) + + def open(self, mode: str) -> StringIO: + assert self.path is not None + if self.path not in self.files: + raise FileNotFoundError(self.path) + return StringIO(self.files[self.path]) + + +def make_repl_cli(sqlexecute: Any | None = None) -> Any: + cli = SimpleNamespace() + cli.logger = DummyLogger() + cli.query_history = [] + cli.last_prompt_message = repl_mode.ANSI('') + cli.last_custom_toolbar_message = repl_mode.ANSI('') + cli.prompt_lines = 0 + cli.default_prompt = r'\t \u@\h:\d> ' + cli.default_prompt_splitln = r'\u@\h\n(\t):\d>' + cli.max_len_prompt = 45 + cli.prompt_format = cli.default_prompt + cli.multiline_continuation_char = '>' + cli.toolbar_format = 'default' + cli.less_chatty = True + cli.keepalive_ticks = None + cli._keepalive_counter = 0 + cli.auto_vertical_output = False + cli.beep_after_seconds = 0.0 + cli.show_warnings = False + cli.null_string = '' + cli.numeric_alignment = 'right' + cli.binary_display = None + cli.prompt_app = None + cli.post_redirect_command = None + cli.logfile = None + cli.smart_completion = False + cli.config = {'history_file': '~/.mycli-history-testing'} + cli.key_bindings = 'emacs' + cli.wider_completion_menu = False + cli._completer_lock = ReusableLock() + cli.completer = object() + cli.syntax_style = 'native' + cli.cli_style = {} + cli.emacs_ttimeoutlen = 1.0 + cli.vi_ttimeoutlen = 2.0 + cli.destructive_warning = False + cli.destructive_keywords = ['drop'] + cli.llm_prompt_field_truncate = 0 + cli.llm_prompt_section_truncate = 0 + cli.main_formatter = DummyFormatterWithQuery() + cli.redirect_formatter = DummyFormatterWithQuery() + cli.pager_configured = 0 + refresh_calls: list[bool] = [] + output_calls: list[tuple[list[str], Any, bool]] = [] + echo_calls: list[str] = [] + timing_calls: list[tuple[str, bool]] = [] + log_queries: list[str] = [] + cli.refresh_calls = refresh_calls + cli.output_calls = output_calls + cli.echo_calls = echo_calls + cli.timing_calls = timing_calls + cli.log_queries = log_queries + cli.title_calls = 0 + cli.sqlexecute = sqlexecute + cli.get_reserved_space = lambda: 3 + cli.get_last_query = lambda: cli.query_history[-1].query if cli.query_history else None + cli.configure_pager = lambda: setattr(cli, 'pager_configured', cli.pager_configured + 1) + + def refresh_completions(reset: bool = False) -> list[SQLResult]: + cli.refresh_calls.append(reset) + return [SQLResult(status='refresh')] + + cli.refresh_completions = refresh_completions + cli.set_all_external_titles = lambda: setattr(cli, 'title_calls', cli.title_calls + 1) + + def output_timing(timing: str, is_warnings_style: bool = False) -> None: + cli.timing_calls.append((timing, is_warnings_style)) + + cli.output_timing = output_timing + + def log_query(text: str) -> None: + cli.log_queries.append(text) + + cli.log_query = log_query + cli.reconnect = lambda database='': False + + def echo(message: Any, **kwargs: Any) -> None: + cli.echo_calls.append(str(message)) + + cli.echo = echo + + def format_sqlresult(result: SQLResult, **kwargs: Any) -> Iterator[str]: + return iter([str(kwargs.get('max_width')), result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult + + def output(formatted: Any, result: Any, is_warnings_style: bool = False) -> None: + cli.output_calls.append((list(formatted), result, is_warnings_style)) + + cli.output = output + cli.get_prompt = lambda string, render_counter: f'{string}:{render_counter}' + return cli + + +def patch_repl_runtime_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(repl_mode.special, 'set_expanded_output', lambda value: None) + monkeypatch.setattr(repl_mode.special, 'set_forced_horizontal_output', lambda value: None) + monkeypatch.setattr(repl_mode.special, 'is_llm_command', lambda text: False) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr(repl_mode.special, 'write_tee', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'unset_once_if_written', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: None) + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda mycli, text, inputhook, loaded_message_fn: text) + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: False) + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(repl_mode, 'confirm_destructive_query', lambda keywords, text: None) + monkeypatch.setattr(repl_mode, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(repl_mode, 'need_completion_reset', lambda text: False) + monkeypatch.setattr(repl_mode, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + +def test_repl_main_module_and_create_history(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli() + monkeypatch.setenv('MYCLI_HISTFILE', '~/override-history') + monkeypatch.setattr(repl_mode, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(repl_mode, 'FileHistoryWithTimestamp', lambda path: f'history:{path}') + assert repl_mode._main_module() is main_module + history = cast(Any, repl_mode._create_history(cli)) + assert history == f'history:{os.path.expanduser("~/override-history")}' + + monkeypatch.delenv('MYCLI_HISTFILE') + monkeypatch.setattr(repl_mode, 'dir_path_exists', lambda path: False) + assert repl_mode._create_history(cli) is None + assert 'Unable to open the history file' in cli.echo_calls[-1] + + +def test_repl_picker_helpers_cover_present_and_missing_resources(monkeypatch: pytest.MonkeyPatch) -> None: + files = { + 'AUTHORS': '* Alice\n* Bob\n', + 'SPONSORS': '* Carol\n', + 'TIPS': '# comment\nTip 1\n\nTip 2\n', + } + monkeypatch.setattr(repl_mode.resources, 'files', lambda package: FakeResourceTree(files)) + monkeypatch.setattr(repl_mode.random, 'choice', lambda seq: seq[0]) + assert repl_mode._thanks_picker() == 'Alice' + assert repl_mode._tips_picker() == 'Tip 1' + + monkeypatch.setattr(repl_mode.resources, 'files', lambda package: FakeResourceTree({})) + assert repl_mode._thanks_picker() == 'our sponsors' + assert repl_mode._tips_picker() == r'\? or "help" for help!' + + +def test_repl_show_startup_banner_and_prompt_helpers(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace(server_info='Server')) + printed: list[str] = [] + monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(' '.join(str(x) for x in args))) + monkeypatch.setattr(repl_mode.random, 'random', lambda: 0.4) + monkeypatch.setattr(repl_mode, '_thanks_picker', lambda: 'Alice') + monkeypatch.setattr(repl_mode, '_tips_picker', lambda: 'Tip') + + cli.less_chatty = False + repl_mode._show_startup_banner(cli, cli.sqlexecute) + monkeypatch.setattr(repl_mode.random, 'random', lambda: 0.6) + repl_mode._show_startup_banner(cli, cli.sqlexecute) + cli.less_chatty = True + repl_mode._show_startup_banner(cli, cli.sqlexecute) + assert any('Thanks to the contributor' in line for line in printed) + assert any('Tip — Tip' in line for line in printed) + + cli.get_prompt = lambda string, render_counter: '0123456' if string == cli.default_prompt else 'a\nb' + cli.max_len_prompt = 5 + prompt_text = to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=2)))) + assert prompt_text == 'a\nb' + assert cli.prompt_lines == 2 + + cli.last_prompt_message = repl_mode.ANSI('cached') + assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='typing', render_counter=3)))) == 'cached' + + cli.prompt_format = 'custom' + cli.prompt_lines = 0 + cli.get_prompt = lambda string, render_counter: 'single' + assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=4)))) == 'single' + assert cli.prompt_lines == 1 + + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', ' > ')] + cli.multiline_continuation_char = '' + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', '')] + cli.multiline_continuation_char = None + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', ' ')] + + +def test_output_results_covers_watch_warning_timing_beep_and_interrupts(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeSQLExecute: + def run(self, text: str) -> list[SQLResult]: + assert text == 'SHOW WARNINGS' + return [SQLResult(status='warning', rows=[('warn',)])] + + cli = make_repl_cli(FakeSQLExecute()) + cli.auto_vertical_output = True + cli.prompt_app = FakePromptSession(columns=91) + cli.beep_after_seconds = 0.1 + cli.show_warnings = True + state = repl_mode.ReplState() + format_widths: list[int | None] = [] + + def format_sqlresult(result: SQLResult, **kwargs: Any) -> Iterator[str]: + format_widths.append(kwargs.get('max_width')) + return iter([result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult + time_values = iter([0.2, 1.0, 2.0, 3.0, 3.2]) + monkeypatch.setattr(repl_mode.time, 'time', lambda: next(time_values)) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: status == 'mut') + + results = sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='mut', rows=cast(Any, FakeCursorBase(rowcount=1, warning_count=1))), + ) + + repl_mode._output_results(cli, state, results, start=0.0) + + assert state.mutating is True + assert format_widths[:2] == [91, 91] + assert cli.prompt_app.output.bell_count == 2 + assert '' in cli.echo_calls + assert any(is_warnings_style is True for _, _, is_warnings_style in cli.output_calls) + assert any(is_warnings_style is False for _, is_warnings_style in cli.timing_calls) + assert any(is_warnings_style is True for _, is_warnings_style in cli.timing_calls) + + cli_interrupt = make_repl_cli(SimpleNamespace()) + cli_interrupt.echo = lambda message, **kwargs: ( + (_ for _ in ()).throw(KeyboardInterrupt()) if message == '' else cli_interrupt.echo_calls.append(str(message)) + ) + cli_interrupt.output = lambda formatted, result, is_warnings_style=False: (_ for _ in ()).throw(KeyboardInterrupt()) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + monkeypatch.setattr(repl_mode.time, 'time', lambda: 0.0) + repl_mode._output_results( + cli_interrupt, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='first'), SQLResult(status='second')), + start=0.0, + ) + + +def test_output_results_handles_abort_default_width_and_bad_watch(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.auto_vertical_output = True + widths: list[int | None] = [] + + def format_sqlresult_with_width(result: SQLResult, **kwargs: Any) -> Iterator[str]: + widths.append(kwargs.get('max_width')) + return iter([result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult_with_width + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: status == 'select') + monkeypatch.setattr(repl_mode, 'confirm', lambda text: False) + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='select', rows=cast(Any, FakeCursorBase(rowcount=1001)))), + start=0.0, + ) + assert 'The result set has more than 1000 rows.' in cli.echo_calls + assert 'Aborted!' in cli.echo_calls + + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='ok')), + start=0.0, + ) + assert widths[-1] == repl_mode.DEFAULT_WIDTH + + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + with pytest.raises(SystemExit): + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='watch', command={'name': 'watch', 'seconds': 'bad'}), + ), + start=0.0, + ) + + +def test_keepalive_hook_covers_threshold_and_errors() -> None: + cli = make_repl_cli(SimpleNamespace(conn=FakeConnection())) + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + + cli.keepalive_ticks = 0 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + + cli.keepalive_ticks = 1 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 1 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + assert cli.sqlexecute.conn.ping_calls == [False] + + cli.sqlexecute.conn = FakeConnection(ping_exc=RuntimeError('boom')) + repl_mode._keepalive_hook(cli, None) + repl_mode._keepalive_hook(cli, None) + assert any('keepalive ping error' in call[0][0] for call in cli.logger.debug_calls) + + +def test_build_prompt_session_covers_toolbar_modes_and_editing_modes(monkeypatch: pytest.MonkeyPatch) -> None: + captured_kwargs: list[dict[str, Any]] = [] + toolbar_help: list[bool] = [] + + def fake_prompt_session(**kwargs: Any) -> FakePromptSession: + captured_kwargs.append(kwargs) + return FakePromptSession() + + monkeypatch.setattr(repl_mode, 'PromptSession', fake_prompt_session) + monkeypatch.setattr(repl_mode, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') + monkeypatch.setattr(repl_mode, 'cli_is_multiline', lambda mycli: False) + + def fake_toolbar_tokens(mycli: Any, show_help: Any, fmt: str) -> str: + toolbar_help.append(show_help()) + return 'toolbar' + + monkeypatch.setattr(repl_mode, 'create_toolbar_tokens_func', fake_toolbar_tokens) + + cli = make_repl_cli(SimpleNamespace()) + state = repl_mode.ReplState() + cli.toolbar_format = 'none' + cli.key_bindings = 'vi' + cli.wider_completion_menu = True + repl_mode._build_prompt_session(cli, state, history=cast(Any, 'history'), key_bindings=cast(Any, 'bindings')) + first_kwargs = captured_kwargs[-1] + assert first_kwargs['bottom_toolbar'] is None + assert first_kwargs['complete_style'] == repl_mode.CompleteStyle.MULTI_COLUMN + assert first_kwargs['editing_mode'] == repl_mode.EditingMode.VI + assert cli.prompt_app.app.ttimeoutlen == cli.vi_ttimeoutlen + + cli.toolbar_format = 'default' + cli.key_bindings = 'emacs' + cli.wider_completion_menu = False + state.iterations = 0 + repl_mode._build_prompt_session(cli, state, history=cast(Any, 'history'), key_bindings=cast(Any, 'bindings')) + latest_kwargs = captured_kwargs[-1] + assert latest_kwargs['bottom_toolbar'] == 'toolbar' + assert latest_kwargs['complete_style'] == repl_mode.CompleteStyle.COLUMN + assert latest_kwargs['editing_mode'] == repl_mode.EditingMode.EMACS + assert toolbar_help == [True] + assert cli.prompt_app.app.ttimeoutlen == cli.emacs_ttimeoutlen + assert latest_kwargs['prompt_continuation'](4, 0, 0) == [('class:continuation', ' > ')] + + +def test_one_iteration_handles_prompt_interrupt_empty_editor_clip_and_clip_true(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + cli = make_repl_cli(SimpleNamespace(run=lambda text: iter([SQLResult(status='ok')]), conn=FakeConnection())) + cli.keepalive_ticks = 1 + cli.prompt_app = FakePromptSession([KeyboardInterrupt(), ' ', 'edit-error', 'clip-error', 'clip-stop']) + + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + inputhook = cli.prompt_app.prompt_calls[-1]['inputhook'] + assert inputhook is not None + inputhook(None) + + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda *args: (_ for _ in ()).throw(RuntimeError('edit boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert 'edit boom' in cli.echo_calls[-1] + + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda mycli, text, inputhook, loaded_message_fn: text) + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: (_ for _ in ()).throw(RuntimeError('clip boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert 'clip boom' in cli.echo_calls[-1] + + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: True) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + + +def test_one_iteration_covers_llm_paths(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + click_output: list[str] = [] + monkeypatch.setattr(repl_mode.click, 'echo', lambda message='', **kwargs: click_output.append(str(message))) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode.special, 'is_llm_command', lambda text: text.startswith('\\llm')) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.conn = FakeConnection(cursor_value='cursor') + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status=f'ran:{text}')]) + + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda text, cur, dbname, field_truncate, section_truncate: ('context', 'select 1', 1.25), + ) + cli = make_repl_cli(FakeSQLExecute()) + cli.prompt_app = FakePromptSession(['\\llm ask', 'select 1']) + repl_mode._one_iteration( + cli, + repl_mode.ReplState(), + ) + assert click_output[:3] == ['LLM Response:', 'context', '---'] + assert cli.output_calls[0][0] == ['None', 'ran:select 1'] + + cli_finish = make_repl_cli(FakeSQLExecute()) + cli_finish.prompt_app = FakePromptSession(['\\llm finish']) + cli_finish.format_sqlresult = lambda result, **kwargs: iter([result.status_plain or 'row']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(repl_mode.special.FinishIteration(iter([SQLResult(status='done')]))), + ) + repl_mode._one_iteration(cli_finish, repl_mode.ReplState()) + assert cli_finish.output_calls[0][0] == ['done'] + + cli_empty = make_repl_cli(FakeSQLExecute()) + cli_empty.prompt_app = FakePromptSession(['\\llm empty']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(repl_mode.special.FinishIteration(None)), + ) + repl_mode._one_iteration(cli_empty, repl_mode.ReplState()) + assert cli_empty.output_calls == [] + + cli_err = make_repl_cli(FakeSQLExecute()) + cli_err.prompt_app = FakePromptSession(['\\llm err']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError('llm boom')), + ) + repl_mode._one_iteration(cli_err, repl_mode.ReplState()) + assert 'llm boom' in cli_err.echo_calls[-1] + + cli_interrupt = make_repl_cli(FakeSQLExecute()) + cli_interrupt.prompt_app = FakePromptSession(['\\llm stop']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(KeyboardInterrupt()), + ) + repl_mode._one_iteration(cli_interrupt, repl_mode.ReplState()) + assert cli_interrupt.output_calls == [] + + cli_quiet = make_repl_cli(FakeSQLExecute()) + cli_quiet.prompt_app = FakePromptSession(['\\llm quiet', 'select 2']) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda text, cur, dbname, field_truncate, section_truncate: ('', 'select 2', 0.5), + ) + repl_mode._one_iteration(cli_quiet, repl_mode.ReplState()) + assert cli_quiet.output_calls[0][0] == ['None', 'ran:select 2'] + + +def test_one_iteration_covers_redirect_destructive_success_refresh_and_logfile(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def connect(self) -> None: + self.calls.append('connect') + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + return iter([SQLResult(status='DROP 1')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.logfile = False + cli.destructive_warning = True + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: text == 'redirect') + monkeypatch.setattr(repl_mode, 'get_redirect_components', lambda text: ('dropdb', 'tee', '>', 'out.txt')) + redirects: list[tuple[Any, ...]] = [] + monkeypatch.setattr(repl_mode.special, 'set_redirect', lambda *args: redirects.append(args)) + monkeypatch.setattr( + repl_mode, + 'confirm_destructive_query', + lambda keywords, text: None if text == 'dropdb' else (True if text == 'approved' else False), + ) + monkeypatch.setattr(repl_mode, 'is_dropping_database', lambda text, dbname: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'need_completion_refresh', lambda text: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'need_completion_reset', lambda text: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: True) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'redirect') + assert redirects == [('tee', '>', 'out.txt')] + assert cli.refresh_calls == [True] + assert cli.query_history[-1].query == 'dropdb' + assert cli.query_history[-1].successful is True + assert cli.query_history[-1].mutating is True + assert sqlexecute.dbname is None + assert sqlexecute.calls == ['dropdb', 'connect'] + assert 'Warning: This query was not logged.' in cli.echo_calls + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'approved') + assert 'Your call!' in cli.echo_calls + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'denied') + assert 'Wise choice!' in cli.echo_calls + + +def test_one_iteration_covers_reconnect_and_error_paths(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class InterfaceSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + if text == 'iface' and self.calls.count('iface') == 1: + raise pymysql.err.InterfaceError() + return iter([SQLResult(status=f'ok:{text}')]) + + interface_sql = InterfaceSQLExecute() + cli_interface = make_repl_cli(interface_sql) + interface_reconnect_calls: list[str] = [] + interface_results = iter([True]) + + def interface_reconnect(database: str = '') -> bool: + interface_reconnect_calls.append(database) + return next(interface_results) + + cli_interface.reconnect = interface_reconnect + + repl_mode._one_iteration(cli_interface, repl_mode.ReplState(), 'iface') + assert interface_sql.calls.count('iface') == 2 + assert cli_interface.query_history[-1].query == 'iface' + assert interface_reconnect_calls == [''] + + cli_interface_false = make_repl_cli(InterfaceSQLExecute()) + false_calls: list[str] = [] + + def interface_reconnect_false(database: str = '') -> bool: + false_calls.append(database) + return False + + cli_interface_false.reconnect = interface_reconnect_false + repl_mode._one_iteration(cli_interface_false, repl_mode.ReplState(), 'iface') + assert false_calls == [''] + + class ErrorSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + if text == 'oplost' and self.calls.count('oplost') == 1: + raise pymysql.OperationalError(2003, 'lost') + if text == 'opbad': + raise pymysql.OperationalError(9999, 'bad op') + if text == 'nyi': + raise NotImplementedError() + if text == 'boom': + raise RuntimeError('boom') + if text == 'eof': + raise EOFError() + return iter([SQLResult(status=f'ok:{text}')]) + + error_sql = ErrorSQLExecute() + cli_error = make_repl_cli(error_sql) + error_reconnect_calls: list[str] = [] + + def error_reconnect(database: str = '') -> bool: + error_reconnect_calls.append(database) + return True + + cli_error.reconnect = error_reconnect + + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'oplost') + assert error_sql.calls.count('oplost') == 2 + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'opbad') + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'nyi') + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'boom') + assert any('bad op' in line for line in cli_error.echo_calls) + assert 'Not Yet Implemented.' in cli_error.echo_calls + assert any('boom' in line for line in cli_error.echo_calls) + assert error_reconnect_calls == [''] + + cli_error_false = make_repl_cli(ErrorSQLExecute()) + false_reconnect_calls: list[str] = [] + + def error_reconnect_false(database: str = '') -> bool: + false_reconnect_calls.append(database) + return False + + cli_error_false.reconnect = error_reconnect_false + repl_mode._one_iteration(cli_error_false, repl_mode.ReplState(), 'oplost') + assert false_reconnect_calls == [''] + + with pytest.raises(EOFError): + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'eof') + + +def test_one_iteration_reraises_eoferror(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class EofSQLExecute: + dbname = 'db' + connection_id = 0 + + def run(self, text: str) -> Iterator[SQLResult]: + raise EOFError() + + with pytest.raises(EOFError): + repl_mode._one_iteration(make_repl_cli(EofSQLExecute()), repl_mode.ReplState(), 'eof') + + +def test_one_iteration_covers_cancel_paths_and_redirect_error(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + + def connect(self) -> None: + return None + + def run(self, text: str) -> Iterator[SQLResult]: + if text == 'cancel-ok': + self.connection_id = 7 + raise KeyboardInterrupt() + if text == 'kill 7': + return iter([SQLResult(status='OK')]) + if text == 'cancel-fail': + self.connection_id = 8 + raise KeyboardInterrupt() + if text == 'kill 8': + return iter([SQLResult(status='failed')]) + if text == 'cancel-error': + self.connection_id = 9 + raise KeyboardInterrupt() + if text == 'kill 9': + raise RuntimeError('kill failed') + if text == 'cancel-missing': + self.connection_id = 0 + raise KeyboardInterrupt() + return iter([SQLResult(status='ok')]) + + cli = make_repl_cli(FakeSQLExecute()) + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: text == 'redirect-bad') + monkeypatch.setattr(repl_mode, 'get_redirect_components', lambda text: ('sql', 'tee', '>', 'out.txt')) + monkeypatch.setattr(repl_mode.special, 'set_redirect', lambda *args: (_ for _ in ()).throw(RuntimeError('redirect boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'redirect-bad') + assert 'redirect boom' in cli.echo_calls[-1] + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-ok') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-fail') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-error') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-missing') + assert 'Cancelled query id: 7' in cli.echo_calls + assert any('Failed to confirm query cancellation' in line for line in cli.echo_calls) + assert any('Encountered error while cancelling query' in line for line in cli.echo_calls) + assert 'Did not get a connection id, skip cancelling query' in cli.echo_calls + + +def test_main_repl_covers_setup_loop_and_goodbye(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.less_chatty = False + cli.smart_completion = True + loop_iterations: list[int] = [] + monkeypatch.setattr(repl_mode, '_create_history', lambda mycli: 'history') + monkeypatch.setattr(repl_mode, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(repl_mode, '_show_startup_banner', lambda mycli, sqlexecute: None) + monkeypatch.setattr( + repl_mode, + '_build_prompt_session', + lambda mycli, state, history, key_bindings: setattr(mycli, 'prompt_app', FakePromptSession()), + ) + + def fake_one_iteration(mycli: Any, state: repl_mode.ReplState) -> None: + loop_iterations.append(state.iterations) + if len(loop_iterations) == 2: + raise EOFError() + + closed: list[bool] = [] + monkeypatch.setattr(repl_mode, '_one_iteration', fake_one_iteration) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: closed.append(True)) + + repl_mode.main_repl(cli) + + assert cli.pager_configured == 1 + assert cli.refresh_calls == [False] + assert cli.title_calls == 1 + assert loop_iterations == [0, 1] + assert closed == [True] + assert cli.echo_calls[-1] == 'Goodbye!' + + +def test_main_repl_covers_no_refresh_and_quiet_exit(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.less_chatty = True + cli.smart_completion = False + monkeypatch.setattr(repl_mode, '_create_history', lambda mycli: 'history') + monkeypatch.setattr(repl_mode, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(repl_mode, '_show_startup_banner', lambda mycli, sqlexecute: None) + monkeypatch.setattr( + repl_mode, + '_build_prompt_session', + lambda mycli, state, history, key_bindings: setattr(mycli, 'prompt_app', FakePromptSession()), + ) + monkeypatch.setattr(repl_mode, '_one_iteration', lambda mycli, state: (_ for _ in ()).throw(EOFError())) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: None) + + repl_mode.main_repl(cli) + + assert cli.refresh_calls == [] + assert cli.echo_calls == [] + + +def test_output_results_covers_remaining_watch_select_and_warning_branches(monkeypatch: pytest.MonkeyPatch) -> None: + class WarninglessSQLExecute: + def run(self, text: str) -> list[SQLResult]: + assert text == 'SHOW WARNINGS' + return [] + + cli = make_repl_cli(WarninglessSQLExecute()) + cli.show_warnings = True + cli.auto_vertical_output = False + cli.prompt_app = FakePromptSession(columns=77) + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + monkeypatch.setattr(repl_mode, 'confirm', lambda text: True) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: status == 'select') + monkeypatch.setattr(repl_mode.time, 'time', lambda: 0.0) + + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='watch', command={'name': 'watch', 'seconds': '2'}), + SQLResult(status='select', rows=cast(Any, FakeCursorBase(rowcount=1001, warning_count=1))), + ), + start=0.0, + ) + assert cli.output_calls diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index f12bd1a5..a04de3ec 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -10,6 +10,10 @@ * migrating individual tests if content moves out of main.py * migrating individual tests to test_main.py after assessment of quality * removing and rewriting these tests if contracts change + +For example, since the generation of these tests, main_modes/repl.py was +created, and all tests here touching the REPL functionality should in +principle be removed. """ from __future__ import annotations @@ -21,7 +25,9 @@ import itertools import os from pathlib import Path +import random import sys +import time from types import ModuleType, SimpleNamespace from typing import Any, Callable, Literal, cast @@ -31,7 +37,9 @@ import pymysql import pytest -from mycli import key_bindings, main +from mycli import main +import mycli.key_bindings +import mycli.main_modes.repl from mycli.packages import key_binding_utils from mycli.packages.sqlresult import SQLResult @@ -677,15 +685,18 @@ def test_initialize_logging_covers_none_bad_path_and_file_handler(tmp_path: Path cli.echo = lambda message, **kwargs: echo_calls.append(message) # type: ignore[assignment] cli.config = {'main': {'log_file': str(tmp_path / 'mycli.log'), 'log_level': 'NONE'}} monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) main.MyCli.initialize_logging(cli) cli.config = {'main': {'log_file': str(tmp_path / 'missing' / 'mycli.log'), 'log_level': 'INFO'}} monkeypatch.setattr(main, 'dir_path_exists', lambda path: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: False) main.MyCli.initialize_logging(cli) assert echo_calls[-1].startswith('Error: Unable to open the log file') cli.config = {'main': {'log_file': str(tmp_path / 'mycli.log'), 'log_level': 'INFO'}} monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) main.MyCli.initialize_logging(cli) @@ -1043,23 +1054,23 @@ def test_handle_editor_clip_and_output_timing(monkeypatch: pytest.MonkeyPatch) - monkeypatch.setattr(main.special, 'get_filename', lambda text: 'query.sql') monkeypatch.setattr(main.special, 'get_editor_query', lambda text: 'select 1') monkeypatch.setattr(main.special, 'open_external_editor', lambda filename, sql: ('edited sql', None)) - assert key_binding_utils.handle_editor_command(cli, r'select 1\e', None, lambda: None) == 'edited sql' + assert mycli.main_modes.repl.handle_editor_command(cli, r'select 1\e', None, lambda: None) == 'edited sql' monkeypatch.setattr(main.special, 'open_external_editor', lambda filename, sql: ('', 'boom')) with pytest.raises(RuntimeError, match='boom'): - key_binding_utils.handle_editor_command(cli, r'select 1\e', None, lambda: None) + mycli.main_modes.repl.handle_editor_command(cli, r'select 1\e', None, lambda: None) monkeypatch.setattr(main.special, 'clip_command', lambda text: True) monkeypatch.setattr(main.special, 'get_clip_query', lambda text: None) monkeypatch.setattr(main.special, 'copy_query_to_clipboard', lambda sql: None) - assert key_binding_utils.handle_clip_command(cli, r'select 1\clip') is True + assert mycli.main_modes.repl.handle_clip_command(cli, r'select 1\clip') is True monkeypatch.setattr(main.special, 'copy_query_to_clipboard', lambda sql: 'clipboard failed') with pytest.raises(RuntimeError, match='clipboard failed'): - key_binding_utils.handle_clip_command(cli, r'select 1\clip') + mycli.main_modes.repl.handle_clip_command(cli, r'select 1\clip') monkeypatch.setattr(main.special, 'clip_command', lambda text: False) - assert key_binding_utils.handle_clip_command(cli, 'select 1') is False + assert mycli.main_modes.repl.handle_clip_command(cli, 'select 1') is False printed: list[tuple[Any, Any]] = [] monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) @@ -1103,7 +1114,7 @@ def test_format_sqlresult_run_query_reserved_space_and_last_query(monkeypatch: p assert main.MyCli.get_last_query(cli) == 'select 1' -def test_reconnect_logging_output_titles_prompt_and_picker_fallbacks(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: +def test_reconnect_logging_output_titles_prompt(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cli = make_bare_mycli() sqlexecute = object.__new__(main.SQLExecute) @@ -1195,17 +1206,6 @@ def failing_connect() -> None: monkeypatch.setattr(main.subprocess, 'run', lambda *args, **kwargs: (_ for _ in ()).throw(FileNotFoundError())) main.MyCli.set_external_multiplex_window_title(cli) - class MissingResource: - def joinpath(self, name: str) -> 'MissingResource': - return self - - def open(self, mode: str) -> StringIO: - raise FileNotFoundError() - - monkeypatch.setattr(main.resources, 'files', lambda package: MissingResource()) - assert main.thanks_picker() == 'our sponsors' - assert main.tips_picker() == r'\? or "help" for help!' - def test_reconnect_first_and_second_passes(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() @@ -1480,40 +1480,6 @@ def test_completion_helpers_title_helpers_thanks_tips(monkeypatch: pytest.Monkey assert list(main.MyCli.get_completions(cli, 'select', 6)) == ['done'] assert entered_lock['count'] >= 2 - class FakeResource: - def __init__(self, text: str | None) -> None: - self.text = text - - def joinpath(self, name: str) -> 'FakeResource': - if name == 'AUTHORS': - return FakeResource('* Alice\n') - if name == 'SPONSORS': - raise FileNotFoundError() - if name == 'TIPS': - return FakeResource('# comment\nTip one\n\nTip two\n') - raise FileNotFoundError() - - def open(self, mode: str) -> StringIO: - if self.text is None: - raise FileNotFoundError() - return StringIO(self.text) - - monkeypatch.setattr(main.resources, 'files', lambda package: FakeResource(None)) - monkeypatch.setattr(main, 'choice', lambda values: values[0]) - assert main.thanks_picker() == 'Alice' - assert main.tips_picker() == 'Tip one' - - class SponsorResource(FakeResource): - def joinpath(self, name: str) -> 'FakeResource': - if name == 'AUTHORS': - raise FileNotFoundError() - if name == 'SPONSORS': - return FakeResource('* Sponsor Person\n') - raise FileNotFoundError() - - monkeypatch.setattr(main.resources, 'files', lambda package: SponsorResource(None)) - assert main.thanks_picker() == 'Sponsor Person' - def test_main_wrapper_and_edit_and_execute(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(main, 'filtered_sys_argv', lambda: ['--help']) @@ -1555,7 +1521,7 @@ class ErrorNoCode(click.ClickException): current_buffer=SimpleNamespace(open_in_editor=lambda validate_and_handle=False: opened.append(validate_and_handle)) ), ) - key_bindings.edit_and_execute(event) + mycli.key_bindings.edit_and_execute(event) assert opened == [False] @@ -2057,6 +2023,9 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> list[SQLResult]: return [SQLResult(status='SELECT 1', header=['a'], rows=[(1,)])] @@ -2065,12 +2034,13 @@ def run(self, text: str) -> list[SQLResult]: sqlexecute = FakeRunSQLExecute() cli.sqlexecute = cast(Any, sqlexecute) monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: False) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: False) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2081,10 +2051,10 @@ def run(self, text: str) -> list[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) main.MyCli.run_cli(cli) assert refresh_resets == [False] assert outputs == [['formatted']] @@ -2093,6 +2063,14 @@ def run(self, text: str) -> list[SQLResult]: assert prompt_session.app.ttimeoutlen == cli.emacs_ttimeoutlen +def test_run_cli_delegates_to_main_repl(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + calls: list[Any] = [] + monkeypatch.setattr(main, 'main_repl', lambda target: calls.append(target)) + main.MyCli.run_cli(cli) + assert calls == [cli] + + def test_run_cli_large_select_asks_for_confirmation(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() cli.config = {'history_file': '~/.mycli-history-testing'} @@ -2105,13 +2083,14 @@ def test_run_cli_large_select_asks_for_confirmation(monkeypatch: pytest.MonkeyPa echoed: list[str] = [] cli.echo = lambda message, **kwargs: echoed.append(str(message)) # type: ignore[assignment] prompt_session = FakePromptSession(responses=['select * from t', EOFError()]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) - monkeypatch.setattr(main, 'Cursor', FakeCursorBase) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'Cursor', FakeCursorBase) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2122,11 +2101,11 @@ def test_run_cli_large_select_asks_for_confirmation(monkeypatch: pytest.MonkeyPa monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) - monkeypatch.setattr(main, 'confirm', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'confirm', lambda text: False) rows = FakeCursorBase(rows=[(1,)], rowcount=1001, description=[('id', 3)], warning_count=0) class FakeRunSQLExecute: @@ -2161,13 +2140,14 @@ def test_run_cli_outputs_warnings_and_timing(monkeypatch: pytest.MonkeyPatch) -> cli.output_timing = lambda timing, is_warnings_style=False: timings.append((timing, is_warnings_style)) # type: ignore[assignment] cli.format_sqlresult = lambda result, **kwargs: iter([result.status_plain or 'row']) # type: ignore[assignment] prompt_session = FakePromptSession(responses=['select 1', EOFError()]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) - monkeypatch.setattr(main, 'Cursor', FakeCursorBase) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'Cursor', FakeCursorBase) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2178,10 +2158,10 @@ def test_run_cli_outputs_warnings_and_timing(monkeypatch: pytest.MonkeyPatch) -> monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) warning_rows = FakeCursorBase(rows=[('Level', 1, 'Message')], rowcount=1, description=[('id', 3)], warning_count=1) main_result = SQLResult(status='SELECT 1', header=['id'], rows=cast(Any, warning_rows)) warning_result = SQLResult(status='Warning', header=['level'], rows=[('Warning',)]) @@ -2191,6 +2171,9 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> list[SQLResult]: if text == 'SHOW WARNINGS': @@ -2237,6 +2220,9 @@ def __init__(self) -> None: self.server_info = 'Server' self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) @@ -2249,21 +2235,20 @@ def fake_prompt_session(**kwargs: Any) -> InspectPromptSession: continuations.append(kwargs['prompt_continuation'](4, 0, 0)) return prompt_session - monkeypatch.setattr(main, 'PromptSession', fake_prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', fake_prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') def fake_create_toolbar_tokens(mycli: Any, show_help: Any, fmt: str) -> str: toolbar_help.append(show_help()) return 'toolbar' - monkeypatch.setattr(main, 'create_toolbar_tokens_func', fake_create_toolbar_tokens) + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', fake_create_toolbar_tokens) monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main.random, 'random', lambda: 0.4) - monkeypatch.setattr(main, 'thanks_picker', lambda: 'Alice') - monkeypatch.setattr(main, 'tips_picker', lambda: 'Tip') + monkeypatch.setattr(random, 'random', lambda: 0.4) monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: prints.append(' '.join(str(x) for x in args))) echoed: list[str] = [] cli.echo = lambda message, **kwargs: echoed.append(str(message)) # type: ignore[assignment] @@ -2303,6 +2288,9 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.conn = LLMConnection() + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> Iterator[SQLResult]: return iter([SQLResult(status=f'ran:{text}')]) @@ -2310,12 +2298,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) prompt_session = FakePromptSession(responses=['\\llm ask', 'select 1', '\\llm finish', '\\llm empty', '\\llm err', EOFError()]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_expanded_output', lambda: False) @@ -2325,10 +2314,10 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: text.startswith('\\llm')) def fake_handle_llm(text: str, cur: Any, dbname: str, field_truncate: int, section_truncate: int) -> tuple[str, str, float]: @@ -2396,6 +2385,9 @@ def __init__(self) -> None: self.connection_id = 0 self.conn = SimpleNamespace() self.calls: list[str] = [] + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def connect(self) -> None: self.calls.append('connect') @@ -2417,12 +2409,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) sqlexecute = FakeRunSQLExecute() cli.sqlexecute = cast(Any, sqlexecute) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2433,11 +2426,11 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: text == 'dropdb') - monkeypatch.setattr(main, 'need_completion_reset', lambda text: True) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: text == 'dropdb') + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: text == 'dropdb') + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_reset', lambda text: True) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: text == 'dropdb') main.MyCli.run_cli(cli) assert reconnect_calls == ['', ''] @@ -2479,6 +2472,9 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.conn = SimpleNamespace(cursor=lambda: 'cursor') + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def connect(self) -> None: return None @@ -2496,12 +2492,13 @@ def run(self, text: str) -> Iterator[SQLResult]: raise EOFError() return iter([SQLResult(status=f'ok:{text}')]) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_expanded_output', lambda: False) @@ -2511,10 +2508,10 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: text.startswith('\\llm')) monkeypatch.setattr(main.special, 'handle_llm', lambda *args, **kwargs: (_ for _ in ()).throw(KeyboardInterrupt())) monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) @@ -2542,18 +2539,22 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> Iterator[SQLResult]: if text == 'iface': raise pymysql.err.InterfaceError() raise pymysql.OperationalError(2003, 'lost') - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2564,62 +2565,15 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) main.MyCli.run_cli(cli) -def test_run_cli_tip_prompt_lines_toolbar_none_and_keepalive_noops(monkeypatch: pytest.MonkeyPatch) -> None: - cli = make_bare_mycli() - cli.less_chatty = False - cli.toolbar_format = 'none' - cli.keepalive_ticks = 1 - cli.prompt_format = 'prompt' - cli.config = {'history_file': '~/.mycli-history-testing'} - cli.set_all_external_titles = lambda: None # type: ignore[assignment] - cli.get_prompt = lambda string, render_counter: 'prompt' # type: ignore[assignment] - printed: list[str] = [] - - class PromptOnce(FakePromptSession): - def prompt(self, **kwargs: Any) -> str: - inputhook = kwargs.get('inputhook') - if inputhook is not None: - cli.keepalive_ticks = None - inputhook(None) - cli.keepalive_ticks = 0 - inputhook(None) - kwargs['message']() - raise EOFError() - - class FakeRunSQLExecute: - def __init__(self) -> None: - self.server_info = 'Server' - self.dbname = 'db' - self.connection_id = 0 - - monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) - cli.sqlexecute = cast(Any, FakeRunSQLExecute()) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: PromptOnce()) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr( - main, 'create_toolbar_tokens_func', lambda *args: (_ for _ in ()).throw(AssertionError('toolbar should be disabled')) - ) - monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') - monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) - monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main.random, 'random', lambda: 0.6) - monkeypatch.setattr(main, 'tips_picker', lambda: 'Tip') - monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(' '.join(str(x) for x in args))) - main.MyCli.run_cli(cli) - assert any('Tip' in line for line in printed) - assert cli.prompt_lines == 1 - - def test_run_cli_watch_beep_auto_vertical_and_cancel_failure_paths(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() cli.config = {'history_file': '~/.mycli-history-testing'} @@ -2649,6 +2603,9 @@ def __init__(self) -> None: self.dbname = 'db' self.connection_id = 0 self.conn = SimpleNamespace() + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def connect(self) -> None: return None @@ -2673,12 +2630,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2689,11 +2647,11 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) - monkeypatch.setattr(main, 'time', iter([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).__next__) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(time, 'time', iter([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).__next__) main.MyCli.run_cli(cli) assert recorded_widths[:2] == [91, 91] assert '' in echoes @@ -2726,6 +2684,9 @@ def __init__(self) -> None: self.server_info = SimpleNamespace(species=SimpleNamespace(name='MySQL')) self.dbname = 'db' self.connection_id = 0 + self.host = 'localhost' + self.port = 3306 + self.user = 'root' def run(self, text: str) -> Iterator[SQLResult]: cli.prompt_app = None @@ -2733,12 +2694,13 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main, 'SQLExecute', FakeRunSQLExecute) cli.sqlexecute = cast(Any, FakeRunSQLExecute()) - monkeypatch.setattr(main, 'PromptSession', lambda **kwargs: prompt_session) - monkeypatch.setattr(main, 'mycli_bindings', lambda mycli: 'bindings') - monkeypatch.setattr(main, 'create_toolbar_tokens_func', lambda *args: 'toolbar') + monkeypatch.setattr(mycli.main_modes.repl, 'PromptSession', lambda **kwargs: prompt_session) + monkeypatch.setattr(mycli.main_modes.repl, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(mycli.main_modes.repl, 'create_toolbar_tokens_func', lambda *args: 'toolbar') monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) - monkeypatch.setattr(main, 'cli_is_multiline', lambda mycli: False) + monkeypatch.setattr(mycli.main_modes.repl, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(mycli.main_modes.repl, 'cli_is_multiline', lambda mycli: False) monkeypatch.setattr(main.special, 'set_expanded_output', lambda value: None) monkeypatch.setattr(main.special, 'set_forced_horizontal_output', lambda value: None) monkeypatch.setattr(main.special, 'is_llm_command', lambda text: False) @@ -2749,9 +2711,9 @@ def run(self, text: str) -> Iterator[SQLResult]: monkeypatch.setattr(main.special, 'unset_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) monkeypatch.setattr(main.special, 'close_tee', lambda: None) - monkeypatch.setattr(main, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_redirect_command', lambda text: False) monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, text: None) - monkeypatch.setattr(main, 'need_completion_refresh', lambda text: False) - monkeypatch.setattr(main, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(mycli.main_modes.repl, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(mycli.main_modes.repl, 'is_dropping_database', lambda text, dbname: False) main.MyCli.run_cli(cli) assert widths == [main.DEFAULT_WIDTH]