Skip to content

Commit e7c606f

Browse files
committed
move prompt-format methods to main_modes/repl.py
Move these methods to main_modes/repl.py: * set_all_external_titles() * set_external_terminal_tab_title() * set_external_terminal_window_title() * set_external_multiplex_window_title() * set_external_multiplex_pane_title() * get_custom_toolbar() * get_prompt() and other rearrangements needed to effect that change. After the changes, main.py still has two calls to the new functions, which are marked with todo comments regarding the incomplete separation.
1 parent 67a6788 commit e7c606f

7 files changed

Lines changed: 423 additions & 217 deletions

File tree

mycli/clitoolbar.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22

33
from prompt_toolkit.application import get_app
44
from prompt_toolkit.enums import EditingMode
5-
from prompt_toolkit.formatted_text import to_formatted_text
5+
from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text
66
from prompt_toolkit.key_binding.vi_state import InputMode
77

88
from mycli.packages import special
99

1010

11-
def create_toolbar_tokens_func(mycli, show_initial_toolbar_help: Callable, format_string: str | None) -> Callable:
11+
def create_toolbar_tokens_func(
12+
mycli,
13+
show_initial_toolbar_help: Callable[[], bool],
14+
format_string: str | None,
15+
get_custom_toolbar: Callable[[str], AnyFormattedText],
16+
) -> Callable[[], list[tuple[str, str]]]:
1217
"""Return a function that generates the toolbar tokens."""
1318

1419
def get_toolbar_tokens() -> list[tuple[str, str]]:
@@ -73,7 +78,7 @@ def get_toolbar_tokens() -> list[tuple[str, str]]:
7378
else:
7479
amended_format = format_string
7580
result = []
76-
formatted = to_formatted_text(mycli.get_custom_toolbar(amended_format), style='class:bottom-toolbar')
81+
formatted = to_formatted_text(get_custom_toolbar(amended_format), style='class:bottom-toolbar')
7782
result.extend([*formatted]) # coerce to list for mypy
7883

7984
result.extend(dynamic)

mycli/main.py

Lines changed: 6 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
from collections import defaultdict
44
from dataclasses import dataclass
55
from decimal import Decimal
6-
import functools
76
from io import TextIOWrapper
87
import logging
98
import os
109
import re
1110
import shutil
12-
import subprocess
1311
import sys
1412
import threading
1513
import traceback
@@ -73,17 +71,15 @@
7371
from mycli.main_modes.execute import main_execute_from_cli
7472
from mycli.main_modes.list_dsn import main_list_dsn
7573
from mycli.main_modes.list_ssh_config import main_list_ssh_config
76-
from mycli.main_modes.repl import main_repl
74+
from mycli.main_modes.repl import get_prompt, main_repl, set_all_external_titles
7775
from mycli.packages import special
7876
from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme
7977
from mycli.packages.filepaths import dir_path_exists, guess_socket_location
8078
from mycli.packages.prompt_utils import confirm_destructive_query
8179
from mycli.packages.special.favoritequeries import FavoriteQueries
8280
from mycli.packages.special.main import ArgType
83-
from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count
8481
from mycli.packages.sqlresult import SQLResult
8582
from mycli.packages.ssh_utils import read_ssh_config
86-
from mycli.packages.string_utils import sanitize_terminal_title
8783
from mycli.packages.tabular_output import sql_format
8884
from mycli.sqlcompleter import SQLCompleter
8985
from mycli.sqlexecute import FIELD_TYPES, SQLExecute
@@ -412,7 +408,9 @@ def change_db(self, arg: str, **_) -> Generator[SQLResult, None, None]:
412408
self.sqlexecute.change_db(arg)
413409
msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'
414410

415-
self.set_all_external_titles()
411+
# todo: this jump back to repl.py is a sign that separation is incomplete.
412+
# also: it should not be needed. Don't titles update on every new prompt?
413+
set_all_external_titles(self)
416414

417415
yield SQLResult(status=msg)
418416

