|
1 | 1 | from typing import Self |
2 | 2 |
|
3 | | -from pydantic import model_validator |
| 3 | +from httpx import URL, Request |
| 4 | +from pydantic import Field, model_validator |
4 | 5 |
|
5 | 6 | from uipath_langchain_client.base_client import UiPathBaseChatModel |
6 | 7 | from uipath_langchain_client.settings import UiPathAPIConfig |
7 | 8 |
|
8 | 9 | try: |
9 | | - from azure.ai.inference import ChatCompletionsClient |
10 | | - from azure.ai.inference.aio import ChatCompletionsClient as ChatCompletionsClientAsync |
11 | | - from azure.core.credentials import AzureKeyCredential |
12 | | - from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel |
| 10 | + from azure.core.credentials import AzureKeyCredential, TokenCredential |
| 11 | + from azure.core.credentials_async import AsyncTokenCredential |
| 12 | + from langchain_azure_ai.chat_models import AzureAIOpenAIApiChatModel |
| 13 | + from openai import AsyncOpenAI, OpenAI |
13 | 14 | except ImportError as e: |
14 | 15 | raise ImportError( |
15 | 16 | "The 'azure' extra is required to use UiPathAzureAIChatCompletionsModel. " |
16 | 17 | "Install it with: uv add uipath-langchain-client[azure]" |
17 | 18 | ) from e |
18 | 19 |
|
19 | 20 |
|
20 | | -class UiPathAzureAIChatCompletionsModel(UiPathBaseChatModel, AzureAIChatCompletionsModel): # type: ignore[override] |
| 21 | +class UiPathAzureAIChatCompletionsModel(UiPathBaseChatModel, AzureAIOpenAIApiChatModel): # type: ignore[override] |
21 | 22 | api_config: UiPathAPIConfig = UiPathAPIConfig( |
22 | 23 | api_type="completions", |
23 | 24 | client_type="passthrough", |
24 | 25 | vendor_type="azure", |
25 | | - freeze_base_url=True, |
| 26 | + freeze_base_url=False, |
26 | 27 | ) |
27 | 28 |
|
28 | | - # Override fields to avoid errors when instantiating the class |
29 | | - endpoint: str | None = "PLACEHOLDER" |
| 29 | + # Override fields to avoid env var lookup / validation errors at instantiation |
| 30 | + endpoint: str | None = Field(default="PLACEHOLDER") |
| 31 | + credential: str | AzureKeyCredential | TokenCredential | AsyncTokenCredential | None = Field( |
| 32 | + default="PLACEHOLDER" |
| 33 | + ) |
30 | 34 |
|
31 | 35 | @model_validator(mode="after") |
32 | 36 | def setup_uipath_client(self) -> Self: |
33 | | - # TODO: finish implementation once we have a proper model in UiPath API |
34 | | - self._client = ChatCompletionsClient( |
35 | | - endpoint="PLACEHOLDER", |
36 | | - credential=AzureKeyCredential("PLACEHOLDER"), |
37 | | - model=self.model_name, |
38 | | - **self.client_kwargs, |
| 37 | + base_url = str(self.uipath_sync_client.base_url).rstrip("/") |
| 38 | + |
| 39 | + def fix_url_and_api_flavor_header(request: Request): |
| 40 | + url_suffix = str(request.url).split(base_url)[-1] |
| 41 | + if "responses" in url_suffix: |
| 42 | + request.headers["X-UiPath-LlmGateway-ApiFlavor"] = "responses" |
| 43 | + else: |
| 44 | + request.headers["X-UiPath-LlmGateway-ApiFlavor"] = "chat-completions" |
| 45 | + request.url = URL(base_url) |
| 46 | + |
| 47 | + async def fix_url_and_api_flavor_header_async(request: Request): |
| 48 | + url_suffix = str(request.url).split(base_url)[-1] |
| 49 | + if "responses" in url_suffix: |
| 50 | + request.headers["X-UiPath-LlmGateway-ApiFlavor"] = "responses" |
| 51 | + else: |
| 52 | + request.headers["X-UiPath-LlmGateway-ApiFlavor"] = "chat-completions" |
| 53 | + request.url = URL(base_url) |
| 54 | + |
| 55 | + self.uipath_sync_client.event_hooks["request"].append(fix_url_and_api_flavor_header) |
| 56 | + self.uipath_async_client.event_hooks["request"].append(fix_url_and_api_flavor_header_async) |
| 57 | + |
| 58 | + self.root_client = OpenAI( |
| 59 | + api_key="PLACEHOLDER", |
| 60 | + max_retries=0, # handled by the UiPath client |
| 61 | + http_client=self.uipath_sync_client, |
39 | 62 | ) |
40 | | - self._async_client = ChatCompletionsClientAsync( |
41 | | - endpoint="PLACEHOLDER", |
42 | | - credential=AzureKeyCredential("PLACEHOLDER"), |
43 | | - model=self.model_name, |
44 | | - **self.client_kwargs, |
| 63 | + self.root_async_client = AsyncOpenAI( |
| 64 | + api_key="PLACEHOLDER", |
| 65 | + max_retries=0, # handled by the UiPath client |
| 66 | + http_client=self.uipath_async_client, |
45 | 67 | ) |
| 68 | + self.client = self.root_client.chat.completions |
| 69 | + self.async_client = self.root_async_client.chat.completions |
46 | 70 | return self |
0 commit comments