Skip to content

Commit 83b0cdf

Browse files
Prevent duplicate default headers on case-insensitive overrides
Co-authored-by: Shri Sukhani <shrisukhani@users.noreply.github.com>
1 parent b20aa0d commit 83b0cdf

File tree

5 files changed

+104
-17
lines changed

5 files changed

+104
-17
lines changed

hyperbrowser/header_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,33 @@ def normalize_headers(
4141
return normalized_headers
4242

4343

44+
def merge_headers(
45+
base_headers: Mapping[str, str],
46+
override_headers: Optional[Mapping[str, str]],
47+
*,
48+
mapping_error_message: str,
49+
pair_error_message: Optional[str] = None,
50+
) -> Dict[str, str]:
51+
merged_headers = dict(base_headers)
52+
normalized_overrides = normalize_headers(
53+
override_headers,
54+
mapping_error_message=mapping_error_message,
55+
pair_error_message=pair_error_message,
56+
)
57+
if not normalized_overrides:
58+
return merged_headers
59+
60+
existing_canonical_to_key = {key.lower(): key for key in merged_headers}
61+
for override_key, override_value in normalized_overrides.items():
62+
canonical_override_key = override_key.lower()
63+
existing_key = existing_canonical_to_key.get(canonical_override_key)
64+
if existing_key is not None:
65+
del merged_headers[existing_key]
66+
merged_headers[override_key] = override_value
67+
existing_canonical_to_key[canonical_override_key] = override_key
68+
return merged_headers
69+
70+
4471
def parse_headers_env_json(raw_headers: Optional[str]) -> Optional[Dict[str, str]]:
4572
if raw_headers is None:
4673
return None

hyperbrowser/transport/async_transport.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Mapping, Optional
44

55
from hyperbrowser.exceptions import HyperbrowserError
6-
from hyperbrowser.header_utils import normalize_headers
6+
from hyperbrowser.header_utils import merge_headers
77
from hyperbrowser.version import __version__
88
from .base import APIResponse, AsyncTransportStrategy
99
from .error_utils import extract_error_message, extract_request_error_context
@@ -18,16 +18,14 @@ def __init__(self, api_key: str, headers: Optional[Mapping[str, str]] = None):
1818
normalized_api_key = api_key.strip()
1919
if not normalized_api_key:
2020
raise HyperbrowserError("api_key must not be empty")
21-
merged_headers = {
22-
"x-api-key": normalized_api_key,
23-
"User-Agent": f"hyperbrowser-python-sdk/{__version__}",
24-
}
25-
normalized_headers = normalize_headers(
21+
merged_headers = merge_headers(
22+
{
23+
"x-api-key": normalized_api_key,
24+
"User-Agent": f"hyperbrowser-python-sdk/{__version__}",
25+
},
2626
headers,
2727
mapping_error_message="headers must be a mapping of string pairs",
2828
)
29-
if normalized_headers:
30-
merged_headers.update(normalized_headers)
3129
self.client = httpx.AsyncClient(headers=merged_headers)
3230
self._closed = False
3331

hyperbrowser/transport/sync.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Mapping, Optional
44

55
from hyperbrowser.exceptions import HyperbrowserError
6-
from hyperbrowser.header_utils import normalize_headers
6+
from hyperbrowser.header_utils import merge_headers
77
from hyperbrowser.version import __version__
88
from .base import APIResponse, SyncTransportStrategy
99
from .error_utils import extract_error_message, extract_request_error_context
@@ -18,16 +18,14 @@ def __init__(self, api_key: str, headers: Optional[Mapping[str, str]] = None):
1818
normalized_api_key = api_key.strip()
1919
if not normalized_api_key:
2020
raise HyperbrowserError("api_key must not be empty")
21-
merged_headers = {
22-
"x-api-key": normalized_api_key,
23-
"User-Agent": f"hyperbrowser-python-sdk/{__version__}",
24-
}
25-
normalized_headers = normalize_headers(
21+
merged_headers = merge_headers(
22+
{
23+
"x-api-key": normalized_api_key,
24+
"User-Agent": f"hyperbrowser-python-sdk/{__version__}",
25+
},
2626
headers,
2727
mapping_error_message="headers must be a mapping of string pairs",
2828
)
29-
if normalized_headers:
30-
merged_headers.update(normalized_headers)
3129
self.client = httpx.Client(headers=merged_headers)
3230

3331
def _handle_response(self, response: httpx.Response) -> APIResponse:

tests/test_header_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import pytest
22

33
from hyperbrowser.exceptions import HyperbrowserError
4-
from hyperbrowser.header_utils import normalize_headers, parse_headers_env_json
4+
from hyperbrowser.header_utils import (
5+
merge_headers,
6+
normalize_headers,
7+
parse_headers_env_json,
8+
)
59

610

711
def test_normalize_headers_trims_header_names():
@@ -67,3 +71,16 @@ def test_parse_headers_env_json_rejects_non_mapping_payload():
6771
match="HYPERBROWSER_HEADERS must be a JSON object of string pairs",
6872
):
6973
parse_headers_env_json('["bad"]')
74+
75+
76+
def test_merge_headers_replaces_existing_headers_case_insensitively():
77+
merged = merge_headers(
78+
{"User-Agent": "default-sdk", "x-api-key": "test-key"},
79+
{"user-agent": "custom-sdk", "X-API-KEY": "override-key"},
80+
mapping_error_message="headers must be a mapping of string pairs",
81+
)
82+
83+
assert merged["user-agent"] == "custom-sdk"
84+
assert merged["X-API-KEY"] == "override-key"
85+
assert "User-Agent" not in merged
86+
assert "x-api-key" not in merged

tests/test_transport_headers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,50 @@ async def run() -> None:
3030
await transport.close()
3131

3232
asyncio.run(run())
33+
34+
35+
def test_sync_transport_case_insensitive_header_overrides_replace_defaults():
36+
transport = SyncTransport(
37+
api_key="test-key",
38+
headers={"user-agent": "custom-agent", "X-API-KEY": "override-key"},
39+
)
40+
try:
41+
user_agent_values = [
42+
value
43+
for key, value in transport.client.headers.multi_items()
44+
if key.lower() == "user-agent"
45+
]
46+
api_key_values = [
47+
value
48+
for key, value in transport.client.headers.multi_items()
49+
if key.lower() == "x-api-key"
50+
]
51+
assert user_agent_values == ["custom-agent"]
52+
assert api_key_values == ["override-key"]
53+
finally:
54+
transport.close()
55+
56+
57+
def test_async_transport_case_insensitive_header_overrides_replace_defaults():
58+
async def run() -> None:
59+
transport = AsyncTransport(
60+
api_key="test-key",
61+
headers={"user-agent": "custom-agent", "X-API-KEY": "override-key"},
62+
)
63+
try:
64+
user_agent_values = [
65+
value
66+
for key, value in transport.client.headers.multi_items()
67+
if key.lower() == "user-agent"
68+
]
69+
api_key_values = [
70+
value
71+
for key, value in transport.client.headers.multi_items()
72+
if key.lower() == "x-api-key"
73+
]
74+
assert user_agent_values == ["custom-agent"]
75+
assert api_key_values == ["override-key"]
76+
finally:
77+
await transport.close()
78+
79+
asyncio.run(run())

0 commit comments

Comments
 (0)