Skip to content

Commit a1f0d58

Browse files
[UX] Extend dstack login with interactive selection of url and default project (#3492)
* [UX] Extend `dstack login` with interactive selection of `url` and default project * Updated tests * Addressing issues from the PR review (removed overly-agressibe catching of import and other exceptions when detecting if interactive menu is available). Plus, added handling KeyboardException in `dstack login`.
1 parent ca56b3b commit a1f0d58

File tree

4 files changed

+520
-79
lines changed

4 files changed

+520
-79
lines changed

src/dstack/_internal/cli/commands/login.py

Lines changed: 143 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,34 @@
11
import argparse
22
import queue
3+
import sys
34
import threading
45
import urllib.parse
56
import webbrowser
67
from http.server import BaseHTTPRequestHandler, HTTPServer
7-
from typing import Optional
8+
from typing import Any, Optional
9+
10+
import questionary
11+
from rich.prompt import Prompt as RichPrompt
12+
from rich.text import Text
813

914
from dstack._internal.cli.commands import BaseCommand
15+
from dstack._internal.cli.commands.project import select_default_project
1016
from dstack._internal.cli.utils.common import console, resolve_url
1117
from dstack._internal.core.errors import ClientError, CLIError
1218
from dstack._internal.core.models.users import UserWithCreds
19+
from dstack._internal.utils.logging import get_logger
1320
from dstack.api._public.runs import ConfigManager
1421
from dstack.api.server import APIClient
1522

23+
logger = get_logger(__name__)
24+
25+
is_project_menu_supported = sys.stdin.isatty()
26+
27+
28+
class UrlPrompt(RichPrompt):
29+
def render_default(self, default: Any) -> Text:
30+
return Text(f"({default})", style="bold orange1")
31+
1632

1733
class LoginCommand(BaseCommand):
1834
NAME = "login"
@@ -23,7 +39,7 @@ def _register(self):
2339
self._parser.add_argument(
2440
"--url",
2541
help="The server URL, e.g. https://sky.dstack.ai",
26-
required=True,
42+
required=not is_project_menu_supported,
2743
)
2844
self._parser.add_argument(
2945
"-p",
@@ -33,10 +49,25 @@ def _register(self):
3349
" Selected automatically if the server supports only one provider."
3450
),
3551
)
52+
self._parser.add_argument(
53+
"-y",
54+
"--yes",
55+
help="Don't ask for confirmation (e.g. set first project as default)",
56+
action="store_true",
57+
)
58+
self._parser.add_argument(
59+
"-n",
60+
"--no",
61+
help="Don't ask for confirmation (e.g. do not change default project)",
62+
action="store_true",
63+
)
3664

3765
def _command(self, args: argparse.Namespace):
3866
super()._command(args)
39-
base_url = _normalize_url_or_error(args.url)
67+
url = args.url
68+
if url is None:
69+
url = self._prompt_url()
70+
base_url = _normalize_url_or_error(url)
4071
api_client = APIClient(base_url=base_url)
4172
provider = self._select_provider_or_error(api_client=api_client, provider=args.provider)
4273
server = _LoginServer(api_client=api_client, provider=provider)
@@ -56,9 +87,9 @@ def _command(self, args: argparse.Namespace):
5687
server.shutdown()
5788
if user is None:
5889
raise CLIError("CLI authentication failed")
59-
console.print(f"Logged in as [code]{user.username}[/].")
90+
console.print(f"Logged in as [code]{user.username}[/]")
6091
api_client = APIClient(base_url=base_url, token=user.creds.token)
61-
self._configure_projects(api_client=api_client, user=user)
92+
self._configure_projects(api_client=api_client, user=user, args=args)
6293

