Skip to content

Commit cd841bb

Browse files
MarkDaoustcopybara-github
authored andcommitted
Feat: add support for using provided credentials in non-Vertex mode.
PiperOrigin-RevId: 875347183
1 parent dd52cc2 commit cd841bb

2 files changed

Lines changed: 194 additions & 19 deletions

File tree

google/genai/_api_client.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,21 @@ def __init__(
702702
)
703703
self._http_options.api_version = 'v1beta1'
704704
else: # Implicit initialization or missing arguments.
705-
if not self.api_key:
705+
if env_api_key and api_key:
706+
# Explicit api_key takes precedence over implicit api_key.
707+
logger.info(
708+
'The client initializer api_key argument takes '
709+
'precedence over the API key from the environment variable.'
710+
)
711+
if credentials:
712+
if env_api_key:
713+
logger.info(
714+
'The user `credentials` argument will take precedence over the'
715+
' api key from the environment variables.'
716+
)
717+
self.api_key = None
718+
719+
if not self.api_key and not credentials:
706720
raise ValueError(
707721
'No API key was provided. Please pass a valid API key. Learn how to'
708722
' create an API key at'
@@ -1175,20 +1189,21 @@ def _request_once(
11751189
stream: bool = False,
11761190
) -> HttpResponse:
11771191
data: Optional[Union[str, bytes]] = None
1178-
# If using proj/location, fetch ADC
1179-
if self.vertexai and (self.project or self.location):
1192+
1193+
uses_vertex_creds = self.vertexai and (self.project or self.location)
1194+
uses_mldev_creds = not self.vertexai and self._credentials
1195+
if (uses_vertex_creds or uses_mldev_creds):
11801196
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
11811197
if self._credentials and self._credentials.quota_project_id:
11821198
http_request.headers['x-goog-user-project'] = (
11831199
self._credentials.quota_project_id
11841200
)
1185-
data = json.dumps(http_request.data) if http_request.data else None
1186-
else:
1187-
if http_request.data:
1188-
if not isinstance(http_request.data, bytes):
1189-
data = json.dumps(http_request.data) if http_request.data else None
1190-
else:
1191-
data = http_request.data
1201+
1202+
if http_request.data:
1203+
if not isinstance(http_request.data, bytes):
1204+
data = json.dumps(http_request.data)
1205+
else:
1206+
data = http_request.data
11921207

11931208
if stream:
11941209
httpx_request = self._httpx_client.build_request(
@@ -1241,22 +1256,22 @@ async def _async_request_once(
12411256
) -> HttpResponse:
12421257
data: Optional[Union[str, bytes]] = None
12431258

1244-
# If using proj/location, fetch ADC
1245-
if self.vertexai and (self.project or self.location):
1259+
uses_vertex_creds = self.vertexai and (self.project or self.location)
1260+
uses_mldev_creds = not self.vertexai and self._credentials
1261+
if (uses_vertex_creds or uses_mldev_creds):
12461262
http_request.headers['Authorization'] = (
12471263
f'Bearer {await self._async_access_token()}'
12481264
)
12491265
if self._credentials and self._credentials.quota_project_id:
12501266
http_request.headers['x-goog-user-project'] = (
12511267
self._credentials.quota_project_id
12521268
)
1253-
data = json.dumps(http_request.data) if http_request.data else None
1254-
else:
1255-
if http_request.data:
1256-
if not isinstance(http_request.data, bytes):
1257-
data = json.dumps(http_request.data) if http_request.data else None
1258-
else:
1259-
data = http_request.data
1269+
1270+
if http_request.data:
1271+
if not isinstance(http_request.data, bytes):
1272+
data = json.dumps(http_request.data)
1273+
else:
1274+
data = http_request.data
12601275

12611276
if stream:
12621277
if self._use_aiohttp():

google/genai/tests/client/test_client_initialization.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,28 @@
4747
)
4848

4949

50+
class FakeCredentials(credentials.Credentials):
51+
def __init__(self, token="fake_token", expired=False, quota_project_id=None):
52+
super().__init__()
53+
self.token = token
54+
self._expired = expired
55+
self._quota_project_id = quota_project_id
56+
self.refresh_count = 0
57+
58+
@property
59+
def expired(self):
60+
return self._expired
61+
62+
@property
63+
def quota_project_id(self):
64+
return self._quota_project_id
65+
66+
def refresh(self, request):
67+
self.refresh_count += 1
68+
self.token = "refreshed_token"
69+
self._expired = False
70+
71+
5072
@pytest.fixture(autouse=True)
5173
def reset_has_aiohttp():
5274
yield
@@ -1721,3 +1743,141 @@ async def test_get_aiohttp_session():
17211743
assert initial_session is not None
17221744
session = await client._api_client._get_aiohttp_session()
17231745
assert session is initial_session
1746+
1747+
1748+
def test_missing_api_key_and_credentials(monkeypatch):
1749+
monkeypatch.setenv("GOOGLE_API_KEY", "")
1750+
with pytest.raises(ValueError, match="Missing key inputs argument!"):
1751+
Client()
1752+
1753+
1754+
auth_precedence_test_cases = [
1755+
# client_args, env_vars, expected_headers
1756+
(
1757+
{"credentials": FakeCredentials()},
1758+
{"GOOGLE_API_KEY": "env_api_key"},
1759+
{"Authorization": "Bearer fake_token"}
1760+
),
1761+
(
1762+
{"credentials": FakeCredentials(quota_project_id="quota-proj")},
1763+
{"GOOGLE_API_KEY": "env_api_key"},
1764+
{
1765+
"Authorization": "Bearer fake_token",
1766+
"x-goog-user-project": "quota-proj"
1767+
}
1768+
),
1769+
(
1770+
{"api_key": "test_api_key"},
1771+
{"GOOGLE_API_KEY": "env_api_key"},
1772+
{"x-goog-api-key": "test_api_key"}
1773+
),
1774+
(
1775+
{},
1776+
{"GOOGLE_API_KEY": "env_api_key"},
1777+
{"x-goog-api-key": "env_api_key"}
1778+
),
1779+
]
1780+
1781+
1782+
@pytest.mark.parametrize(
1783+
["client_kwargs", "env_vars", "expected_headers"],
1784+
auth_precedence_test_cases,
1785+
)
1786+
@mock.patch.object(httpx.Client, "send", autospec=True)
1787+
def test_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers):
1788+
for key, value in env_vars.items():
1789+
monkeypatch.setenv(key, value)
1790+
1791+
client = Client(**client_kwargs)
1792+
mock_send.return_value = httpx.Response(
1793+
status_code=200,
1794+
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}
1795+
)
1796+
client.models.generate_content(model="test", contents="hello?")
1797+
mock_send.assert_called_once()
1798+
request = mock_send.call_args[0][1]
1799+
1800+
for key, value in expected_headers.items():
1801+
assert key in request.headers
1802+
assert request.headers[key] == value
1803+
1804+
if "Authorization" in expected_headers:
1805+
assert "x-goog-api-key" not in request.headers
1806+
if "x-goog-api-key" in expected_headers:
1807+
assert "Authorization" not in request.headers
1808+
if "x-goog-user-project" not in expected_headers:
1809+
assert "x-goog-user-project" not in request.headers
1810+
assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}'
1811+
1812+
@pytest.mark.parametrize(
1813+
["client_kwargs", "env_vars", "expected_headers"],
1814+
auth_precedence_test_cases,
1815+
)
1816+
@pytest.mark.asyncio
1817+
@mock.patch.object(httpx.AsyncClient, "send", autospec=True)
1818+
async def test_async_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers):
1819+
for key, value in env_vars.items():
1820+
monkeypatch.setenv(key, value)
1821+
1822+
client = Client(**client_kwargs)
1823+
mock_send.return_value = httpx.Response(
1824+
status_code=200,
1825+
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}
1826+
)
1827+
await client.aio.models.generate_content(model="test", contents="hello?")
1828+
mock_send.assert_called_once()
1829+
request = mock_send.call_args[0][1]
1830+
1831+
for key, value in expected_headers.items():
1832+
assert key in request.headers
1833+
assert request.headers[key] == value
1834+
1835+
if "Authorization" in expected_headers:
1836+
assert "x-goog-api-key" not in request.headers
1837+
if "x-goog-api-key" in expected_headers:
1838+
assert "Authorization" not in request.headers
1839+
if "x-goog-user-project" not in expected_headers:
1840+
assert "x-goog-user-project" not in request.headers
1841+
assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}'
1842+
1843+
1844+
async def test_both_credentials_mldev():
1845+
with pytest.raises(ValueError, match="mutually exclusive"):
1846+
creds = FakeCredentials(expired=True)
1847+
client = Client(credentials=creds, api_key="test-api-key")
1848+
1849+
1850+
@mock.patch.object(httpx.Client, "send", autospec=True)
1851+
def test_refresh_credentials_mldev(mock_send):
1852+
creds = FakeCredentials(expired=True)
1853+
client = Client(credentials=creds)
1854+
mock_send.return_value = httpx.Response(
1855+
status_code=200,
1856+
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]},
1857+
)
1858+
client.models.generate_content(model="test", contents="hello?")
1859+
mock_send.assert_called_once()
1860+
request = mock_send.call_args[0][1]
1861+
assert "Authorization" in request.headers
1862+
assert request.headers["Authorization"] == "Bearer refreshed_token"
1863+
assert "x-goog-api-key" not in request.headers
1864+
assert creds.refresh_count == 1
1865+
1866+
1867+
@requires_aiohttp
1868+
@pytest.mark.asyncio
1869+
@mock.patch.object(httpx.AsyncClient, "send", autospec=True)
1870+
async def test_async_refresh_credentials_mldev(mock_send):
1871+
creds = FakeCredentials(expired=True)
1872+
client = Client(credentials=creds)
1873+
mock_send.return_value = httpx.Response(
1874+
status_code=200,
1875+
json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]},
1876+
)
1877+
await client.aio.models.generate_content(model="test", contents="hello?")
1878+
mock_send.assert_called_once()
1879+
request = mock_send.call_args[0][1]
1880+
assert "Authorization" in request.headers
1881+
assert request.headers["Authorization"] == "Bearer refreshed_token"
1882+
assert "x-goog-api-key" not in request.headers
1883+
assert creds.refresh_count == 1

0 commit comments

Comments
 (0)