Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions mycli/clitoolbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

from prompt_toolkit.application import get_app
from prompt_toolkit.enums import EditingMode
from prompt_toolkit.formatted_text import to_formatted_text
from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
from prompt_toolkit.key_binding.vi_state import InputMode

from mycli.packages import special


def create_toolbar_tokens_func(mycli, show_initial_toolbar_help: Callable, format_string: str | None) -> Callable:
def create_toolbar_tokens_func(
mycli,
show_initial_toolbar_help: Callable[[], bool],
format_string: str | None,
get_custom_toolbar: Callable[[str], AnyFormattedText],
) -> Callable[[], list[tuple[str, str]]]:
"""Return a function that generates the toolbar tokens."""

def get_toolbar_tokens() -> list[tuple[str, str]]:
Expand Down Expand Up @@ -73,7 +78,7 @@ def get_toolbar_tokens() -> list[tuple[str, str]]:
else:
amended_format = format_string
result = []
formatted = to_formatted_text(mycli.get_custom_toolbar(amended_format), style='class:bottom-toolbar')
formatted = to_formatted_text(get_custom_toolbar(amended_format), style='class:bottom-toolbar')
result.extend([*formatted]) # coerce to list for mypy

result.extend(dynamic)
Expand Down
162 changes: 6 additions & 156 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from collections import defaultdict
from dataclasses import dataclass
from decimal import Decimal
import functools
from io import TextIOWrapper
import logging
import os
import re
import shutil
import subprocess
import sys
import threading
import traceback
Expand Down Expand Up @@ -73,17 +71,15 @@
from mycli.main_modes.execute import main_execute_from_cli
from mycli.main_modes.list_dsn import main_list_dsn
from mycli.main_modes.list_ssh_config import main_list_ssh_config
from mycli.main_modes.repl import main_repl
from mycli.main_modes.repl import get_prompt, main_repl, set_all_external_titles
from mycli.packages import special
from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme
from mycli.packages.filepaths import dir_path_exists, guess_socket_location
from mycli.packages.prompt_utils import confirm_destructive_query
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.special.main import ArgType
from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count
from mycli.packages.sqlresult import SQLResult
from mycli.packages.ssh_utils import read_ssh_config
from mycli.packages.string_utils import sanitize_terminal_title
from mycli.packages.tabular_output import sql_format
from mycli.sqlcompleter import SQLCompleter
from mycli.sqlexecute import FIELD_TYPES, SQLExecute
Expand Down Expand Up @@ -412,7 +408,9 @@ def change_db(self, arg: str, **_) -> Generator[SQLResult, None, None]:
self.sqlexecute.change_db(arg)
msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'

self.set_all_external_titles()
# todo: this jump back to repl.py is a sign that separation is incomplete.
# also: it should not be needed. Don't titles update on every new prompt?
set_all_external_titles(self)

yield SQLResult(status=msg)

