|
47 | 47 | ) |
48 | 48 |
|
49 | 49 |
|
| 50 | +class FakeCredentials(credentials.Credentials): |
| 51 | + def __init__(self, token="fake_token", expired=False, quota_project_id=None): |
| 52 | + super().__init__() |
| 53 | + self.token = token |
| 54 | + self._expired = expired |
| 55 | + self._quota_project_id = quota_project_id |
| 56 | + self.refresh_count = 0 |
| 57 | + |
| 58 | + @property |
| 59 | + def expired(self): |
| 60 | + return self._expired |
| 61 | + |
| 62 | + @property |
| 63 | + def quota_project_id(self): |
| 64 | + return self._quota_project_id |
| 65 | + |
| 66 | + def refresh(self, request): |
| 67 | + self.refresh_count += 1 |
| 68 | + self.token = "refreshed_token" |
| 69 | + self._expired = False |
| 70 | + |
| 71 | + |
50 | 72 | @pytest.fixture(autouse=True) |
51 | 73 | def reset_has_aiohttp(): |
52 | 74 | yield |
@@ -1721,3 +1743,141 @@ async def test_get_aiohttp_session(): |
1721 | 1743 | assert initial_session is not None |
1722 | 1744 | session = await client._api_client._get_aiohttp_session() |
1723 | 1745 | assert session is initial_session |
| 1746 | + |
| 1747 | + |
| 1748 | +def test_missing_api_key_and_credentials(monkeypatch): |
| 1749 | + monkeypatch.setenv("GOOGLE_API_KEY", "") |
| 1750 | + with pytest.raises(ValueError, match="Missing key inputs argument!"): |
| 1751 | + Client() |
| 1752 | + |
| 1753 | + |
| 1754 | +auth_precedence_test_cases = [ |
| 1755 | + # client_args, env_vars, expected_headers |
| 1756 | + ( |
| 1757 | + {"credentials": FakeCredentials()}, |
| 1758 | + {"GOOGLE_API_KEY": "env_api_key"}, |
| 1759 | + {"Authorization": "Bearer fake_token"} |
| 1760 | + ), |
| 1761 | + ( |
| 1762 | + {"credentials": FakeCredentials(quota_project_id="quota-proj")}, |
| 1763 | + {"GOOGLE_API_KEY": "env_api_key"}, |
| 1764 | + { |
| 1765 | + "Authorization": "Bearer fake_token", |
| 1766 | + "x-goog-user-project": "quota-proj" |
| 1767 | + } |
| 1768 | + ), |
| 1769 | + ( |
| 1770 | + {"api_key": "test_api_key"}, |
| 1771 | + {"GOOGLE_API_KEY": "env_api_key"}, |
| 1772 | + {"x-goog-api-key": "test_api_key"} |
| 1773 | + ), |
| 1774 | + ( |
| 1775 | + {}, |
| 1776 | + {"GOOGLE_API_KEY": "env_api_key"}, |
| 1777 | + {"x-goog-api-key": "env_api_key"} |
| 1778 | + ), |
| 1779 | +] |
| 1780 | + |
| 1781 | + |
| 1782 | +@pytest.mark.parametrize( |
| 1783 | + ["client_kwargs", "env_vars", "expected_headers"], |
| 1784 | + auth_precedence_test_cases, |
| 1785 | +) |
| 1786 | +@mock.patch.object(httpx.Client, "send", autospec=True) |
| 1787 | +def test_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): |
| 1788 | + for key, value in env_vars.items(): |
| 1789 | + monkeypatch.setenv(key, value) |
| 1790 | + |
| 1791 | + client = Client(**client_kwargs) |
| 1792 | + mock_send.return_value = httpx.Response( |
| 1793 | + status_code=200, |
| 1794 | + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} |
| 1795 | + ) |
| 1796 | + client.models.generate_content(model="test", contents="hello?") |
| 1797 | + mock_send.assert_called_once() |
| 1798 | + request = mock_send.call_args[0][1] |
| 1799 | + |
| 1800 | + for key, value in expected_headers.items(): |
| 1801 | + assert key in request.headers |
| 1802 | + assert request.headers[key] == value |
| 1803 | + |
| 1804 | + if "Authorization" in expected_headers: |
| 1805 | + assert "x-goog-api-key" not in request.headers |
| 1806 | + if "x-goog-api-key" in expected_headers: |
| 1807 | + assert "Authorization" not in request.headers |
| 1808 | + if "x-goog-user-project" not in expected_headers: |
| 1809 | + assert "x-goog-user-project" not in request.headers |
| 1810 | + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' |
| 1811 | + |
| 1812 | +@pytest.mark.parametrize( |
| 1813 | + ["client_kwargs", "env_vars", "expected_headers"], |
| 1814 | + auth_precedence_test_cases, |
| 1815 | +) |
| 1816 | +@pytest.mark.asyncio |
| 1817 | +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) |
| 1818 | +async def test_async_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): |
| 1819 | + for key, value in env_vars.items(): |
| 1820 | + monkeypatch.setenv(key, value) |
| 1821 | + |
| 1822 | + client = Client(**client_kwargs) |
| 1823 | + mock_send.return_value = httpx.Response( |
| 1824 | + status_code=200, |
| 1825 | + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} |
| 1826 | + ) |
| 1827 | + await client.aio.models.generate_content(model="test", contents="hello?") |
| 1828 | + mock_send.assert_called_once() |
| 1829 | + request = mock_send.call_args[0][1] |
| 1830 | + |
| 1831 | + for key, value in expected_headers.items(): |
| 1832 | + assert key in request.headers |
| 1833 | + assert request.headers[key] == value |
| 1834 | + |
| 1835 | + if "Authorization" in expected_headers: |
| 1836 | + assert "x-goog-api-key" not in request.headers |
| 1837 | + if "x-goog-api-key" in expected_headers: |
| 1838 | + assert "Authorization" not in request.headers |
| 1839 | + if "x-goog-user-project" not in expected_headers: |
| 1840 | + assert "x-goog-user-project" not in request.headers |
| 1841 | + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' |
| 1842 | + |
| 1843 | + |
| 1844 | +async def test_both_credentials_mldev(): |
| 1845 | + with pytest.raises(ValueError, match="mutually exclusive"): |
| 1846 | + creds = FakeCredentials(expired=True) |
| 1847 | + client = Client(credentials=creds, api_key="test-api-key") |
| 1848 | + |
| 1849 | + |
| 1850 | +@mock.patch.object(httpx.Client, "send", autospec=True) |
| 1851 | +def test_refresh_credentials_mldev(mock_send): |
| 1852 | + creds = FakeCredentials(expired=True) |
| 1853 | + client = Client(credentials=creds) |
| 1854 | + mock_send.return_value = httpx.Response( |
| 1855 | + status_code=200, |
| 1856 | + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, |
| 1857 | + ) |
| 1858 | + client.models.generate_content(model="test", contents="hello?") |
| 1859 | + mock_send.assert_called_once() |
| 1860 | + request = mock_send.call_args[0][1] |
| 1861 | + assert "Authorization" in request.headers |
| 1862 | + assert request.headers["Authorization"] == "Bearer refreshed_token" |
| 1863 | + assert "x-goog-api-key" not in request.headers |
| 1864 | + assert creds.refresh_count == 1 |
| 1865 | + |
| 1866 | + |
| 1867 | +@requires_aiohttp |
| 1868 | +@pytest.mark.asyncio |
| 1869 | +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) |
| 1870 | +async def test_async_refresh_credentials_mldev(mock_send): |
| 1871 | + creds = FakeCredentials(expired=True) |
| 1872 | + client = Client(credentials=creds) |
| 1873 | + mock_send.return_value = httpx.Response( |
| 1874 | + status_code=200, |
| 1875 | + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, |
| 1876 | + ) |
| 1877 | + await client.aio.models.generate_content(model="test", contents="hello?") |
| 1878 | + mock_send.assert_called_once() |
| 1879 | + request = mock_send.call_args[0][1] |
| 1880 | + assert "Authorization" in request.headers |
| 1881 | + assert request.headers["Authorization"] == "Bearer refreshed_token" |
| 1882 | + assert "x-goog-api-key" not in request.headers |
| 1883 | + assert creds.refresh_count == 1 |
0 commit comments