@@ -908,7 +906,8 @@ def get_output_margin(self, status: str | None = None) -> int:
908906
render_counter = self.prompt_session.app.render_counter
909907
else:
910908
render_counter = 0
911-
self.prompt_lines = self.get_prompt(self.prompt_format, render_counter).count('\n') + 1
909+
# todo: this jump back to get_prompt() in repl.py is a sign that separation is incomplete
910+
self.prompt_lines = get_prompt(self, self.prompt_format, render_counter).count('\n') + 1
912911
margin = self.get_reserved_space() + self.prompt_lines
913912
if special.is_timing_enabled():
914913
margin += 1
@@ -1045,155 +1044,6 @@ def get_completions(self, text: str, cursor_position: int) -> Iterable[Completio
10451044
with self._completer_lock:
10461045
return self.completer.get_completions(Document(text=text, cursor_position=cursor_position), None)
10471046

1048-
def set_all_external_titles(self) -> None:
1049-
self.set_external_terminal_tab_title()
1050-
self.set_external_terminal_window_title()
1051-
self.set_external_multiplex_window_title()
1052-
self.set_external_multiplex_pane_title()
1053-
1054-
def set_external_terminal_tab_title(self) -> None:
1055-
if not self.terminal_tab_title_format:
1056-
return
1057-
if not self.prompt_session:
1058-
return
1059-
if not sys.stderr.isatty():
1060-
return
1061-
title = sanitize_terminal_title(self.get_prompt(self.terminal_tab_title_format, self.prompt_session.app.render_counter))
1062-
print(f'\x1b]1;{title}\a', file=sys.stderr, end='')
1063-
sys.stderr.flush()
1064-
1065-
def set_external_terminal_window_title(self) -> None:
1066-
if not self.terminal_window_title_format:
1067-
return
1068-
if not self.prompt_session:
1069-
return
1070-
if not sys.stderr.isatty():
1071-
return
1072-
title = sanitize_terminal_title(self.get_prompt(self.terminal_window_title_format, self.prompt_session.app.render_counter))
1073-
print(f'\x1b]2;{title}\a', file=sys.stderr, end='')
1074-
sys.stderr.flush()
1075-
1076-
def set_external_multiplex_window_title(self) -> None:
1077-
if not self.multiplex_window_title_format:
1078-
return
1079-
if not os.getenv('TMUX'):
1080-
return
1081-
if not self.prompt_session:
1082-
return
1083-
title = sanitize_terminal_title(self.get_prompt(self.multiplex_window_title_format, self.prompt_session.app.render_counter))
1084-
try:
1085-
subprocess.run(
1086-
['tmux', 'rename-window', title],
1087-
check=False,
1088-
stdin=subprocess.DEVNULL,
1089-
stdout=subprocess.DEVNULL,
1090-
stderr=subprocess.DEVNULL,
1091-
)
1092-
except FileNotFoundError:
1093-
pass
1094-
1095-
def set_external_multiplex_pane_title(self) -> None:
1096-
if not self.multiplex_pane_title_format:
1097-
return
1098-
if not os.getenv('TMUX'):
1099-
return
1100-
if not self.prompt_session:
1101-
return
1102-
if not sys.stderr.isatty():
1103-
return
1104-
title = sanitize_terminal_title(self.get_prompt(self.multiplex_pane_title_format, self.prompt_session.app.render_counter))
1105-
print(f'\x1b]2;{title}\x1b\\', file=sys.stderr, end='')
1106-
sys.stderr.flush()
1107-
1108-
def get_custom_toolbar(self, toolbar_format: str) -> ANSI:
1109-
if not self.prompt_session:
1110-
return ANSI('')
1111-
if not self.prompt_session.app:
1112-
return ANSI('')
1113-
if self.prompt_session.app.current_buffer.text:
1114-
return self.last_custom_toolbar_message
1115-
toolbar = self.get_prompt(toolbar_format, self.prompt_session.app.render_counter)
1116-
toolbar = toolbar.replace("\\x1b", "\x1b")
1117-
self.last_custom_toolbar_message = ANSI(toolbar)
1118-
return self.last_custom_toolbar_message
1119-
1120-
# Memoizing a method leaks the instance, but we only expect one MyCli instance.
1121-
# Before memoizing, get_prompt() was called dozens of times per prompt.
1122-
# Even after memoizing, get_prompt's logic gets called twice per prompt, which
1123-
# should be addressed, because some format strings take a trip to the server.
1124-
@functools.lru_cache(maxsize=256) # noqa: B019
1125-
def get_prompt(self, string: str, _render_counter: int) -> str:
1126-
sqlexecute = self.sqlexecute
1127-
assert sqlexecute is not None
1128-
assert sqlexecute.server_info is not None
1129-
assert sqlexecute.server_info.species is not None
1130-
if self.login_path and self.login_path_as_host:
1131-
prompt_host = self.login_path
1132-
elif sqlexecute.host is not None:
1133-
prompt_host = sqlexecute.host
1134-
else:
1135-
prompt_host = DEFAULT_HOST
1136-
short_prompt_host, _, _ = prompt_host.partition('.')
1137-
if re.match(r'^[\d\.]+$', short_prompt_host):
1138-
short_prompt_host = prompt_host
1139-
now = datetime.now()
1140-
backslash_placeholder = '\ufffc_backslash'
1141-
string = string.replace('\\\\', backslash_placeholder)
1142-
string = string.replace("\\u", sqlexecute.user or "(none)")
1143-
string = string.replace("\\h", prompt_host or "(none)")
1144-
string = string.replace("\\H", short_prompt_host or "(none)")
1145-
string = string.replace("\\d", sqlexecute.dbname or "(none)")
1146-
string = string.replace("\\t", sqlexecute.server_info.species.name)
1147-
string = string.replace("\\n", "\n")
1148-
string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y"))
1149-
string = string.replace("\\m", now.strftime("%M"))
1150-
string = string.replace("\\P", now.strftime("%p"))
1151-
string = string.replace("\\R", now.strftime("%H"))
1152-
string = string.replace("\\r", now.strftime("%I"))
1153-
string = string.replace("\\s", now.strftime("%S"))
1154-
string = string.replace("\\p", str(sqlexecute.port))
1155-
string = string.replace("\\j", os.path.basename(sqlexecute.socket or '(none)'))
1156-
string = string.replace("\\J", sqlexecute.socket or '(none)')
1157-
string = string.replace("\\k", os.path.basename(sqlexecute.socket or str(sqlexecute.port)))
1158-
string = string.replace("\\K", sqlexecute.socket or str(sqlexecute.port))
1159-
string = string.replace("\\A", self.dsn_alias or "(none)")
1160-
string = string.replace("\\_", " ")
1161-
string = string.replace(backslash_placeholder, '\\')
1162-
1163-
# jump through hoops for the test environment, and for efficiency
1164-
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
1165-
if '\\y' in string:
1166-
with sqlexecute.conn.cursor() as cur:
1167-
string = string.replace('\\y', str(get_uptime(cur)) or '(none)')
1168-
if '\\Y' in string:
1169-
with sqlexecute.conn.cursor() as cur:
1170-
string = string.replace('\\Y', format_uptime(str(get_uptime(cur))) or '(none)')
1171-
else:
1172-
string = string.replace('\\y', '(none)')
1173-
string = string.replace('\\Y', '(none)')
1174-
1175-
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
1176-
if '\\T' in string:
1177-
with sqlexecute.conn.cursor() as cur:
1178-
string = string.replace('\\T', get_ssl_version(cur) or '(none)')
1179-
else:
1180-
string = string.replace('\\T', '(none)')
1181-
1182-
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
1183-
if '\\w' in string:
1184-
with sqlexecute.conn.cursor() as cur:
1185-
string = string.replace('\\w', str(get_warning_count(cur) or '(none)'))
1186-
else:
1187-
string = string.replace('\\w', '(none)')
1188-
if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None:
1189-
if '\\W' in string:
1190-
with sqlexecute.conn.cursor() as cur:
1191-
string = string.replace('\\W', str(get_warning_count(cur) or ''))
1192-
else:
1193-
string = string.replace('\\W', '')
1194-
1195-
return string
1196-
11971047
def run_query(
11981048
self,
11991049
query: str,

0 commit comments

Comments
 (0)