|
3 | 3 | from collections import defaultdict |
4 | 4 | from dataclasses import dataclass |
5 | 5 | from decimal import Decimal |
6 | | -import functools |
7 | 6 | from io import TextIOWrapper |
8 | 7 | import logging |
9 | 8 | import os |
10 | 9 | import re |
11 | 10 | import shutil |
12 | | -import subprocess |
13 | 11 | import sys |
14 | 12 | import threading |
15 | 13 | import traceback |
|
73 | 71 | from mycli.main_modes.execute import main_execute_from_cli |
74 | 72 | from mycli.main_modes.list_dsn import main_list_dsn |
75 | 73 | from mycli.main_modes.list_ssh_config import main_list_ssh_config |
76 | | -from mycli.main_modes.repl import main_repl |
| 74 | +from mycli.main_modes.repl import get_prompt, main_repl, set_all_external_titles |
77 | 75 | from mycli.packages import special |
78 | 76 | from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme |
79 | 77 | from mycli.packages.filepaths import dir_path_exists, guess_socket_location |
80 | 78 | from mycli.packages.prompt_utils import confirm_destructive_query |
81 | 79 | from mycli.packages.special.favoritequeries import FavoriteQueries |
82 | 80 | from mycli.packages.special.main import ArgType |
83 | | -from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count |
84 | 81 | from mycli.packages.sqlresult import SQLResult |
85 | 82 | from mycli.packages.ssh_utils import read_ssh_config |
86 | | -from mycli.packages.string_utils import sanitize_terminal_title |
87 | 83 | from mycli.packages.tabular_output import sql_format |
88 | 84 | from mycli.sqlcompleter import SQLCompleter |
89 | 85 | from mycli.sqlexecute import FIELD_TYPES, SQLExecute |
@@ -412,7 +408,9 @@ def change_db(self, arg: str, **_) -> Generator[SQLResult, None, None]: |
412 | 408 | self.sqlexecute.change_db(arg) |
413 | 409 | msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"' |
414 | 410 |
|
415 | | - self.set_all_external_titles() |
| 411 | + # todo: this jump back to repl.py is a sign that separation is incomplete. |
| 412 | + # also: it should not be needed. Don't titles update on every new prompt? |
| 413 | + set_all_external_titles(self) |
416 | 414 |
|
417 | 415 | yield SQLResult(status=msg) |
418 | 416 |
|
@@ -908,7 +906,8 @@ def get_output_margin(self, status: str | None = None) -> int: |
908 | 906 | render_counter = self.prompt_session.app.render_counter |
909 | 907 | else: |
910 | 908 | render_counter = 0 |
911 | | - self.prompt_lines = self.get_prompt(self.prompt_format, render_counter).count('\n') + 1 |
| 909 | + # todo: this jump back to get_prompt() in repl.py is a sign that separation is incomplete |
| 910 | + self.prompt_lines = get_prompt(self, self.prompt_format, render_counter).count('\n') + 1 |
912 | 911 | margin = self.get_reserved_space() + self.prompt_lines |
913 | 912 | if special.is_timing_enabled(): |
914 | 913 | margin += 1 |
@@ -1045,155 +1044,6 @@ def get_completions(self, text: str, cursor_position: int) -> Iterable[Completio |
1045 | 1044 | with self._completer_lock: |
1046 | 1045 | return self.completer.get_completions(Document(text=text, cursor_position=cursor_position), None) |
1047 | 1046 |
|
1048 | | - def set_all_external_titles(self) -> None: |
1049 | | - self.set_external_terminal_tab_title() |
1050 | | - self.set_external_terminal_window_title() |
1051 | | - self.set_external_multiplex_window_title() |
1052 | | - self.set_external_multiplex_pane_title() |
1053 | | - |
1054 | | - def set_external_terminal_tab_title(self) -> None: |
1055 | | - if not self.terminal_tab_title_format: |
1056 | | - return |
1057 | | - if not self.prompt_session: |
1058 | | - return |
1059 | | - if not sys.stderr.isatty(): |
1060 | | - return |
1061 | | - title = sanitize_terminal_title(self.get_prompt(self.terminal_tab_title_format, self.prompt_session.app.render_counter)) |
1062 | | - print(f'\x1b]1;{title}\a', file=sys.stderr, end='') |
1063 | | - sys.stderr.flush() |
1064 | | - |
1065 | | - def set_external_terminal_window_title(self) -> None: |
1066 | | - if not self.terminal_window_title_format: |
1067 | | - return |
1068 | | - if not self.prompt_session: |
1069 | | - return |
1070 | | - if not sys.stderr.isatty(): |
1071 | | - return |
1072 | | - title = sanitize_terminal_title(self.get_prompt(self.terminal_window_title_format, self.prompt_session.app.render_counter)) |
1073 | | - print(f'\x1b]2;{title}\a', file=sys.stderr, end='') |
1074 | | - sys.stderr.flush() |
1075 | | - |
1076 | | - def set_external_multiplex_window_title(self) -> None: |
1077 | | - if not self.multiplex_window_title_format: |
1078 | | - return |
1079 | | - if not os.getenv('TMUX'): |
1080 | | - return |
1081 | | - if not self.prompt_session: |
1082 | | - return |
1083 | | - title = sanitize_terminal_title(self.get_prompt(self.multiplex_window_title_format, self.prompt_session.app.render_counter)) |
1084 | | - try: |
1085 | | - subprocess.run( |
1086 | | - ['tmux', 'rename-window', title], |
1087 | | - check=False, |
1088 | | - stdin=subprocess.DEVNULL, |
1089 | | - stdout=subprocess.DEVNULL, |
1090 | | - stderr=subprocess.DEVNULL, |
1091 | | - ) |
1092 | | - except FileNotFoundError: |
1093 | | - pass |
1094 | | - |
1095 | | - def set_external_multiplex_pane_title(self) -> None: |
1096 | | - if not self.multiplex_pane_title_format: |
1097 | | - return |
1098 | | - if not os.getenv('TMUX'): |
1099 | | - return |
1100 | | - if not self.prompt_session: |
1101 | | - return |
1102 | | - if not sys.stderr.isatty(): |
1103 | | - return |
1104 | | - title = sanitize_terminal_title(self.get_prompt(self.multiplex_pane_title_format, self.prompt_session.app.render_counter)) |
1105 | | - print(f'\x1b]2;{title}\x1b\\', file=sys.stderr, end='') |
1106 | | - sys.stderr.flush() |
1107 | | - |
1108 | | - def get_custom_toolbar(self, toolbar_format: str) -> ANSI: |
1109 | | - if not self.prompt_session: |
1110 | | - return ANSI('') |
1111 | | - if not self.prompt_session.app: |
1112 | | - return ANSI('') |
1113 | | - if self.prompt_session.app.current_buffer.text: |
1114 | | - return self.last_custom_toolbar_message |
1115 | | - toolbar = self.get_prompt(toolbar_format, self.prompt_session.app.render_counter) |
1116 | | - toolbar = toolbar.replace("\\x1b", "\x1b") |
1117 | | - self.last_custom_toolbar_message = ANSI(toolbar) |
1118 | | - return self.last_custom_toolbar_message |
1119 | | - |
1120 | | - # Memoizing a method leaks the instance, but we only expect one MyCli instance. |
1121 | | - # Before memoizing, get_prompt() was called dozens of times per prompt. |
1122 | | - # Even after memoizing, get_prompt's logic gets called twice per prompt, which |
1123 | | - # should be addressed, because some format strings take a trip to the server. |
1124 | | - @functools.lru_cache(maxsize=256) # noqa: B019 |
1125 | | - def get_prompt(self, string: str, _render_counter: int) -> str: |
1126 | | - sqlexecute = self.sqlexecute |
1127 | | - assert sqlexecute is not None |
1128 | | - assert sqlexecute.server_info is not None |
1129 | | - assert sqlexecute.server_info.species is not None |
1130 | | - if self.login_path and self.login_path_as_host: |
1131 | | - prompt_host = self.login_path |
1132 | | - elif sqlexecute.host is not None: |
1133 | | - prompt_host = sqlexecute.host |
1134 | | - else: |
1135 | | - prompt_host = DEFAULT_HOST |
1136 | | - short_prompt_host, _, _ = prompt_host.partition('.') |
1137 | | - if re.match(r'^[\d\.]+$', short_prompt_host): |
1138 | | - short_prompt_host = prompt_host |
1139 | | - now = datetime.now() |
1140 | | - backslash_placeholder = '\ufffc_backslash' |
1141 | | - string = string.replace('\\\\', backslash_placeholder) |
1142 | | - string = string.replace("\\u", sqlexecute.user or "(none)") |
1143 | | - string = string.replace("\\h", prompt_host or "(none)") |
1144 | | - string = string.replace("\\H", short_prompt_host or "(none)") |
1145 | | - string = string.replace("\\d", sqlexecute.dbname or "(none)") |
1146 | | - string = string.replace("\\t", sqlexecute.server_info.species.name) |
1147 | | - string = string.replace("\\n", "\n") |
1148 | | - string = string.replace("\\D", now.strftime("%a %b %d %H:%M:%S %Y")) |
1149 | | - string = string.replace("\\m", now.strftime("%M")) |
1150 | | - string = string.replace("\\P", now.strftime("%p")) |
1151 | | - string = string.replace("\\R", now.strftime("%H")) |
1152 | | - string = string.replace("\\r", now.strftime("%I")) |
1153 | | - string = string.replace("\\s", now.strftime("%S")) |
1154 | | - string = string.replace("\\p", str(sqlexecute.port)) |
1155 | | - string = string.replace("\\j", os.path.basename(sqlexecute.socket or '(none)')) |
1156 | | - string = string.replace("\\J", sqlexecute.socket or '(none)') |
1157 | | - string = string.replace("\\k", os.path.basename(sqlexecute.socket or str(sqlexecute.port))) |
1158 | | - string = string.replace("\\K", sqlexecute.socket or str(sqlexecute.port)) |
1159 | | - string = string.replace("\\A", self.dsn_alias or "(none)") |
1160 | | - string = string.replace("\\_", " ") |
1161 | | - string = string.replace(backslash_placeholder, '\\') |
1162 | | - |
1163 | | - # jump through hoops for the test environment, and for efficiency |
1164 | | - if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: |
1165 | | - if '\\y' in string: |
1166 | | - with sqlexecute.conn.cursor() as cur: |
1167 | | - string = string.replace('\\y', str(get_uptime(cur)) or '(none)') |
1168 | | - if '\\Y' in string: |
1169 | | - with sqlexecute.conn.cursor() as cur: |
1170 | | - string = string.replace('\\Y', format_uptime(str(get_uptime(cur))) or '(none)') |
1171 | | - else: |
1172 | | - string = string.replace('\\y', '(none)') |
1173 | | - string = string.replace('\\Y', '(none)') |
1174 | | - |
1175 | | - if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: |
1176 | | - if '\\T' in string: |
1177 | | - with sqlexecute.conn.cursor() as cur: |
1178 | | - string = string.replace('\\T', get_ssl_version(cur) or '(none)') |
1179 | | - else: |
1180 | | - string = string.replace('\\T', '(none)') |
1181 | | - |
1182 | | - if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: |
1183 | | - if '\\w' in string: |
1184 | | - with sqlexecute.conn.cursor() as cur: |
1185 | | - string = string.replace('\\w', str(get_warning_count(cur) or '(none)')) |
1186 | | - else: |
1187 | | - string = string.replace('\\w', '(none)') |
1188 | | - if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: |
1189 | | - if '\\W' in string: |
1190 | | - with sqlexecute.conn.cursor() as cur: |
1191 | | - string = string.replace('\\W', str(get_warning_count(cur) or '')) |
1192 | | - else: |
1193 | | - string = string.replace('\\W', '') |
1194 | | - |
1195 | | - return string |
1196 | | - |
1197 | 1047 | def run_query( |
1198 | 1048 | self, |
1199 | 1049 | query: str, |
|
0 commit comments