Expand Down Expand Up @@ -908,7 +906,8 @@ def get_output_margin(self, status: str | None = None) -> int:
render_counter = self.prompt_session.app.render_counter
else:
render_counter = 0
self.prompt_lines = self.get_prompt(self.prompt_format, render_counter).count('\n') + 1
# todo: this jump back to get_prompt() in repl.py is a sign that separation is incomplete
self.prompt_lines = get_prompt(self, self.prompt_format, render_counter).count('\n') + 1
margin = self.get_reserved_space() + self.prompt_lines
if special.is_timing_enabled():
margin += 1
Expand Down Expand Up @@ -1045,155 +1044,6 @@ def get_completions(self, text: str, cursor_position: int) -> Iterable[Completio
with self._completer_lock:
return self.completer.get_completions(Document(text=text, cursor_position=cursor_position), None)

def set_all_external_titles(self) -> None:
self.set_external_terminal_tab_title()
self.set_external_terminal_window_title()
self.set_external_multiplex_window_title()
self.set_external_multiplex_pane_title()

def set_external_terminal_tab_title(self) -> None:
if not self.terminal_tab_title_format:
return
if not self.prompt_session:
return
if not sys.stderr.isatty():
return
title = sanitize_terminal_title(self.get_prompt(self.terminal_tab_title_format, self.prompt_session.app.render_counter))
print(f'\x1b]1;{title}\a', file=sys.stderr, end='')
sys.stderr.flush()

def set_external_terminal_window_title(self) -> None:
if not self.terminal_window_title_format:
return
if not self.prompt_session:
return
if not sys.stderr.isatty():
return
title = sanitize_terminal_title(self.get_prompt(self.terminal_window_title_format, self.prompt_session.app.render_counter))
print(f'\x1b]2;{title}\a', file=sys.stderr, end='')
sys.stderr.flush()

def set_external_multiplex_window_title(self) -> None:
if not self.multiplex_window_title_format:
return
if not os.getenv('TMUX'):
return
if not self.prompt_session:
return
title = sanitize_terminal_title(self.get_prompt(self.multiplex_window_title_format, self.prompt_session.app.render_counter))
try:
subprocess.run(
['tmux', 'rename-window', title],
check=False,
stdin=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except FileNotFoundError:
pass

def set_external_multiplex_pane_title(self) -> None:
if not self.multiplex_pane_title_format:
return
if not os.getenv('TMUX'):
return
if not self.prompt_session:
return
if not sys.stderr.isatty():
return
title = sanitize_terminal_title(self.get_prompt(self.multiplex_pane_title_format, self.prompt_session.app.render_counter))
print(f'\x1b]2;{title}\x1b\\', file=sys.stderr, end='')
sys.stderr.flush()

def get_custom_toolbar(self, toolbar_format: str) -> ANSI:
if not self.prompt_session:
return ANSI('')
if not self.prompt_session.app:
return ANSI('')
if self.prompt_session.app.current_buffer.text:
return self.last_custom_toolbar_message
toolbar = self.get_prompt(toolbar_format, self.prompt_session.app.render_counter)
toolbar = toolbar.replace("\\x1b", "\x1b")
self.last_custom_toolbar_message = ANSI(toolbar)
return self.last_custom_toolbar_message

# Memoizing a method leaks the instance, but we only expect one MyCli instance.
# Before memoizing, get_prompt() was called dozens of times per prompt.
# Even after memoizing, get_prompt's logic gets called twice per prompt, which
# should be addressed, because some format strings take a trip to the server.
@functools.lru_cache(maxsize=256) # noqa: B019
def get_prompt(self, string: str, _render_counter: int) -> str:
sqlexecute = self.sqlexecute
assert sqlexecute is not None
assert sqlexecute.server_info is not None
assert sqlexecute.server_info.species is not None
if self.login_path and self.login_path_as_host:
prompt_host = self.login_path
elif sqlexecute.host is not None:
prompt_host = sqlexecute.host
else:
prompt_host = DEFAULT_HOST
short_prompt_host, _, _ = prompt_host.partition('.')
if re.match(r'^[\d\.]+$', short_prompt_host):
short_prompt_host = prompt_host
now = datetime.now()
backslash_placeholder = '\ufffc_backslash'
string = string.replace('\\\\', backslash_placeholder)
string = string.replace("\\u", sqlexecute.user or "(none)")
string = string.replace("\\h", prompt_host or "(none)")
string = string.replace("\\H", short_prompt_host or "(none)")
string = string.replace("\\d", sqlexecute.dbname or "(none)")
string = string.replace("\\t", sqlexecute.server_info.species.name)
string = string.replace("\\n", "\n")
string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y"))
string = string.replace("\\m", now.strftime("%M"))
string = string.replace("\\P", now.strftime("%p"))
string = string.replace("\\R", now.strftime("%H"))
string = string.replace("\\r", now.strftime("%I"))
string = string.replace("\\s", now.strftime("%S"))
string = string.replace("\\p", str(sqlexecute.port))
string = string.replace("\\j", os.path.basename(sqlexecute.socket or '(none)'))
string = string.replace("\\J", sqlexecute.socket or '(none)')
string = string.replace("\\k", os.path.basename(sqlexecute.socket or str(sqlexecute.port)))
string = string.replace("\\K", sqlexecute.socket or str(sqlexecute.port))
string = string.replace("\\A", self.dsn_alias or "(none)")
string = string.replace("\\_", " ")
string = string.replace(backslash_placeholder, '\\')

# jump through hoops for the test environment, and for efficiency
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
if '\\y' in string:
with sqlexecute.conn.cursor() as cur:
string = string.replace('\\y', str(get_uptime(cur)) or '(none)')
if '\\Y' in string:
with sqlexecute.conn.cursor() as cur:
string = string.replace('\\Y', format_uptime(str(get_uptime(cur))) or '(none)')
else:
string = string.replace('\\y', '(none)')
string = string.replace('\\Y', '(none)')

if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
if '\\T' in string:
with sqlexecute.conn.cursor() as cur:
string = string.replace('\\T', get_ssl_version(cur) or '(none)')
else:
string = string.replace('\\T', '(none)')

if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
if '\\w' in string:
with sqlexecute.conn.cursor() as cur:
string = string.replace('\\w', str(get_warning_count(cur) or '(none)'))
else:
string = string.replace('\\w', '(none)')
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
if '\\W' in string:
with sqlexecute.conn.cursor() as cur:
string = string.replace('\\W', str(get_warning_count(cur) or ''))
else:
string = string.replace('\\W', '')

return string

def run_query(
self,
query: str,
Expand Down
Loading
Loading