diff --git a/Makefile b/Makefile index ad50242be..13719e7fb 100644 --- a/Makefile +++ b/Makefile @@ -101,6 +101,9 @@ clean-llama-stack: remove-llama-stack-container ## Remove container and image echo "Removing llama-stack image..."; \ $(CONTAINER_RUNTIME) rmi $(LLAMA_STACK_IMAGE); \ fi +run-llama-stack: ## Start Llama Stack with enriched config (for local service mode) + uv run src/llama_stack_configuration.py -c $(CONFIG) -i $(LLAMA_STACK_CONFIG) -o $(LLAMA_STACK_CONFIG) && \ + uv run llama stack run $(LLAMA_STACK_CONFIG) test-unit: ## Run the unit tests @echo "Running unit tests..." diff --git a/docs/providers.md b/docs/providers.md index 7419c9408..a742ca130 100644 --- a/docs/providers.md +++ b/docs/providers.md @@ -92,51 +92,44 @@ azure_entra_id: #### Llama Stack Configuration Requirements -Because Lightspeed builds on top of Llama Stack, certain configuration fields are required to satisfy the base Llama Stack schema. The config block for the Azure inference provider **must** include `api_key`, `api_base`, and `api_version` — Llama Stack will fail to start if any of these are missing. +Because Lightspeed builds on top of Llama Stack, certain configuration fields are required to satisfy the base Llama Stack schema. The config block for the Azure inference provider **must** include `base_url` and `api_version`. When using Entra ID authentication, `api_key` is not required to be configured, since the API key is acquired and passed automatically at runtime. -**Important:** The `api_key` field must be set to `${env.AZURE_API_KEY}` exactly as shown below. This is not optional — Lightspeed uses this specific environment variable name as a placeholder for injection of the Entra ID access token. Using a different variable name will break the authentication flow. +When `azure_entra_id` is configured in Lightspeed, config enrichment automatically sets `model_validation: false` on the `remote::azure` provider so Llama Stack can start without validating models against Azure at startup. ```yaml inference: - provider_id: azure provider_type: remote::azure config: - api_key: ${env.AZURE_API_KEY} # Must be exactly this - placeholder for Entra ID token - api_base: ${env.AZURE_API_BASE} + # api_key: ${env.AZURE_API_KEY} # Can be omitted when Entra ID configured in LCORE + base_url: ${env.AZURE_API_BASE} api_version: 2025-01-01-preview + model_validation: false # added automatically by Lightspeed enrichment ``` -**How it works:** At startup, Lightspeed acquires an Entra ID access token and stores it in the `AZURE_API_KEY` environment variable. When Llama Stack initializes, it reads the config, substitutes `${env.AZURE_API_KEY}` with the token value, and uses it to authenticate with Azure OpenAI. Llama Stack also calls `models.list()` during initialization to validate provider connectivity, which is why the token must be available before client initialization. +**How it works:** Llama Stack defers Azure authentication to inference time. Lightspeed acquires Entra ID tokens at runtime and passes them via the `X-LlamaStack-Provider-Data` header (`azure_api_key`, `azure_api_base`). #### Access Token Lifecycle and Management -**Library mode startup:** +**Lightspeed startup (library and service mode):** 1. Lightspeed reads your Entra ID configuration -2. Acquires an initial access token from Microsoft Entra ID -3. Stores the token in the `AZURE_API_KEY` environment variable -4. **Then** initializes the Llama Stack library client +2. Does not acquire or cache access tokens at startup—authentication is deferred until request time +3. Initializes the Llama Stack client without Azure credentials; credentials are supplied later via `X-LlamaStack-Provider-Data` when an Azure model is used -This ordering is critical because Llama Stack calls `models.list()` during initialization to validate provider connectivity. If the token is not set before client initialization, Azure requests will fail with authentication errors. - -**Service mode startup:** - -When running Llama Stack as a separate service, Lightspeed runs a pre-startup script that: -1. Reads the Entra ID configuration -2. Acquires an initial access token -3. Writes the token to the `AZURE_API_KEY` environment variable -4. **Then** Llama Stack service starts - -This initial token is used solely for the `models.list()` validation call during Llama Stack startup. After startup, Lightspeed manages token refresh independently and passes fresh tokens via request headers. +**Llama Stack service startup (container mode):** +1. Config enrichment sets `model_validation: false` on the Azure provider +2. Llama Stack starts without authenticating models against Azure +3. Lightspeed connects to this service at startup without Azure credentials; tokens are added only for Azure inference requests **During inference requests:** 1. Before each request, Lightspeed checks if the token has expired -2. If expired, a new token is automatically acquired and the environment variable is updated -3. For library mode: the Llama Stack client is reloaded to pick up the new token -4. For service mode: the token is passed via `X-LlamaStack-Provider-Data` request headers +2. If expired, a new token is automatically acquired and cached in memory +3. The token is passed via `X-LlamaStack-Provider-Data` (library and service mode) **Token security:** - Access tokens are wrapped in `SecretStr` to prevent accidental logging -- Tokens are stored only in the `AZURE_API_KEY` environment variable (single source of truth) +- Tokens are cached in `AzureEntraIDManager` singleton class +- Inference uses `X-LlamaStack-Provider-Data` headers - Each Uvicorn worker maintains its own token lifecycle independently **Token validity:** diff --git a/docs/rag_guide.md b/docs/rag_guide.md index 598272ea7..548f3d0b0 100644 --- a/docs/rag_guide.md +++ b/docs/rag_guide.md @@ -83,7 +83,6 @@ The script reads your `lightspeed-stack.yaml` configuration and enriches a base - `-c, --config`: Lightspeed config file (default: `lightspeed-stack.yaml`) - `-i, --input`: Input Llama Stack config (default: `run.yaml`) - `-o, --output`: Output enriched config (default: `run_.yaml`) -- `-e, --env-file`: Path to .env file for AZURE_API_KEY (default: `.env`) > [!TIP] > Use this script to generate your initial `run.yaml` configuration, then manually customize as needed for your specific setup. diff --git a/examples/azure-run.yaml b/examples/azure-run.yaml index 894e24528..91cc92bdf 100644 --- a/examples/azure-run.yaml +++ b/examples/azure-run.yaml @@ -22,9 +22,9 @@ providers: - provider_id: azure provider_type: remote::azure config: - api_key: ${env.AZURE_API_KEY} base_url: https://ols-test.openai.azure.com/openai/v1 api_version: 2024-02-15-preview + model_validation: false - provider_id: openai provider_type: remote::openai config: diff --git a/scripts/llama-stack-entrypoint.sh b/scripts/llama-stack-entrypoint.sh index a7eeb797b..f8d9aefc6 100755 --- a/scripts/llama-stack-entrypoint.sh +++ b/scripts/llama-stack-entrypoint.sh @@ -7,7 +7,6 @@ set -e INPUT_CONFIG="${LLAMA_STACK_CONFIG:-/opt/app-root/run.yaml}" ENRICHED_CONFIG="/opt/app-root/run.yaml" LIGHTSPEED_CONFIG="${LIGHTSPEED_CONFIG:-/opt/app-root/lightspeed-stack.yaml}" -ENV_FILE="/opt/app-root/.env" # Enrich config if lightspeed config exists if [ -f "$LIGHTSPEED_CONFIG" ]; then @@ -16,14 +15,7 @@ if [ -f "$LIGHTSPEED_CONFIG" ]; then python3 /opt/app-root/llama_stack_configuration.py \ -c "$LIGHTSPEED_CONFIG" \ -i "$INPUT_CONFIG" \ - -o "$ENRICHED_CONFIG" \ - -e "$ENV_FILE" 2>&1 || ENRICHMENT_FAILED=1 - - # Source .env if generated (contains AZURE_API_KEY) - if [ -f "$ENV_FILE" ]; then - # shellcheck source=/dev/null - set -a && . "$ENV_FILE" && set +a - fi + -o "$ENRICHED_CONFIG" 2>&1 || ENRICHMENT_FAILED=1 if [ -f "$ENRICHED_CONFIG" ] && [ "$ENRICHMENT_FAILED" -eq 0 ]; then echo "Using enriched config: $ENRICHED_CONFIG" diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 7b985e9c0..f7fd5f632 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -54,7 +54,6 @@ is_context_length_error, prepare_input, store_query_results, - update_azure_token, validate_attachments_metadata, validate_model_provider_override, ) @@ -204,7 +203,7 @@ async def query_endpoint_handler( and AzureEntraIDManager().is_token_expired and AzureEntraIDManager().refresh_token() ): - client = await update_azure_token(client) + client = await AsyncLlamaStackClientHolder().update_azure_token() # Retrieve response using Responses API turn_summary = await retrieve_response( diff --git a/src/app/endpoints/responses.py b/src/app/endpoints/responses.py index 2705a27dd..9027885d1 100644 --- a/src/app/endpoints/responses.py +++ b/src/app/endpoints/responses.py @@ -78,7 +78,6 @@ handle_known_apistatus_errors, is_context_length_error, store_query_results, - update_azure_token, validate_model_provider_override, ) from utils.quota import check_tokens_available, get_available_quotas @@ -405,7 +404,7 @@ async def responses_endpoint_handler( and AzureEntraIDManager().is_token_expired and AzureEntraIDManager().refresh_token() ): - client = await update_azure_token(client) + client = await AsyncLlamaStackClientHolder().update_azure_token() input_text = ( original_request.input diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index 08555a4d5..8e1838b93 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -20,6 +20,7 @@ import constants from authentication import get_auth_dependency from authentication.interface import AuthTuple +from authorization.azure_token_manager import AzureEntraIDManager from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration @@ -327,6 +328,16 @@ async def _call_llm( """ client = AsyncLlamaStackClientHolder().get_client() resolved_model_id = model_id or await _get_default_model_id() + + # Handle Azure token refresh if needed + if ( + resolved_model_id.startswith("azure") + and AzureEntraIDManager().is_entra_id_configured + and AzureEntraIDManager().is_token_expired + and AzureEntraIDManager().refresh_token() + ): + client = await AsyncLlamaStackClientHolder().update_azure_token() + logger.debug("Using model %s for rlsapi v1 inference", resolved_model_id) response = await client.responses.create( diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 6079d2aa4..fe1820920 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -92,7 +92,6 @@ is_context_length_error, prepare_input, store_query_results, - update_azure_token, update_conversation_topic_summary, validate_attachments_metadata, validate_model_provider_override, @@ -262,7 +261,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals and AzureEntraIDManager().is_token_expired and AzureEntraIDManager().refresh_token() ): - client = await update_azure_token(client) + client = await AsyncLlamaStackClientHolder().update_azure_token() request_id = get_suid() diff --git a/src/app/main.py b/src/app/main.py index af42dd9f5..e5c45d9c3 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -77,15 +77,6 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: initialize_sentry() - azure_config = configuration.configuration.azure_entra_id - if azure_config is not None: - AzureEntraIDManager().set_config(azure_config) - if not AzureEntraIDManager().refresh_token(): - logger.warning( - "Failed to refresh Azure token at startup. " - "Token refresh will be retried on next Azure request." - ) - llama_stack_config = configuration.configuration.llama_stack await AsyncLlamaStackClientHolder().load(llama_stack_config) client = AsyncLlamaStackClientHolder().get_client() @@ -104,6 +95,11 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]: ) raise + azure_entra_id_config = configuration.configuration.azure_entra_id + if azure_entra_id_config is not None: + AzureEntraIDManager().set_config(azure_entra_id_config) + azure_base_url = await AsyncLlamaStackClientHolder().get_azure_base_url() + AzureEntraIDManager().set_base_url(azure_base_url) logger.info("Registering MCP servers") await register_mcp_servers_async(logger, configuration.configuration) logger.info("App startup complete") diff --git a/src/authorization/azure_token_manager.py b/src/authorization/azure_token_manager.py index bc1cf323d..702efe465 100644 --- a/src/authorization/azure_token_manager.py +++ b/src/authorization/azure_token_manager.py @@ -1,6 +1,5 @@ """Azure Entra ID token manager for Azure OpenAI authentication.""" -import os import time from typing import Optional @@ -34,7 +33,13 @@ class AzureEntraIDManager(metaclass=Singleton): def __init__(self) -> None: """Initialize the token manager with empty state.""" self._expires_on: int = 0 + self._access_token: SecretStr = SecretStr("") self._entra_id_config: Optional[AzureEntraIdConfiguration] = None + self._azure_base_url: Optional[str] = None + + def set_base_url(self, base_url: Optional[str]) -> None: + """Set the Azure API base.""" + self._azure_base_url = base_url def set_config(self, azure_config: AzureEntraIdConfiguration) -> None: """Set the Azure Entra ID configuration.""" @@ -53,8 +58,24 @@ def is_token_expired(self) -> bool: @property def access_token(self) -> SecretStr: - """Return the access token from environment variable as SecretStr.""" - return SecretStr(os.environ.get("AZURE_API_KEY", "")) + """Return the cached access token.""" + return self._access_token + + @property + def azure_base_url(self) -> Optional[str]: + """Return the cached Azure API base.""" + return self._azure_base_url + + def build_azure_provider_data(self) -> Optional[dict[str, str]]: + """Build azure_api_key and azure_base_url entries for provider data. + + Returns: + Provider data dict when a token and base_url are available. + """ + token = self.access_token.get_secret_value() + if not token or self.azure_base_url is None: + return None + return {"azure_api_key": token, "azure_api_base": self.azure_base_url} def refresh_token(self) -> bool: """Refresh the cached Azure access token. @@ -76,9 +97,9 @@ def refresh_token(self) -> bool: return False def _update_access_token(self, token: str, expires_on: int) -> None: - """Update the token in env var and track expiration time.""" + """Update the cached token and track expiration time.""" + self._access_token = SecretStr(token) self._expires_on = expires_on - TOKEN_EXPIRATION_LEEWAY - os.environ["AZURE_API_KEY"] = token expiry_time = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(self._expires_on) ) diff --git a/src/client.py b/src/client.py index a503c0094..8fd1e0370 100644 --- a/src/client.py +++ b/src/client.py @@ -3,15 +3,21 @@ import json import os import tempfile -from typing import Optional +from typing import Optional, cast import yaml from fastapi import HTTPException from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient +from authorization.azure_token_manager import AzureEntraIDManager from configuration import configuration -from llama_stack_configuration import YamlDumper, enrich_byok_rag, enrich_solr +from llama_stack_configuration import ( + YamlDumper, + enrich_azure_entra_id_inference, + enrich_byok_rag, + enrich_solr, +) from log import get_logger from models.api.responses.error import ServiceUnavailableResponse from models.config import LlamaStackConfiguration @@ -90,6 +96,12 @@ def _enrich_library_config(self, input_config_path: str) -> str: # Enrichment: Solr - enabled when "okp" appears in either inline or tool list enrich_solr(ls_config, config.rag.model_dump(), config.okp.model_dump()) + # Enrichment: Azure Entra ID deferred auth + entra_id_config = ( + config.azure_entra_id.model_dump() if config.azure_entra_id else None + ) + enrich_azure_entra_id_inference(ls_config, entra_id_config) + enriched_path = os.path.join( tempfile.gettempdir(), "llama_stack_enriched_config.yaml" ) @@ -211,23 +223,35 @@ async def check_model_available(self, model_id: str) -> tuple[bool, str]: ) return False, f"Model {model_id} not found in model registry" - def update_provider_data(self, updates: dict[str, str]) -> AsyncLlamaStackClient: - """Update provider data headers for service client. - - For use with service mode only. - - Args: - updates: Key-value pairs to merge into provider data header. + async def update_azure_token(self) -> AsyncLlamaStackClient: + """Apply cached Azure credentials and replace the held client. Returns: - The updated client instance. + The new client instance assigned to this holder. """ - if not self._lsc: - raise RuntimeError( - "AsyncLlamaStackClient has not been initialised. Ensure 'load(..)' has been called." + updates = AzureEntraIDManager().build_azure_provider_data() + if not updates: + return self.get_client() + + if self.is_library_client: + if not self._config_path: + logger.warning("Cannot update Azure token: config path not set") + return self.get_client() + + current_provider_data = dict( + cast(AsyncLlamaStackAsLibraryClient, self._lsc).provider_data or {} + ) + current_provider_data.update(updates) + client = AsyncLlamaStackAsLibraryClient( + self._config_path, provider_data=current_provider_data ) + await client.initialize() + self._lsc = client + return client - current_headers = self._lsc.default_headers or {} + # Service client mode + current_client = self.get_client() + current_headers = current_client.default_headers or {} provider_data_json = current_headers.get("X-LlamaStack-Provider-Data") try: @@ -242,5 +266,32 @@ def update_provider_data(self, updates: dict[str, str]) -> AsyncLlamaStackClient "X-LlamaStack-Provider-Data": json.dumps(provider_data), } - self._lsc = self._lsc.copy(set_default_headers=updated_headers) # type: ignore[arg-type] - return self._lsc + updated_client = current_client.copy( + set_default_headers=updated_headers # type: ignore[arg-type] + ) + self._lsc = updated_client + return updated_client + + async def get_azure_base_url(self) -> Optional[str]: + """ + Retrieve the Azure base_url endpoint from the remote Llama Stack provider configuration. + + Returns: + Optional[str]: The Azure base_url if available, otherwise None. + """ + if not self._lsc: + return None + + try: + providers = await self._lsc.providers.list() + except (APIConnectionError, APIStatusError) as err: + logger.warning("Failed to list providers for Azure base_url: %s", err) + return None + + for provider in providers: + if provider.provider_type != "remote::azure": + continue + base = provider.config.get("base_url") + if base is not None: + return str(base) + return None diff --git a/src/llama_stack_configuration.py b/src/llama_stack_configuration.py index 050ffabd5..ca0775bcf 100644 --- a/src/llama_stack_configuration.py +++ b/src/llama_stack_configuration.py @@ -5,15 +5,11 @@ 2. As a module: `from llama_stack_configuration import generate_configuration` """ -import os from argparse import ArgumentParser -from pathlib import Path from typing import Any, Optional from urllib.parse import urljoin import yaml -from azure.core.exceptions import ClientAuthenticationError -from azure.identity import ClientSecretCredential, CredentialUnavailableError from llama_stack.core.stack import replace_env_vars import constants @@ -47,71 +43,39 @@ def increase_indent(self, flow: bool = False, indentless: bool = False) -> None: # ============================================================================= -def setup_azure_entra_id_token( - azure_config: Optional[dict[str, Any]], env_file: str +def enrich_azure_entra_id_inference( + ls_config: dict[str, Any], + azure_entra_id: Optional[dict[str, Any]], ) -> None: - """Generate Azure Entra ID access token and write to .env file. + """Enrich remote::azure inference provider for Entra ID authentication. - Skips generation if AZURE_API_KEY is already set (e.g., orchestrator-injected). - """ - # Skip if already injected by orchestrator (secure production setup) - if os.environ.get("AZURE_API_KEY"): - logger.info("Azure Entra ID: AZURE_API_KEY already set, skipping generation") - return - - if azure_config is None: - logger.info("Azure Entra ID: Not configured, skipping") - return + When Azure Entra ID is configured, the remote::azure inference provider is enriched + with model_validation=false to defer model validation to runtime. - tenant_id = azure_config.get("tenant_id") - client_id = azure_config.get("client_id") - client_secret = azure_config.get("client_secret") - scope = azure_config.get("scope", "https://cognitiveservices.azure.com/.default") + Parameters: + ls_config (dict[str, Any]): Mutable Llama Stack configuration dictionary to update. + azure_entra_id (Optional[dict[str, Any]]): Lightspeed azure_entra_id block, + or None. - if not all([tenant_id, client_id, client_secret]): - logger.warning( - "Azure Entra ID: Missing required fields (tenant_id, client_id, client_secret)" - ) + Returns: + None: The configuration is modified in place. + """ + if azure_entra_id is None: return - try: - credential = ClientSecretCredential( - tenant_id=str(tenant_id), - client_id=str(client_id), - client_secret=str(client_secret), - ) + inference_providers = ls_config.get("providers", {}).get("inference", []) - token = credential.get_token(scope) - - # Write to .env file - # Create file if it doesn't exist - Path(env_file).touch() - - lines = [] - with open(env_file, "r", encoding="utf-8") as f: - lines = f.readlines() - - # Update or add AZURE_API_KEY - key_found = False - for i, line in enumerate(lines): - if line.startswith("AZURE_API_KEY="): - lines[i] = f"AZURE_API_KEY={token.token}\n" - key_found = True - break - - if not key_found: - lines.append(f"AZURE_API_KEY={token.token}\n") - - with open(env_file, "w", encoding="utf-8") as f: - f.writelines(lines) + for provider in inference_providers: + if provider.get("provider_type") != "remote::azure": + continue + provider_config = provider.setdefault("config", {}) + provider_config["model_validation"] = False logger.info( - "Azure Entra ID: Access token set in env and written to %s", env_file + "Azure Entra ID: configured remote::azure provider with " + "model_validation=false" ) - except (ClientAuthenticationError, CredentialUnavailableError) as e: - logger.error("Azure Entra ID: Failed to generate token: %s", e) - # ============================================================================= # Enrichment: BYOK RAG @@ -602,7 +566,6 @@ def generate_configuration( input_file: str, output_file: str, config: dict[str, Any], - env_file: str = ".env", ) -> None: """Generate enriched Llama Stack configuration for service/container mode. @@ -610,7 +573,6 @@ def generate_configuration( input_file: Path to input Llama Stack config output_file: Path to write enriched config config: Lightspeed config dict (from YAML) - env_file: Path to .env file """ logger.info("Reading Llama Stack configuration from file %s", input_file) @@ -619,8 +581,8 @@ def generate_configuration( dedupe_providers_vector_io(ls_config) - # Enrichment: Azure Entra ID token - setup_azure_entra_id_token(config.get("azure_entra_id"), env_file) + # Enrichment: Azure Entra ID deferred auth + enrich_azure_entra_id_inference(ls_config, config.get("azure_entra_id")) # Enrichment: BYOK RAG enrich_byok_rag(ls_config, config.get("byok_rag", [])) @@ -664,19 +626,12 @@ def main() -> None: default="run_.yaml", help="Output enriched config (default: run_.yaml)", ) - parser.add_argument( - "-e", - "--env-file", - default=".env", - help="Path to .env file for AZURE_API_KEY (default: .env)", - ) args = parser.parse_args() with open(args.config, "r", encoding="utf-8") as f: config = yaml.safe_load(f) - config = replace_env_vars(config) - generate_configuration(args.input, args.output, config, args.env_file) + generate_configuration(args.input, args.output, config) if __name__ == "__main__": diff --git a/src/utils/query.py b/src/utils/query.py index 9daa086aa..aefe121fa 100644 --- a/src/utils/query.py +++ b/src/utils/query.py @@ -6,10 +6,6 @@ import psycopg2 from fastapi import HTTPException -from llama_stack_client import ( - APIConnectionError, - AsyncLlamaStackClient, -) from llama_stack_client import ( APIStatusError as LLSApiStatusError, ) @@ -20,9 +16,7 @@ import constants from app.database import get_session -from authorization.azure_token_manager import AzureEntraIDManager from cache.cache_error import CacheError -from client import AsyncLlamaStackClientHolder from configuration import configuration from log import get_logger from models.api.requests import QueryRequest @@ -32,7 +26,6 @@ InternalServerErrorResponse, PromptTooLongResponse, QuotaExceededResponse, - ServiceUnavailableResponse, UnprocessableEntityResponse, ) from models.cache_entry import CacheEntry @@ -173,52 +166,6 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) -async def update_azure_token( - client: AsyncLlamaStackClient, -) -> AsyncLlamaStackClient: - """ - Update the client with a fresh Azure token. - - Updates the client with the fresh Azure token. Should be called after - verifying that token refresh is needed and successful. - - Args: - client: The current AsyncLlamaStackClient instance - - Returns: - AsyncLlamaStackClient: The client instance (reloaded or updated with fresh token) - """ - if AsyncLlamaStackClientHolder().is_library_client: - library_client: AsyncLlamaStackClient = ( - await AsyncLlamaStackClientHolder().reload_library_client() - ) - return library_client - try: - providers = await client.providers.list() - azure_config = next( - p.config for p in providers if p.provider_type == "remote::azure" - ) - except APIConnectionError as e: - error_response = ServiceUnavailableResponse( - backend_name="Llama Stack", - cause=str(e), - ) - raise HTTPException(**error_response.model_dump()) from e - except LLSApiStatusError as e: - error_response = InternalServerErrorResponse.generic() - raise HTTPException(**error_response.model_dump()) from e - - updated_client: ( - AsyncLlamaStackClient - ) = AsyncLlamaStackClientHolder().update_provider_data( - { - "azure_api_key": AzureEntraIDManager().access_token.get_secret_value(), - "azure_api_base": str(azure_config.get("api_base")), - } - ) - return updated_client - - def prepare_input( query_request: QueryRequest, inline_rag_context: Optional[str] = None ) -> str: diff --git a/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-openai.yaml b/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-openai.yaml index 3efea3fc1..7e6693fb4 100644 --- a/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-openai.yaml +++ b/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-openai.yaml @@ -154,18 +154,13 @@ spec: INPUT_CONFIG="${LLAMA_STACK_CONFIG:-/opt/app-root/run.yaml}" ENRICHED_CONFIG="/opt/app-root/run.yaml" LIGHTSPEED_CONFIG="${LIGHTSPEED_CONFIG:-/opt/app-root/lightspeed-stack.yaml}" - ENV_FILE="/opt/app-root/.env" if [[ -f "$LIGHTSPEED_CONFIG" ]]; then echo "Enriching llama-stack config..." ENRICHMENT_FAILED=0 python3 /opt/app-root/llama_stack_configuration.py \ -c "$LIGHTSPEED_CONFIG" \ -i "$INPUT_CONFIG" \ - -o "$ENRICHED_CONFIG" \ - -e "$ENV_FILE" 2>&1 || ENRICHMENT_FAILED=1 - if [[ -f "$ENV_FILE" ]]; then - set -a && . "$ENV_FILE" && set +a - fi + -o "$ENRICHED_CONFIG" 2>&1 || ENRICHMENT_FAILED=1 if [[ -f "$ENRICHED_CONFIG" ]] && [[ "$ENRICHMENT_FAILED" -eq 0 ]]; then echo "Using enriched config: $ENRICHED_CONFIG" restore_rag_seed diff --git a/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-prow.yaml b/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-prow.yaml index 757933c3d..789d42852 100644 --- a/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-prow.yaml +++ b/tests/e2e-prow/rhoai/manifests/lightspeed/llama-stack-prow.yaml @@ -144,18 +144,13 @@ spec: INPUT_CONFIG="${LLAMA_STACK_CONFIG:-/opt/app-root/run.yaml}" ENRICHED_CONFIG="/opt/app-root/run.yaml" LIGHTSPEED_CONFIG="${LIGHTSPEED_CONFIG:-/opt/app-root/lightspeed-stack.yaml}" - ENV_FILE="/opt/app-root/.env" if [[ -f "$LIGHTSPEED_CONFIG" ]]; then echo "Enriching llama-stack config..." ENRICHMENT_FAILED=0 python3 /opt/app-root/src/llama_stack_configuration.py \ -c "$LIGHTSPEED_CONFIG" \ -i "$INPUT_CONFIG" \ - -o "$ENRICHED_CONFIG" \ - -e "$ENV_FILE" 2>&1 || ENRICHMENT_FAILED=1 - if [[ -f "$ENV_FILE" ]]; then - set -a && . "$ENV_FILE" && set +a - fi + -o "$ENRICHED_CONFIG" 2>&1 || ENRICHMENT_FAILED=1 if [[ -f "$ENRICHED_CONFIG" ]] && [[ "$ENRICHMENT_FAILED" -eq 0 ]]; then echo "Using enriched config: $ENRICHED_CONFIG" restore_rag_seed diff --git a/tests/e2e/configs/run-azure.yaml b/tests/e2e/configs/run-azure.yaml index 27e4022a8..6f627ba8b 100644 --- a/tests/e2e/configs/run-azure.yaml +++ b/tests/e2e/configs/run-azure.yaml @@ -22,10 +22,10 @@ providers: - provider_id: azure provider_type: remote::azure config: - api_key: ${env.AZURE_API_KEY} base_url: https://ols-test.openai.azure.com/openai/v1 api_version: 2024-02-15-preview allowed_models: ["gpt-4o-mini"] + model_validation: false - provider_id: openai provider_type: remote::openai config: diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index eba6dd711..5a6b43684 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -541,19 +541,8 @@ async def test_query_azure_token_refresh( ) mock_updated_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) - mock_response_obj_updated = mocker.Mock() - mock_response_obj_updated.output = [] - mock_updated_client.responses = mocker.Mock() - mock_updated_client.responses.create = mocker.AsyncMock( - return_value=mock_response_obj_updated - ) - mock_update_token = mocker.patch( - "app.endpoints.query.update_azure_token", - new=mocker.AsyncMock(return_value=mock_updated_client), - ) - mocker.patch( - "app.endpoints.query.get_topic_summary", - new=mocker.AsyncMock(return_value=None), + mock_client_holder.update_azure_token = mocker.AsyncMock( + return_value=mock_updated_client ) async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: @@ -576,7 +565,7 @@ async def mock_retrieve_response(*_args: Any, **_kwargs: Any) -> TurnSummary: mcp_headers={}, ) - mock_update_token.assert_called_once() + mock_client_holder.update_azure_token.assert_called_once() class TestRetrieveResponse: diff --git a/tests/unit/app/endpoints/test_responses.py b/tests/unit/app/endpoints/test_responses.py index b5d2e5bd9..37c2c7598 100644 --- a/tests/unit/app/endpoints/test_responses.py +++ b/tests/unit/app/endpoints/test_responses.py @@ -466,7 +466,7 @@ async def test_responses_azure_token_refresh( """Test that Azure token refresh is called when model starts with azure.""" responses_request = ResponsesRequest(input="Hi", model="azure/some-model") _patch_base(mocker, minimal_config) - _patch_client(mocker) + _mock_client, mock_holder = _patch_client(mocker) _patch_resolve_response_context(mocker) mocker.patch( f"{MODULE}.select_model_for_responses", @@ -482,10 +482,7 @@ async def test_responses_azure_token_refresh( mock_azure.refresh_token.return_value = True mocker.patch(f"{MODULE}.AzureEntraIDManager", return_value=mock_azure) updated_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) - mock_update_token = mocker.patch( - f"{MODULE}.update_azure_token", - new=mocker.AsyncMock(return_value=updated_client), - ) + mock_holder.update_azure_token = mocker.AsyncMock(return_value=updated_client) _patch_rag(mocker) _patch_moderation(mocker, decision="passed") mocker.patch( @@ -505,7 +502,7 @@ async def test_responses_azure_token_refresh( auth=MOCK_AUTH, mcp_headers={}, ) - mock_update_token.assert_called_once() + mock_holder.update_azure_token.assert_called_once() @pytest.mark.asyncio async def test_responses_structured_input_appends_rag_message( diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 1894190cf..1f4a84cac 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -728,8 +728,12 @@ async def test_streaming_query_azure_token_refresh( ) mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + mock_updated_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) mock_client_holder = mocker.Mock() mock_client_holder.get_client.return_value = mock_client + mock_client_holder.update_azure_token = mocker.AsyncMock( + return_value=mock_updated_client + ) mocker.patch( "app.endpoints.streaming_query.AsyncLlamaStackClientHolder", return_value=mock_client_holder, @@ -757,12 +761,6 @@ async def test_streaming_query_azure_token_refresh( return_value=mock_azure_manager, ) - mock_updated_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) - mock_update_token = mocker.patch( - "app.endpoints.streaming_query.update_azure_token", - new=mocker.AsyncMock(return_value=mock_updated_client), - ) - mocker.patch( "app.endpoints.streaming_query.extract_provider_and_model_from_model_id", return_value=("azure", "model1"), @@ -804,7 +802,7 @@ async def mock_generate_response( mcp_headers={}, ) - mock_update_token.assert_called_once() + mock_client_holder.update_azure_token.assert_called_once() class TestCreateResponseGenerator: diff --git a/tests/unit/authorization/test_azure_token_manager.py b/tests/unit/authorization/test_azure_token_manager.py index f5b7d1fe6..d2c93d4e2 100644 --- a/tests/unit/authorization/test_azure_token_manager.py +++ b/tests/unit/authorization/test_azure_token_manager.py @@ -53,14 +53,12 @@ def test_singleton_behavior(self, token_manager: AzureEntraIDManager) -> None: manager2 = AzureEntraIDManager() assert token_manager is manager2 - def test_initial_state( - self, token_manager: AzureEntraIDManager, mocker: MockerFixture - ) -> None: + def test_initial_state(self, token_manager: AzureEntraIDManager) -> None: """Check the initial token manager state.""" - mocker.patch.dict("os.environ", {"AZURE_API_KEY": ""}, clear=False) assert token_manager.access_token.get_secret_value() == "" assert token_manager.is_token_expired assert not token_manager.is_entra_id_configured + assert token_manager.azure_base_url is None def test_set_config( self, @@ -73,12 +71,26 @@ def test_set_config( def test_token_expiration_logic(self, token_manager: AzureEntraIDManager) -> None: """Verify token expiration logic works correctly.""" - token_manager._expires_on = int(time.time()) + 100 + token_manager._update_access_token("valid-token", int(time.time()) + 100) assert not token_manager.is_token_expired token_manager._expires_on = 0 assert token_manager.is_token_expired + def test_build_azure_provider_data( + self, token_manager: AzureEntraIDManager + ) -> None: + """Test build_azure_provider_data returns token and api_base when set.""" + assert token_manager.build_azure_provider_data() is None + + token_manager.set_base_url("https://azure.example.com") + token_manager._update_access_token("my-token", int(time.time()) + 3600) + + assert token_manager.build_azure_provider_data() == { + "azure_api_key": "my-token", + "azure_api_base": "https://azure.example.com", + } + def test_refresh_token_raises_without_config( self, token_manager: AzureEntraIDManager ) -> None: @@ -153,12 +165,15 @@ def test_token_expired_property_dynamic( ) -> None: """Simulate time passage to test token expiration property.""" now = 1000000 - token_manager._expires_on = now + 10 + token_manager._update_access_token( + "valid-token", now + TOKEN_EXPIRATION_LEEWAY + 60 + ) mocker.patch("authorization.azure_token_manager.time.time", return_value=now) assert not token_manager.is_token_expired mocker.patch( - "authorization.azure_token_manager.time.time", return_value=now + 20 + "authorization.azure_token_manager.time.time", + return_value=now + TOKEN_EXPIRATION_LEEWAY + 120, ) assert token_manager.is_token_expired diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index fd74c56ad..c4cbc9514 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -3,16 +3,19 @@ # pylint: disable=protected-access import json +import time from collections.abc import Callable from typing import Any import pytest from fastapi import HTTPException from llama_stack_client import APIConnectionError, APIStatusError -from pydantic import AnyHttpUrl +from pydantic import AnyHttpUrl, SecretStr from pytest_mock import MockerFixture +from authorization.azure_token_manager import AzureEntraIDManager from client import AsyncLlamaStackClientHolder +from configuration import AzureEntraIdConfiguration from models.config import LlamaStackConfiguration from utils.types import Singleton @@ -96,8 +99,21 @@ async def test_get_async_llama_stack_wrong_configuration() -> None: @pytest.mark.asyncio -async def test_update_provider_data_service_client() -> None: - """Test that update_provider_data updates headers for service clients.""" +async def test_update_azure_token_service_client() -> None: + """Test update_azure_token replaces the service client with new provider headers.""" + AzureEntraIDManager._instances = {} # type: ignore[attr-defined] + manager = AzureEntraIDManager() + manager.set_config( + AzureEntraIdConfiguration( + tenant_id=SecretStr("tenant"), + client_id=SecretStr("client"), + client_secret=SecretStr("secret"), + scope="https://cognitiveservices.azure.com/.default", + ) + ) + manager.set_base_url("https://api.example.com") + manager._update_access_token("fresh-token", int(time.time()) + 3600) + cfg = LlamaStackConfiguration( url=AnyHttpUrl("http://localhost:8321"), api_key=None, @@ -107,41 +123,59 @@ async def test_update_provider_data_service_client() -> None: ) holder = AsyncLlamaStackClientHolder() await holder.load(cfg) - original_client = holder.get_client() - assert not holder.is_library_client - - # Pre-populate with existing provider data via headers - original_client._custom_headers["X-LlamaStack-Provider-Data"] = json.dumps( - { - "existing_field": "keep_this", - "azure_api_key": "old_token", - } - ) - updated_client = holder.update_provider_data( - { - "azure_api_key": "new_token", - "azure_api_base": "https://new.example.com", - } - ) + updated_client = await holder.update_azure_token() - # Returns new client and updates holder assert updated_client is not original_client assert holder.get_client() is updated_client + provider_data_json = updated_client.default_headers.get( + "X-LlamaStack-Provider-Data" + ) + provider_data = json.loads(provider_data_json) + assert provider_data["azure_api_key"] == "fresh-token" + assert provider_data["azure_api_base"] == "https://api.example.com" + - # Verify headers on updated client +@pytest.mark.asyncio +async def test_load_service_client_defers_azure_provider_data() -> None: + """Test service client load does not set Azure headers until update_azure_token.""" + AzureEntraIDManager._instances = {} # type: ignore[attr-defined] + manager = AzureEntraIDManager() + manager.set_config( + AzureEntraIdConfiguration( + tenant_id=SecretStr("tenant"), + client_id=SecretStr("client"), + client_secret=SecretStr("secret"), + scope="https://cognitiveservices.azure.com/.default", + ) + ) + manager.set_base_url("https://ols-test.openai.azure.com/openai/v1") + manager._update_access_token("startup-token", int(time.time()) + 3600) + + cfg = LlamaStackConfiguration( + url=AnyHttpUrl("http://localhost:8321"), + api_key=None, + use_as_library_client=False, + library_client_config_path=None, + timeout=60, + ) + holder = AsyncLlamaStackClientHolder() + await holder.load(cfg) + + default_headers = holder.get_client().default_headers or {} + assert "X-LlamaStack-Provider-Data" not in default_headers + + updated_client = await holder.update_azure_token() provider_data_json = updated_client.default_headers.get( "X-LlamaStack-Provider-Data" ) assert provider_data_json is not None - assert isinstance(provider_data_json, str) provider_data = json.loads(provider_data_json) - - # Existing fields preserved, new fields updated - assert provider_data["existing_field"] == "keep_this" - assert provider_data["azure_api_key"] == "new_token" - assert provider_data["azure_api_base"] == "https://new.example.com" + assert provider_data["azure_api_key"] == "startup-token" + assert ( + provider_data["azure_api_base"] == "https://ols-test.openai.azure.com/openai/v1" + ) @pytest.mark.asyncio diff --git a/tests/unit/test_llama_stack_configuration.py b/tests/unit/test_llama_stack_configuration.py index d10f60580..aaa3bf53e 100644 --- a/tests/unit/test_llama_stack_configuration.py +++ b/tests/unit/test_llama_stack_configuration.py @@ -12,6 +12,7 @@ construct_vector_io_providers_section, construct_vector_stores_section, dedupe_providers_vector_io, + enrich_azure_entra_id_inference, enrich_byok_rag, enrich_solr, generate_configuration, @@ -24,6 +25,85 @@ UserDataCollection, ) +# ============================================================================= +# Test enrich_azure_entra_id_inference +# ============================================================================= + + +def test_enrich_azure_entra_id_inference_skips_when_not_configured() -> None: + """Test enrich_azure_entra_id_inference does nothing without Entra ID config.""" + ls_config: dict[str, Any] = { + "providers": { + "inference": [ + { + "provider_id": "azure", + "provider_type": "remote::azure", + "config": {"model_validation": True}, + } + ] + } + } + enrich_azure_entra_id_inference(ls_config, None) + assert ls_config["providers"]["inference"][0]["config"] == { + "model_validation": True + } + + +def test_enrich_azure_entra_id_inference_sets_model_validation_false() -> None: + """Test enrich_azure_entra_id_inference disables startup model validation.""" + ls_config: dict[str, Any] = { + "providers": { + "inference": [ + { + "provider_id": "azure", + "provider_type": "remote::azure", + "config": {}, + } + ] + } + } + enrich_azure_entra_id_inference(ls_config, {"tenant_id": "t"}) + azure_config = ls_config["providers"]["inference"][0]["config"] + assert azure_config["model_validation"] is False + + +def test_generate_configuration_enriches_azure_entra_id(tmp_path: Path) -> None: + """Test generate_configuration applies Azure Entra ID enrichment.""" + infile = tmp_path / "in.yaml" + outfile = tmp_path / "out.yaml" + with open(infile, "w", encoding="utf-8") as f: + yaml.dump( + { + "version": 2, + "providers": { + "inference": [ + { + "provider_id": "azure", + "provider_type": "remote::azure", + "config": {"api_key": "${env.AZURE_API_KEY}"}, + } + ] + }, + "registered_resources": {}, + }, + f, + ) + + generate_configuration( + str(infile), + str(outfile), + {"azure_entra_id": {"tenant_id": "tenant"}}, + ) + + with open(outfile, encoding="utf-8") as f: + result = yaml.safe_load(f) + + azure_config = result["providers"]["inference"][0]["config"] + assert azure_config["model_validation"] is False + assert azure_config["api_key"] == "${env.AZURE_API_KEY}" + assert not (tmp_path / ".env").exists() + + # ============================================================================= # Test construct_vector_stores_section # ============================================================================= @@ -444,9 +524,7 @@ def test_generate_configuration_dedupes_vector_io_on_load(tmp_path: Path) -> Non }, f, ) - generate_configuration( - str(infile), str(outfile), {}, env_file=str(tmp_path / ".env") - ) + generate_configuration(str(infile), str(outfile), {}) with open(outfile, encoding="utf-8") as f: result = yaml.safe_load(f) dupme = [ diff --git a/tests/unit/utils/test_query.py b/tests/unit/utils/test_query.py index 18bd887bb..7985f1316 100644 --- a/tests/unit/utils/test_query.py +++ b/tests/unit/utils/test_query.py @@ -8,7 +8,6 @@ import psycopg2 import pytest from fastapi import HTTPException -from llama_stack_client import APIConnectionError, APIStatusError from llama_stack_client.types import ModelListResponse from pytest_mock import MockerFixture from sqlalchemy.exc import SQLAlchemyError @@ -38,7 +37,6 @@ prepare_input, store_conversation_into_cache, store_query_results, - update_azure_token, validate_attachments_metadata, validate_model_provider_override, ) @@ -554,102 +552,6 @@ def test_consume_tokens_database_error(self, mocker: MockerFixture) -> None: assert exc_info.value.status_code == 500 -class TestUpdateAzureToken: - """Tests for update_azure_token function.""" - - @pytest.mark.asyncio - async def test_update_with_library_client(self, mocker: MockerFixture) -> None: - """Test updating token with library client.""" - mock_client_holder = mocker.Mock() - mock_client_holder.is_library_client = True - mock_client_holder.reload_library_client = mocker.AsyncMock( - return_value="client" - ) - mocker.patch( - "utils.query.AsyncLlamaStackClientHolder", return_value=mock_client_holder - ) - - mock_client = mocker.Mock() - result = await update_azure_token(mock_client) - assert result == "client" - mock_client_holder.reload_library_client.assert_called_once() - - @pytest.mark.asyncio - async def test_update_with_remote_client(self, mocker: MockerFixture) -> None: - """Test updating token with remote client.""" - mock_client_holder = mocker.Mock() - mock_client_holder.is_library_client = False - mock_client_holder.update_provider_data = mocker.Mock( - return_value="updated_client" - ) - mocker.patch( - "utils.query.AsyncLlamaStackClientHolder", return_value=mock_client_holder - ) - - mock_provider = type( - "Provider", - (), - { - "provider_type": "remote::azure", - "config": {"api_base": "https://api.example.com"}, - }, - )() - mock_client = mocker.AsyncMock() - mock_client.providers.list = mocker.AsyncMock(return_value=[mock_provider]) - - mocker.patch( - "utils.query.AzureEntraIDManager", - return_value=mocker.Mock( - access_token=mocker.Mock( - get_secret_value=mocker.Mock(return_value="token") - ) - ), - ) - - result = await update_azure_token(mock_client) - assert result == "updated_client" - - @pytest.mark.asyncio - async def test_update_with_connection_error(self, mocker: MockerFixture) -> None: - """Test updating token raises HTTPException on connection error.""" - mock_client_holder = mocker.Mock() - mock_client_holder.is_library_client = False - mocker.patch( - "utils.query.AsyncLlamaStackClientHolder", return_value=mock_client_holder - ) - - mock_client = mocker.AsyncMock() - mock_client.providers.list = mocker.AsyncMock( - side_effect=APIConnectionError( - message="Connection failed", request=mocker.Mock() - ) - ) - - with pytest.raises(HTTPException) as exc_info: - await update_azure_token(mock_client) - assert exc_info.value.status_code == 503 - - @pytest.mark.asyncio - async def test_update_with_api_status_error(self, mocker: MockerFixture) -> None: - """Test updating token raises HTTPException on API status error.""" - mock_client_holder = mocker.Mock() - mock_client_holder.is_library_client = False - mocker.patch( - "utils.query.AsyncLlamaStackClientHolder", return_value=mock_client_holder - ) - - mock_client = mocker.AsyncMock() - # Create a mock exception that will be caught by except APIStatusError - mock_error = APIStatusError( - message="API error", response=mocker.Mock(request=None), body=None - ) - mock_client.providers.list = mocker.AsyncMock(side_effect=mock_error) - - with pytest.raises(HTTPException) as exc_info: - await update_azure_token(mock_client) - assert exc_info.value.status_code == 500 - - class TestStoreQueryResults: """Tests for store_query_results function."""