Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cli-tools/gpu-dev-cli/gpu_dev_cli/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import subprocess
import re
from typing import Dict, Any
from .config import Config
from .config import Config, CLI_UPGRADE_MESSAGE
from rich.spinner import Spinner


Expand Down Expand Up @@ -34,6 +34,9 @@ def authenticate_user(config: Config) -> Dict[str, Any]:
}

except Exception as e:
# Let upgrade messages pass through without re-wrapping
if CLI_UPGRADE_MESSAGE in str(e):
raise
raise RuntimeError(f"AWS authentication failed: {e}")


Expand Down
18 changes: 18 additions & 0 deletions cli-tools/gpu-dev-cli/gpu_dev_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@
import os
import json
import boto3
from botocore.exceptions import ClientError
from pathlib import Path
from typing import Dict, Any, Optional

# Upgrade message shown when backend infrastructure has changed
CLI_UPGRADE_MESSAGE = """
The GPU Dev service has been updated and requires a newer CLI version.

Please upgrade:
pip install --upgrade git+https://github.com/pytorch/osdc.git@release

For more info: https://github.com/pytorch/osdc
""".strip()


class Config:
"""Zero-config AWS-based configuration"""
Expand Down Expand Up @@ -101,6 +112,13 @@ def get_queue_url(self) -> str:
try:
response = self.sqs_client.get_queue_url(QueueName=self.queue_name)
return response["QueueUrl"]
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code == "AWS.SimpleQueueService.NonExistentQueue":
raise RuntimeError(CLI_UPGRADE_MESSAGE)
raise RuntimeError(
f"Cannot access SQS queue {self.queue_name}. Check AWS permissions: {e}"
)
except Exception as e:
raise RuntimeError(
f"Cannot access SQS queue {self.queue_name}. Check AWS permissions: {e}"
Expand Down
9 changes: 8 additions & 1 deletion cli-tools/gpu-dev-cli/gpu_dev_cli/reservations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from rich.live import Live
from rich.spinner import Spinner

from .config import Config
from .config import Config, CLI_UPGRADE_MESSAGE
from .name_generator import sanitize_name
from . import __version__

Expand Down Expand Up @@ -971,6 +971,13 @@ def get_gpu_availability_by_type(self) -> Optional[Dict[str, Dict[str, Any]]]:

return availability_info

except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "")
if error_code == "ResourceNotFoundException":
raise RuntimeError(CLI_UPGRADE_MESSAGE)
console.print(
f"[red]❌ Error getting GPU availability: {str(e)}[/red]")
return None
except Exception as e:
console.print(
f"[red]❌ Error getting GPU availability: {str(e)}[/red]")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "gpu-dev-cli"
version = "0.3.5"
version = "0.3.6"
description = "CLI tool for PyTorch GPU developer server reservations"
authors = [{name = "PyTorch Team"}]
readme = "cli-tools/gpu-dev-cli/README.md"
Expand Down
4 changes: 2 additions & 2 deletions terraform-gpu-devservers/lambda.tf
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ resource "aws_lambda_function" "reservation_processor" {
HOSTED_ZONE_ID = local.effective_domain_name != "" ? local.hosted_zone_id : ""
SSH_DOMAIN_MAPPINGS_TABLE = local.effective_domain_name != "" ? aws_dynamodb_table.ssh_domain_mappings.name : ""
SSL_CERTIFICATE_ARN = local.effective_domain_name != "" ? aws_acm_certificate.wildcard[0].arn : ""
LAMBDA_VERSION = "0.3.5"
MIN_CLI_VERSION = "0.3.5"
LAMBDA_VERSION = "0.3.6"
MIN_CLI_VERSION = "0.3.6"
DISK_CONTENTS_BUCKET = aws_s3_bucket.disk_contents.bucket
}, local.alb_env_vars)
}
Expand Down