6394
def _select_provider_or_error(self, api_client: APIClient, provider: Optional[str]) -> str:
6495
providers = api_client.auth.list_providers()
@@ -67,6 +98,8 @@ def _select_provider_or_error(self, api_client: APIClient, provider: Optional[st
6798
raise CLIError("No SSO providers configured on the server.")
6899
if provider is None:
69100
if len(available_providers) > 1:
101+
if is_project_menu_supported:
102+
return self._prompt_provider(available_providers)
70103
raise CLIError(
71104
"Specify -p/--provider to choose SSO provider"
72105
f" Available providers: {', '.join(available_providers)}"
@@ -79,7 +112,37 @@ def _select_provider_or_error(self, api_client: APIClient, provider: Optional[st
79112
)
80113
return provider
81114

82-
def _configure_projects(self, api_client: APIClient, user: UserWithCreds):
115+
def _prompt_url(self) -> str:
116+
try:
117+
url = UrlPrompt.ask(
118+
"Enter the server URL",
119+
default="https://sky.dstack.ai",
120+
console=console,
121+
)
122+
except KeyboardInterrupt:
123+
console.print("\nCancelled by user")
124+
raise SystemExit(1)
125+
if url is None:
126+
raise CLIError("URL is required")
127+
return url
128+
129+
def _prompt_provider(self, available_providers: list[str]) -> str:
130+
choices = [
131+
questionary.Choice(title=provider, value=provider) for provider in available_providers
132+
]
133+
selected_provider = questionary.select(
134+
message="Select SSO provider:",
135+
choices=choices,
136+
qmark="",
137+
instruction="(↑↓ Enter)",
138+
).ask()
139+
if selected_provider is None:
140+
raise SystemExit(1)
141+
return selected_provider
142+
143+
def _configure_projects(
144+
self, api_client: APIClient, user: UserWithCreds, args: argparse.Namespace
145+
):
83146
projects = api_client.projects.list(include_not_joined=False)
84147
if len(projects) == 0:
85148
console.print(
@@ -89,30 +152,88 @@ def _configure_projects(self, api_client: APIClient, user: UserWithCreds):
89152
return
90153
config_manager = ConfigManager()
91154
default_project = config_manager.get_project_config()
92-
new_default_project = None
93-
for i, project in enumerate(projects):
94-
set_as_default = (
95-
default_project is None
96-
and i == 0
97-
or default_project is not None
98-
and default_project.name == project.project_name
99-
)
100-
if set_as_default:
101-
new_default_project = project
155+
for project in projects:
102156
config_manager.configure_project(
103157
name=project.project_name,
104158
url=api_client.base_url,
105159
token=user.creds.token,
106-
default=set_as_default,
160+
default=False,
107161
)
108162
config_manager.save()
163+
project_names = ", ".join(f"[code]{p.project_name}[/]" for p in projects)
109164
console.print(
110-
f"Configured projects: {', '.join(f'[code]{p.project_name}[/]' for p in projects)}."
165+
f"Added {project_names} project{'' if len(projects) == 1 else 's'} at {config_manager.config_filepath}"
111166
)
112-
if new_default_project:
113-
console.print(
114-
f"Set project [code]{new_default_project.project_name}[/] as default project."
115-
)
167+
168+
project_configs = config_manager.list_project_configs()
169+
170+
if args.no:
171+
return
172+
173+
if args.yes:
174+
if len(projects) > 0:
175+
first_project_from_server = projects[0]
176+
first_project_config = next(
177+
(
178+
pc
179+
for pc in project_configs
180+
if pc.name == first_project_from_server.project_name
181+
),
182+
None,
183+
)
184+
if first_project_config is not None:
185+
config_manager.configure_project(
186+
name=first_project_config.name,
187+
url=first_project_config.url,
188+
token=first_project_config.token,
189+
default=True,
190+
)
191+
config_manager.save()
192+
console.print(
193+
f"Set [code]{first_project_config.name}[/] project as default at {config_manager.config_filepath}"
194+
)
195+
return
196+
197+
if len(project_configs) == 1 or not is_project_menu_supported:
198+
selected_project = None
199+
if len(project_configs) == 1:
200+
selected_project = project_configs[0]
201+
else:
202+
for i, project in enumerate(projects):
203+
set_as_default = (
204+
default_project is None
205+
and i == 0
206+
or default_project is not None
207+
and default_project.name == project.project_name
208+
)
209+
if set_as_default:
210+
selected_project = next(
211+
(pc for pc in project_configs if pc.name == project.project_name),
212+
None,
213+
)
214+
break
215+
if selected_project is not None:
216+
config_manager.configure_project(
217+
name=selected_project.name,
218+
url=selected_project.url,
219+
token=selected_project.token,
220+
default=True,
221+
)
222+
config_manager.save()
223+
console.print(
224+
f"Set [code]{selected_project.name}[/] project as default at {config_manager.config_filepath}"
225+
)
226+
else:
227+
console.print()
228+
selected_project = select_default_project(project_configs, default_project)
229+
if selected_project is not None:
230+
config_manager.configure_project(
231+
name=selected_project.name,
232+
url=selected_project.url,
233+
token=selected_project.token,
234+
default=True,
235+
)
236+
config_manager.save()
116237

117238

118239
class _BadRequestError(Exception):

src/dstack/_internal/cli/commands/project.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,10 @@
22
import sys
33
from typing import Any, Optional, Union
44

5+
import questionary
56
from requests import HTTPError
67
from rich.table import Table
78

8-
try:
9-
import questionary
10-
11-
is_project_menu_supported = sys.stdin.isatty()
12-
except (ImportError, NotImplementedError, AttributeError):
13-
is_project_menu_supported = False
14-
159
import dstack.api.server
1610
from dstack._internal.cli.commands import BaseCommand
1711
from dstack._internal.cli.utils.common import add_row_from_dict, confirm_ask, console
@@ -22,6 +16,8 @@
2216

2317
logger = get_logger(__name__)
2418

19+
is_project_menu_supported = sys.stdin.isatty()
20+
2521

2622
def select_default_project(
2723
project_configs: list[ProjectConfig], default_project: Optional[ProjectConfig]
@@ -57,9 +53,9 @@ def select_default_project(
5753
default_index = i
5854
menu_entries.append((entry, i))
5955

60-
choices = [questionary.Choice(title=entry, value=index) for entry, index in menu_entries] # pyright: ignore[reportPossiblyUnboundVariable]
56+
choices = [questionary.Choice(title=entry, value=index) for entry, index in menu_entries]
6157
default_value = default_index
62-
selected_index = questionary.select( # pyright: ignore[reportPossiblyUnboundVariable]
58+
selected_index = questionary.select(
6359
message="Select the default project:",
6460
choices=choices,
6561
default=default_value, # pyright: ignore[reportArgumentType]

src/dstack/_internal/cli/utils/common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ def configure_logging():
9999

100100
def confirm_ask(prompt, **kwargs) -> bool:
101101
kwargs["console"] = console
102-
return Confirm.ask(prompt=prompt, **kwargs)
102+
try:
103+
return Confirm.ask(prompt=prompt, **kwargs)
104+
except KeyboardInterrupt:
105+
console.print("\nCancelled by user")
106+
raise SystemExit(1)
103107

104108

105109
def add_row_from_dict(table: Table, data: Dict[Union[str, int], Any], **kwargs):

0 commit comments

Comments
 (0)