Skip to content

Commit 41a8235

Browse files
author
giulio-leone
committed
fix(azure): strip model from request body for deployment-based endpoints
1 parent 656e3ca commit 41a8235

2 files changed

Lines changed: 36 additions & 1 deletion

File tree

src/openai/lib/azure.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ def _build_request(
6161
retries_taken: int = 0,
6262
) -> httpx.Request:
6363
if options.url in _deployments_endpoints and is_mapping(options.json_data):
64-
model = options.json_data.get("model")
64+
json_data = cast(Mapping[str, Any], options.json_data)
65+
model = json_data.get("model")
6566
if model is not None and "/deployments" not in str(self.base_url.path):
6667
options.url = f"/deployments/{model}{options.url}"
68+
options.json_data = {k: v for k, v in json_data.items() if k != "model"}
6769

6870
return super()._build_request(options, retries_taken=retries_taken)
6971

tests/lib/test_azure.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
import logging
45
from typing import Union, cast
56
from typing_extensions import Literal, Protocol
@@ -47,6 +48,38 @@ def test_implicit_deployment_path(client: Client) -> None:
4748
)
4849

4950

51+
@pytest.mark.parametrize("client", [sync_client, async_client])
52+
@pytest.mark.parametrize(
53+
"endpoint,model",
54+
[
55+
("/chat/completions", "gpt-4o"),
56+
("/completions", "gpt-4o"),
57+
("/embeddings", "text-embedding-ada-002"),
58+
("/images/generations", "gpt-image-1-5"),
59+
("/images/edits", "gpt-image-1-5"),
60+
("/audio/transcriptions", "whisper-1"),
61+
("/audio/translations", "whisper-1"),
62+
("/audio/speech", "tts-1"),
63+
],
64+
)
65+
def test_implicit_deployment_strips_model_from_body(client: Client, endpoint: str, model: str) -> None:
66+
req = client._build_request(
67+
FinalRequestOptions.construct(
68+
method="post",
69+
url=endpoint,
70+
json_data={"model": model, "extra": "value"},
71+
)
72+
)
73+
74+
body = json.loads(req.content.decode())
75+
assert "model" not in body
76+
assert body["extra"] == "value"
77+
assert (
78+
str(req.url)
79+
== f"https://example-resource.azure.openai.com/openai/deployments/{model}{endpoint}?api-version=2023-07-01"
80+
)
81+
82+
5083
@pytest.mark.parametrize(
5184
"client,method",
5285
[

0 commit comments

Comments
 (0)