diff --git a/flocks/server/routes/provider.py b/flocks/server/routes/provider.py index bf13f016..ef00afcb 100644 --- a/flocks/server/routes/provider.py +++ b/flocks/server/routes/provider.py @@ -2320,9 +2320,15 @@ async def test_provider_credentials(provider_id: str, body: Optional[TestCredent requested_model_id = body.model_id if body else None test_model_id = requested_model_id or models[0].id - # Validate model belongs to this provider + # Validate model belongs to this provider. Azure OpenAI is the + # exception: users may test a deployment name before saving it. valid_ids = {m.id for m in models} - if test_model_id not in valid_ids: + is_unsaved_azure_deployment = ( + requested_model_id + and provider_id in {"azure-openai", "azure"} + and test_model_id not in valid_ids + ) + if test_model_id not in valid_ids and not is_unsaved_azure_deployment: response = { "success": False, "message": f"模型 '{test_model_id}' 不属于该 Provider", diff --git a/tests/provider/test_azure_provider.py b/tests/provider/test_azure_provider.py new file mode 100644 index 00000000..d23d426d --- /dev/null +++ b/tests/provider/test_azure_provider.py @@ -0,0 +1,33 @@ +from flocks.provider.provider import ModelCapabilities, ModelInfo +from flocks.provider.sdk.azure import AzureProvider + + +def test_azure_provider_returns_configured_deployment_models(): + provider = AzureProvider() + provider._config_models = [ + ModelInfo( + id="customer-prod-deployment", + name="Customer Production Deployment", + provider_id="azure", + capabilities=ModelCapabilities( + supports_tools=True, + supports_streaming=True, + context_window=128000, + max_tokens=4096, + ), + ) + ] + + models = provider.get_models() + + assert [m.id for m in models] == ["customer-prod-deployment"] + assert models[0].name == "Customer Production Deployment" + + +def test_azure_provider_returns_fallback_models_without_config(): + provider = AzureProvider() + + models = provider.get_models() + + assert {m.id for m in models} == {"gpt-5.4", "gpt-5-mini"} + assert all(m.provider_id == "azure" for m in models) diff --git a/tests/provider/test_test_credentials.py b/tests/provider/test_test_credentials.py index e01d5db7..aa7e7dd2 100644 --- a/tests/provider/test_test_credentials.py +++ b/tests/provider/test_test_credentials.py @@ -754,3 +754,83 @@ async def test_existing_custom_settings_are_preserved_during_provider_test(self) assert configured.api_key == "gateway-api-key" assert configured.base_url == "https://gateway.internal/v1" assert configured.custom_settings["verify_ssl"] is False + + @pytest.mark.asyncio + async def test_requested_azure_deployment_model_is_used_for_provider_test(self): + from flocks.server.routes.provider import TestCredentialRequest, test_provider_credentials + + provider = MagicMock() + provider._config = MagicMock( + custom_settings={}, + base_url="https://example-resource.openai.azure.com/", + ) + provider.chat = AsyncMock(return_value=MagicMock(content="Paris")) + + model = MagicMock() + model.id = "customer-prod-deployment" + + mock_secrets = MagicMock() + mock_secrets.get.return_value = "azure-api-key" + + mock_config = MagicMock() + + with ( + patch(_PATCH_SECRET_MGR, return_value=mock_secrets), + patch(_PATCH_CONFIG_GET, new_callable=AsyncMock, return_value=mock_config), + patch(_PATCH_PROVIDER) as mock_provider_cls, + ): + mock_provider_cls._ensure_initialized = MagicMock() + mock_provider_cls._load_dynamic_providers = MagicMock() + mock_provider_cls.apply_config = AsyncMock() + mock_provider_cls.get.return_value = provider + mock_provider_cls.list_models.return_value = [model] + + result = await test_provider_credentials( + "azure-openai", + TestCredentialRequest(model_id="customer-prod-deployment"), + ) + + assert result["success"] is True, result + assert result["model_id"] == "customer-prod-deployment" + provider.chat.assert_awaited_once() + assert provider.chat.await_args.args[0] == "customer-prod-deployment" + + @pytest.mark.asyncio + async def test_unsaved_azure_deployment_can_be_tested_without_model_definition(self): + from flocks.server.routes.provider import TestCredentialRequest, test_provider_credentials + + provider = MagicMock() + provider._config = MagicMock( + custom_settings={}, + base_url="https://example-resource.openai.azure.com/", + ) + provider.chat = AsyncMock(return_value=MagicMock(content="Paris")) + + catalog_model = MagicMock() + catalog_model.id = "gpt-5.4" + + mock_secrets = MagicMock() + mock_secrets.get.return_value = "azure-api-key" + + mock_config = MagicMock() + + with ( + patch(_PATCH_SECRET_MGR, return_value=mock_secrets), + patch(_PATCH_CONFIG_GET, new_callable=AsyncMock, return_value=mock_config), + patch(_PATCH_PROVIDER) as mock_provider_cls, + ): + mock_provider_cls._ensure_initialized = MagicMock() + mock_provider_cls._load_dynamic_providers = MagicMock() + mock_provider_cls.apply_config = AsyncMock() + mock_provider_cls.get.return_value = provider + mock_provider_cls.list_models.return_value = [catalog_model] + + result = await test_provider_credentials( + "azure-openai", + TestCredentialRequest(model_id="unsaved-prod-deployment"), + ) + + assert result["success"] is True, result + assert result["model_id"] == "unsaved-prod-deployment" + provider.chat.assert_awaited_once() + assert provider.chat.await_args.args[0] == "unsaved-prod-deployment" diff --git a/tests/server/routes/test_custom_provider_runtime.py b/tests/server/routes/test_custom_provider_runtime.py index 0e8baaee..4899ce6e 100644 --- a/tests/server/routes/test_custom_provider_runtime.py +++ b/tests/server/routes/test_custom_provider_runtime.py @@ -1,4 +1,5 @@ from flocks.provider.provider import ModelCapabilities, ModelInfo, Provider +from flocks.provider.sdk.azure import AzureProvider from flocks.server.routes.custom_provider import CreateModelReq, _add_model_to_runtime @@ -48,3 +49,35 @@ class DummyProvider: assert provider._config_models[0].capabilities.supports_reasoning is True finally: Provider._models = original_models + + +def test_add_azure_deployment_to_runtime_config_models(monkeypatch): + provider = AzureProvider() + provider.id = "azure-openai" + provider._config_models = [] + body = CreateModelReq( + model_id="customer-prod-deployment", + name="Customer Production Deployment", + context_window=128000, + max_output_tokens=4096, + supports_vision=False, + supports_tools=True, + supports_streaming=True, + supports_reasoning=False, + input_price=0.0, + output_price=0.0, + currency="USD", + ) + + original_models = Provider._models + Provider._models = {} + monkeypatch.setattr(Provider, "get", classmethod(lambda cls, provider_id: provider)) + + try: + _add_model_to_runtime("azure-openai", body) + + assert Provider._models[body.model_id].provider_id == "azure-openai" + assert provider._config_models[0].id == "customer-prod-deployment" + assert provider._config_models[0].name == "Customer Production Deployment" + finally: + Provider._models = original_models diff --git a/webui/src/locales/en-US/model.json b/webui/src/locales/en-US/model.json index c2480436..abdf84f6 100644 --- a/webui/src/locales/en-US/model.json +++ b/webui/src/locales/en-US/model.json @@ -119,7 +119,16 @@ "loadFailed": "Failed to load provider catalog", "noModelsToTest": "No enabled models to test", "batchTestDone": "Batch test complete", - "batchTestSummary": "{{success}} succeeded, {{failed}} failed" + "batchTestSummary": "{{success}} succeeded, {{failed}} failed", + "azureDeploymentName": "Azure Deployment Name", + "azureDeploymentPlaceholder": "e.g. my-gpt-4o-prod", + "azureDeploymentHint": "Azure OpenAI requests use the deployment name, not a fixed model name. The preset models are examples; enter your own deployment name here.", + "azureDeploymentDisplayName": "Display Name (optional)", + "azureDeploymentDisplayPlaceholder": "e.g. GPT-4o Production", + "azureDeploymentRequired": "Select at least one preset model or enter an Azure deployment name", + "azureModelIdHint": "For Azure OpenAI, Model ID should be the deployment name from Azure Portal.", + "azureCustomDeployments": "Custom Azure Deployments", + "azureNoCustomDeployments": "No custom Azure deployment has been added yet." }, "wizard": { "providerSaved": "Provider Saved", diff --git a/webui/src/locales/zh-CN/model.json b/webui/src/locales/zh-CN/model.json index 29cb71e4..768f69be 100644 --- a/webui/src/locales/zh-CN/model.json +++ b/webui/src/locales/zh-CN/model.json @@ -119,7 +119,16 @@ "loadFailed": "加载 Provider 目录失败", "noModelsToTest": "没有已启用的模型可测试", "batchTestDone": "批量测试完成", - "batchTestSummary": "{{success}} 成功, {{failed}} 失败" + "batchTestSummary": "{{success}} 成功, {{failed}} 失败", + "azureDeploymentName": "Azure 部署名称", + "azureDeploymentPlaceholder": "例如 my-gpt-4o-prod", + "azureDeploymentHint": "Azure OpenAI 请求使用 deployment name,而不是固定模型名。预设模型只是常用示例,你可以在这里填写自己的部署名称。", + "azureDeploymentDisplayName": "显示名称(可选)", + "azureDeploymentDisplayPlaceholder": "例如 GPT-4o Production", + "azureDeploymentRequired": "请至少选择一个预设模型,或填写 Azure deployment name", + "azureModelIdHint": "对于 Azure OpenAI,模型 ID 请填写 Azure Portal 中的 deployment name。", + "azureCustomDeployments": "自定义 Azure Deployments", + "azureNoCustomDeployments": "尚未添加自定义 Azure deployment。" }, "wizard": { "providerSaved": "Provider 已保存", diff --git a/webui/src/pages/Model/index.tsx b/webui/src/pages/Model/index.tsx index cb15baf3..eb59fb6b 100644 --- a/webui/src/pages/Model/index.tsx +++ b/webui/src/pages/Model/index.tsx @@ -55,6 +55,12 @@ function providerAllowsEmptyApiKey(providerId: string): boolean { ); } +const AZURE_PROVIDER_IDS = new Set(['azure-openai', 'azure']); + +function isAzureProviderId(providerId: string): boolean { + return AZURE_PROVIDER_IDS.has(providerId); +} + // ==================== Connection Cache ==================== const CONNECTION_CACHE_KEY = 'flocks_provider_connection_cache'; @@ -1088,6 +1094,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { const [baseUrl, setBaseUrl] = useState(''); const [description, setDescription] = useState(''); const [providerName, setProviderName] = useState(''); + const [azureDeploymentName, setAzureDeploymentName] = useState(''); + const [azureDeploymentDisplayName, setAzureDeploymentDisplayName] = useState(''); // Model selection (for catalog providers) const [selectedModelIds, setSelectedModelIds] = useState>(new Set()); @@ -1172,6 +1180,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { setDescription(provider.description || ''); setSelectedModelIds(new Set(provider.models.map(m => m.id))); setProviderName(''); + setAzureDeploymentName(''); + setAzureDeploymentDisplayName(''); } }; @@ -1212,7 +1222,8 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { base_url: baseUrl.trim() || undefined, provider_name: selectedCatalogId === 'openai-compatible' && providerName.trim() ? providerName.trim() : undefined, }); - const res = await providerAPI.testCredentials(selectedCatalogId); + const azureModelId = isAzureProviderId(selectedCatalogId) ? azureDeploymentName.trim() : ''; + const res = await providerAPI.testCredentials(selectedCatalogId, azureModelId || undefined); setTestResult({ success: res.data.success, message: res.data.message || (res.data.success ? t('status.connected') : t('form.testFailed')), @@ -1235,6 +1246,11 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { toast.warning('Please enter API Key'); return; } + const azureModelId = isAzureProviderId(selectedCatalogId) ? azureDeploymentName.trim() : ''; + if (isAzureProviderId(selectedCatalogId) && selectedModelIds.size === 0 && !azureModelId) { + toast.warning(t('form.azureDeploymentRequired')); + return; + } try { setSaving(true); if (selectedCatalogId === 'openai-compatible') { @@ -1259,6 +1275,20 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { const unselected = selectedCatalog.models.filter(m => !selectedModelIds.has(m.id)).map(m => m.id); await Promise.all(unselected.map(id => modelV2API.deleteDefinition(selectedCatalogId, id).catch(() => {}))); } + if (azureModelId) { + await modelV2API.createDefinition(selectedCatalogId, { + model_id: azureModelId, + name: azureDeploymentDisplayName.trim() || azureModelId, + }); + try { + const res = await providerAPI.testCredentials(selectedCatalogId, azureModelId); + if (!res.data.success) { + toast.warning(t('form.testFailed'), res.data.error || res.data.message); + } + } catch (testErr: any) { + toast.warning(t('form.testFailed'), testErr.response?.data?.detail || testErr.message); + } + } toast.success(t('providerAdded'), displayName); setSavedProviderId(selectedCatalogId); setSavedProviderName(displayName); @@ -1600,6 +1630,36 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: { )} + {isAzureProviderId(selectedCatalogId) && ( +
+
+ + setAzureDeploymentName(e.target.value)} + className="w-full px-3 py-2 border border-blue-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-300 text-sm bg-white" + placeholder={t('form.azureDeploymentPlaceholder')} + /> +

{t('form.azureDeploymentHint')}

+
+
+ + setAzureDeploymentDisplayName(e.target.value)} + className="w-full px-3 py-2 border border-blue-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-300 text-sm bg-white" + placeholder={azureDeploymentName.trim() || t('form.azureDeploymentDisplayPlaceholder')} + /> +
+
+ )} + {selectedCatalog.models.length > 0 && (
@@ -1712,7 +1772,13 @@ function AddProviderDialog({ connectedIds, onClose, onAdded }: {

{t('wizard.modelsAdded', { count: addedModelCount })}

)} - +
)} @@ -1791,10 +1857,12 @@ function useModelForm() { }; } -function ModelFormFields({ form, testResult, testing }: { +function ModelFormFields({ form, testResult, testing, modelIdPlaceholder, modelIdHint }: { form: ReturnType; testResult: { success: boolean; message: string; latency?: number } | null; testing: boolean; + modelIdPlaceholder?: string; + modelIdHint?: string; }) { const { t } = useTranslation('model'); return ( @@ -1809,8 +1877,9 @@ function ModelFormFields({ form, testResult, testing }: { value={form.modelId} onChange={e => form.setModelId(e.target.value)} className="w-full px-3 py-2 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-slate-400 text-sm" - placeholder="gpt-4o-custom" + placeholder={modelIdPlaceholder || 'gpt-4o-custom'} /> + {modelIdHint &&

{modelIdHint}

}
- + ); @@ -2085,7 +2160,20 @@ function ConfigureProviderDialog({ provider, existingCredentials, models, onClos // Catalog model management const [catalogModels, setCatalogModels] = useState([]); + const [catalogModelsLoaded, setCatalogModelsLoaded] = useState(false); const [selectedModelIds, setSelectedModelIds] = useState>(new Set(models.map(m => m.id))); + const [newAzureDeploymentName, setNewAzureDeploymentName] = useState(''); + const [newAzureDeploymentDisplayName, setNewAzureDeploymentDisplayName] = useState(''); + const isAzureProvider = isAzureProviderId(provider.id); + const catalogModelIds = useMemo(() => new Set(catalogModels.map(m => m.id)), [catalogModels]); + const selectedCatalogModelCount = useMemo( + () => catalogModels.filter(m => selectedModelIds.has(m.id)).length, + [catalogModels, selectedModelIds] + ); + const azureCustomModels = useMemo( + () => isAzureProvider && catalogModelsLoaded ? models.filter(m => !catalogModelIds.has(m.id)) : [], + [catalogModelIds, catalogModelsLoaded, isAzureProvider, models] + ); useEffect(() => { setApiKey(existingKey); @@ -2103,10 +2191,19 @@ function ConfigureProviderDialog({ provider, existingCredentials, models, onClos }, [provider.id, models]); useEffect(() => { + setCatalogModelsLoaded(false); catalogAPI.list().then(res => { const found = res.data.providers.find(p => p.id === provider.id); - if (found) setCatalogModels(found.models); - }).catch(() => {}); + if (found) { + setCatalogModels(found.models); + setCatalogModelsLoaded(true); + } else { + setCatalogModels([]); + } + }).catch(() => { + setCatalogModels([]); + setCatalogModelsLoaded(false); + }); }, [provider.id]); const handleToggleCatalogModel = (modelId: string) => { @@ -2148,6 +2245,13 @@ function ConfigureProviderDialog({ provider, existingCredentials, models, onClos ...toAdd.map(m => modelV2API.createDefinition(provider.id, { model_id: m.id, name: m.name }).catch(() => {})), ]); } + const azureModelId = newAzureDeploymentName.trim(); + if (isAzureProvider && azureModelId) { + await modelV2API.createDefinition(provider.id, { + model_id: azureModelId, + name: newAzureDeploymentDisplayName.trim() || azureModelId, + }); + } toast.success(t('credentialsSaved')); onConfigured(); @@ -2343,7 +2447,7 @@ ${hasExisting ? '你已有凭证配置,可以更新或测试连接。' : '请