diff --git a/mycli/main.py b/mycli/main.py index 7819b5e5..2dff8bd0 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -90,7 +90,7 @@ from mycli.main_modes.list_dsn import main_list_dsn from mycli.main_modes.list_ssh_config import main_list_ssh_config from mycli.packages import special -from mycli.packages.cli_utils import is_valid_connection_scheme +from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme from mycli.packages.filepaths import dir_path_exists, guess_socket_location from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command from mycli.packages.prompt_utils import confirm, confirm_destructive_query @@ -2706,13 +2706,6 @@ def edit_and_execute(event: KeyPressEvent) -> None: buff.open_in_editor(validate_and_handle=False) -def filtered_sys_argv() -> list[str]: - args = sys.argv[1:] - if args == ['-h']: - args = ['--help'] - return args - - def main() -> int | None: try: result = click_entrypoint.main( diff --git a/mycli/packages/cli_utils.py b/mycli/packages/cli_utils.py index b5e7c5e6..65950130 100644 --- a/mycli/packages/cli_utils.py +++ b/mycli/packages/cli_utils.py @@ -1,5 +1,14 @@ from __future__ import annotations +import sys + + +def filtered_sys_argv() -> list[str]: + args = sys.argv[1:] + if args == ['-h']: + args = ['--help'] + return args + def is_valid_connection_scheme(text: str) -> tuple[bool, str | None]: # exit early if the text does not resemble a DSN URI diff --git a/test/pytests/test_cli_utils.py b/test/pytests/test_cli_utils.py index 7875e2e3..1d01d3e6 100644 --- a/test/pytests/test_cli_utils.py +++ b/test/pytests/test_cli_utils.py @@ -2,11 +2,26 @@ import pytest +from mycli.packages import cli_utils from mycli.packages.cli_utils import ( + filtered_sys_argv, is_valid_connection_scheme, ) +@pytest.mark.parametrize( + ('argv', 'expected'), + [ + (['mycli', '-h'], ['--help']), + (['mycli', '-h', 'example.com'], ['-h', 'example.com']), + ], +) +def test_filtered_sys_argv(monkeypatch, argv, expected): + monkeypatch.setattr(cli_utils.sys, 'argv', argv) + + assert filtered_sys_argv() == expected + + @pytest.mark.parametrize( ('text', 'is_valid', 'invalid_scheme'), [