diff --git a/src/openai/lib/_tools.py b/src/openai/lib/_tools.py index 4070ad63bb..6b6419b898 100644 --- a/src/openai/lib/_tools.py +++ b/src/openai/lib/_tools.py @@ -1,14 +1,43 @@ from __future__ import annotations -from typing import Any, Dict, cast +from typing import Any, Dict, Iterable, List, cast import pydantic from ._pydantic import to_strict_json_schema +from .._types import Omit +from .._utils import is_given from ..types.chat import ChatCompletionFunctionToolParam from ..types.shared_params import FunctionDefinition from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam +_WEB_SEARCH_TOOL_TYPES = frozenset( + {"web_search", "web_search_2025_08_26", "web_search_preview", "web_search_preview_2025_03_11"} +) + + +def _apply_web_search_default_location_tools( + tools: Iterable[Any] | Omit, +) -> Iterable[Any] | Omit: + """For web_search tools that lack user_location, inject user_location with type='approximate'. + + This prevents the server from defaulting to a US-based location when no + user_location is specified, which is unexpected behavior for developers + outside the US. + """ + if not is_given(tools): + return tools + + result: List[Any] = [] + changed = False + for tool in tools: + if isinstance(tool, dict) and tool.get("type") in _WEB_SEARCH_TOOL_TYPES and "user_location" not in tool: + tool = {**tool, "user_location": {"type": "approximate"}} + changed = True + result.append(tool) + + return result if changed else tools + class PydanticFunctionTool(Dict[str, Any]): """Dictionary wrapper so we can pass the given base model diff --git a/src/openai/resources/responses/responses.py b/src/openai/resources/responses/responses.py index 48705098cc..5a089f4da5 100644 --- a/src/openai/resources/responses/responses.py +++ b/src/openai/resources/responses/responses.py @@ -44,7 +44,11 @@ AsyncInputItemsWithStreamingResponse, ) from ..._streaming import Stream, AsyncStream -from ...lib._tools import PydanticFunctionTool, ResponsesPydanticFunctionTool +from ...lib._tools import ( + PydanticFunctionTool, + ResponsesPydanticFunctionTool, + _apply_web_search_default_location_tools, +) from .input_tokens import ( InputTokens, AsyncInputTokens, @@ -942,7 +946,7 @@ def create( "temperature": temperature, "text": text, "tool_choice": tool_choice, - "tools": tools, + "tools": _apply_web_search_default_location_tools(tools), "top_logprobs": top_logprobs, "top_p": top_p, "truncation": truncation, @@ -1256,7 +1260,7 @@ def parser(raw_response: Response) -> ParsedResponse[TextFormatT]: "temperature": temperature, "text": text, "tool_choice": tool_choice, - "tools": tools, + "tools": _apply_web_search_default_location_tools(tools), "top_logprobs": top_logprobs, "top_p": top_p, "truncation": truncation, @@ -2623,7 +2627,7 @@ async def create( "temperature": temperature, "text": text, "tool_choice": tool_choice, - "tools": tools, + "tools": _apply_web_search_default_location_tools(tools), "top_logprobs": top_logprobs, "top_p": top_p, "truncation": truncation, @@ -2941,7 +2945,7 @@ def parser(raw_response: Response) -> ParsedResponse[TextFormatT]: "temperature": temperature, "text": text, "tool_choice": tool_choice, - "tools": tools, + "tools": _apply_web_search_default_location_tools(tools), "top_logprobs": top_logprobs, "top_p": top_p, "truncation": truncation, @@ -4597,7 +4601,7 @@ def create( "temperature": temperature, "text": text, "tool_choice": tool_choice, - "tools": tools, + "tools": _apply_web_search_default_location_tools(tools), "top_logprobs": top_logprobs, "top_p": top_p, "truncation": truncation, @@ -4677,7 +4681,7 @@ async def create( "temperature": temperature, "text": text, "tool_choice": tool_choice, - "tools": tools, + "tools": _apply_web_search_default_location_tools(tools), "top_logprobs": top_logprobs, "top_p": top_p, "truncation": truncation, diff --git a/tests/lib/test_tools.py b/tests/lib/test_tools.py new file mode 100644 index 0000000000..8cbf8ae302 --- /dev/null +++ b/tests/lib/test_tools.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from openai._types import Omit, omit +from openai.lib._tools import _apply_web_search_default_location_tools + + +class TestApplyWebSearchDefaultLocationTools: + """Tests for _apply_web_search_default_location_tools.""" + + def test_no_tools_returns_omit(self) -> None: + """When tools is Omit (not given), return it unchanged.""" + result = _apply_web_search_default_location_tools(omit) + assert result is omit + + def test_empty_tools_list(self) -> None: + """An empty list of tools should be returned unchanged.""" + tools: list = [] + result = _apply_web_search_default_location_tools(tools) + assert result is tools # same reference, no changes + + def test_non_web_search_tools_unchanged(self) -> None: + """Tools that are not web_search should not be modified.""" + tools = [ + {"type": "function", "function": {"name": "my_func"}}, + {"type": "code_interpreter"}, + ] + result = _apply_web_search_default_location_tools(tools) + assert result is tools # same reference, no changes + + def test_web_search_injects_user_location(self) -> None: + """web_search without user_location should get one injected.""" + tools = [{"type": "web_search"}] + result = _apply_web_search_default_location_tools(tools) + assert result is not tools # new list created + assert result[0]["user_location"] == {"type": "approximate"} + assert result[0]["type"] == "web_search" + + def test_web_search_with_existing_user_location_unchanged(self) -> None: + """web_search that already has user_location should not be overridden.""" + existing_loc = {"type": "approximate", "city": "London", "country": "GB"} + tools = [{"type": "web_search", "user_location": existing_loc}] + result = _apply_web_search_default_location_tools(tools) + assert result is tools # same reference, no changes needed + assert result[0]["user_location"] is existing_loc + + def test_web_search_2025_08_26_injects(self) -> None: + tools = [{"type": "web_search_2025_08_26"}] + result = _apply_web_search_default_location_tools(tools) + assert result[0]["user_location"] == {"type": "approximate"} + + def test_web_search_preview_injects(self) -> None: + tools = [{"type": "web_search_preview"}] + result = _apply_web_search_default_location_tools(tools) + assert result[0]["user_location"] == {"type": "approximate"} + + def test_web_search_preview_2025_03_11_injects(self) -> None: + tools = [{"type": "web_search_preview_2025_03_11"}] + result = _apply_web_search_default_location_tools(tools) + assert result[0]["user_location"] == {"type": "approximate"} + + def test_mixed_tools_only_web_search_modified(self) -> None: + """When mixing web_search and non-web-search tools, only web_search gets modified.""" + func_tool = {"type": "function", "function": {"name": "foo"}} + ws_tool = {"type": "web_search"} + tools = [func_tool, ws_tool] + result = _apply_web_search_default_location_tools(tools) + # function tool is unchanged + assert result[0] is func_tool + # web_search tool is a new dict with user_location injected + assert result[1]["user_location"] == {"type": "approximate"} + assert result[1]["type"] == "web_search" + + def test_web_search_preserves_other_keys(self) -> None: + """Injecting user_location should not drop other keys on the tool dict.""" + tools = [{"type": "web_search", "extra_key": "value"}] + result = _apply_web_search_default_location_tools(tools) + assert result[0]["extra_key"] == "value" + assert result[0]["user_location"] == {"type": "approximate"} + + def test_non_dict_tool_not_modified(self) -> None: + """Non-dict tools (e.g. string shorthand) should pass through.""" + tools = ["web_search"] # string, not dict + result = _apply_web_search_default_location_tools(tools) + assert result is tools # unchanged