diff --git a/AGENTS.md b/AGENTS.md index fea48ec..386245c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -68,6 +68,10 @@ If no new rule is detected -> do not update the file. - build: `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore` - analyze: `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore -p:RunAnalyzers=true` - test: `dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build` +- test-list: `dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build --list-tests` +- test-detailed: `dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build --output Detailed --no-progress` +- test-trx: `dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build --report-trx --results-directory ./artifacts/test-results` +- test-runner-help: `tests/ManagedCode.MCPGateway.Tests/bin/Release/net10.0/ManagedCode.MCPGateway.Tests --help` - format: `dotnet format ManagedCode.MCPGateway.slnx` - skills-validate: `python3 .codex/skills/mcaf-skill-curation/scripts/validate_skills.py .codex/skills` - skills-metadata: `python3 .codex/skills/mcaf-skill-curation/scripts/generate_available_skills.py .codex/skills --absolute` @@ -90,6 +94,7 @@ If no new rule is detected -> do not update the file. - Keep the gateway reusable as a NuGet library, not as an app-specific host. - Preserve one public execution surface for local `AITool` instances and MCP tools. - Preserve one searchable catalog that supports vector ranking when embeddings are available and lexical fallback when they are not. +- For multilingual or noisy search inputs, prefer a generic English-normalization step before ranking when an AI/query-rewrite component is available, because the user wants the searchable representation to converge to English instead of relying only on language-specific token overlap. - Keep meta-tools available through `McpGatewayToolSet` and `IMcpGateway.CreateMetaTools(...)`. - If a user adds or corrects a persistent workflow rule, update `AGENTS.md` first and only then continue with the task. @@ -97,12 +102,17 @@ If no new rule is detected -> do not update the file. - `src/ManagedCode.MCPGateway/` contains the package source. - `tests/ManagedCode.MCPGateway.Tests/` contains integration-style package tests. -- `src/ManagedCode.MCPGateway/Abstractions/` contains public interfaces. -- `src/ManagedCode.MCPGateway/Models/` contains public contracts and internal source registrations. +- `src/ManagedCode.MCPGateway/Abstractions/` contains public interfaces, grouped by concern when needed. +- `src/ManagedCode.MCPGateway/Configuration/` contains public configuration types and service keys. +- `src/ManagedCode.MCPGateway/Models/` contains public contracts grouped by behavior such as search, invocation, catalog, and embeddings. +- `src/ManagedCode.MCPGateway/Embeddings/` contains public embedding-store implementations. +- `src/ManagedCode.MCPGateway/Internal/` contains internal catalog, runtime, and helper implementation details. - `src/ManagedCode.MCPGateway/Registration/` contains DI registration extensions. -- `src/ManagedCode.MCPGateway/McpGateway.cs` is the main runtime implementation. +- `src/ManagedCode.MCPGateway/McpGateway.cs` is the public gateway facade. +- `src/ManagedCode.MCPGateway/Internal/Runtime/` contains the internal runtime orchestration implementation, grouped by core, catalog, search, invocation, and embeddings responsibilities. - `src/ManagedCode.MCPGateway/McpGatewayToolSet.cs` exposes the gateway as reusable `AITool` meta-tools. - `.codex/skills/` contains repo-local MCAF skills for Codex. +- Keep the source tree explicitly modular: separate public API folders from `Internal/` implementation folders, and group runtime classes by responsibility in dedicated folders instead of dumping search, indexing, invocation, registry, and infrastructure files into the package root, because flat structure hides boundaries and invites god-object design. ### Skills (ALL TASKS) @@ -116,6 +126,9 @@ If no new rule is detected -> do not update the file. ### Documentation (ALL TASKS) - Update `README.md` whenever public API shape, setup, or usage changes. +- For non-trivial architecture, runtime-flow, or cross-cutting search changes, always add or update an ADR under `docs/ADR/`, update `docs/Architecture/Overview.md`, and keep `README.md` synchronized with the shipped behavior and examples so the docs describe the real package rather than an older design snapshot. +- When the package requires an initialization step such as index building, provide an ergonomic optional integration path (for example DI extension or hosted background warmup) instead of forcing every consumer to call it manually, and document when manual initialization is still appropriate. +- Keep documented configuration defaults synchronized with the actual `McpGatewayOptions` defaults; for example, `MaxSearchResults` default is `15`, not stale sample values. - Keep the README focused on package usage and onboarding, not internal implementation notes. - Document optional DI dependencies explicitly in README examples so consumers know which services they must register themselves, such as embedding generators. - Keep README code examples as real example code blocks, not commented-out pseudo-code; if behavior is optional, show it in a separate example instead of commenting lines inside another snippet. @@ -128,8 +141,10 @@ If no new rule is detected -> do not update the file. ### Testing (ALL TASKS) - Test framework in this repository is TUnit. Never add or keep xUnit here. +- This repository uses `TUnit` on `Microsoft.Testing.Platform`; never use VSTest-only flags such as `--filter` or `--logger`, because they are not supported here. - For TUnit solution runs, always invoke `dotnet test --solution ...`; do not pass the solution path positionally. - Every behavior change must include or update tests in `tests/ManagedCode.MCPGateway.Tests/`. +- Add tests only when they close a meaningful behavior or regression gap; avoid low-signal tests that only increase count without improving confidence. - Keep tests focused on real gateway behavior: - local tool indexing and invocation - MCP tool indexing and invocation @@ -139,6 +154,7 @@ If no new rule is detected -> do not update the file. - Keep request context behavior covered when search or invocation consumes contextual inputs. - Do not remove tests to get green builds. - Keep `global.json` configured for `Microsoft.Testing.Platform` when TUnit is used. +- At the end of implementation work, run code-size and quality verification with `cloc`, `roslynator`, and the repository's strict .NET build/test checks, then fix actionable findings so oversized files and quality drift do not accumulate. - Run verification in this order: - restore - build @@ -149,12 +165,30 @@ If no new rule is detected -> do not update the file. - Follow `.editorconfig` and repository analyzers. - Keep warnings clean; repository builds treat warnings as errors. - Prefer simple, readable C# over clever abstractions. +- Prefer modern C# 14 syntax when it improves clarity and keep replacing stale legacy syntax with current idiomatic language constructs instead of preserving older forms by inertia. +- Prefer straightforward DI-native constructors in public types; avoid redundant constructor chaining that only wraps `new SomeRuntime(...)` behind a second constructor, because in modern C# this adds ceremony without improving clarity. +- In hot runtime paths, prefer single-pass loops over allocation-heavy LINQ chains when the logic is simple, because duplicate enumeration and transient allocations have already been called out as unacceptable in this repository. +- Avoid open-ended `while (true)` loops in runtime code when a real termination condition exists; use an explicit condition such as cancellation or lifecycle state so concurrency code stays auditable. +- Avoid transient collection + `string.Join` assembly in hot runtime string paths; build the final string directly when only a few optional segments exist, because these extra allocations have already been called out as wasteful in this repository. +- Prefer readable imperative conditionals over long multi-line boolean expression bodies; if a predicate stops being obvious at a glance, split it into guard clauses or named locals instead of compressing it into one chained return expression. +- Prefer non-blocking coordination over coarse locking when practical; use concurrent collections, atomic state, and single-flight patterns instead of `lock`-heavy designs, because blocking synchronization has already proven to obscure concurrency behavior in this package. +- Keep concurrency coordination intention-revealing: avoid opaque fields such as generic drain/task signals inside runtime services when a named helper or clearer lifecycle abstraction can express the behavior, because hidden synchronization state quickly turns registry/runtime code into unreadable infrastructure. +- Prefer serializer-first JSON/schema handling; avoid ad-hoc manual special cases for `JsonElement`/`JsonNode`/schema objects when normal `System.Text.Json` serialization can represent them correctly. +- For JSON and schema payloads, always route serialization through the repository's canonical JSON converter/options path; do not hand-roll ad-hoc `JsonSerializer.Serialize*` handling inside feature code when the package already defines how JSON should be materialized. +- For context/object flattening, do not maintain parallel per-type serialization trees by hand; normalize once through the canonical JSON path and traverse the normalized representation, because duplicated type-switch logic drifts and keeps reintroducing ad-hoc serialization. +- Prefer explicit SOLID object decomposition over large `partial` types; when responsibilities like registry, indexing, invocation, or schema handling can live in dedicated classes, extract real collaborators instead of only splitting files. +- Keep `McpGateway` focused on search/invoke orchestration only; do not embed registry or mutation responsibilities into the gateway type itself, because that mixes lifecycle/catalog mutation with runtime execution concerns. - Keep public API names aligned with package identity `ManagedCode.MCPGateway`. - Do not duplicate package metadata or version blocks inside project files unless a project-specific override is required. - Use constants for stable tool names and protocol-facing identifiers. -- Never leave stable string literals inline in runtime code; extract named constants for diagnostic codes, messages, modes, keys, and other durable identifiers so changes stay centralized. +- Never leave stable string literals inline in runtime code; extract named constants for diagnostic codes, messages, modes, keys, tool descriptions, and other durable identifiers so changes stay centralized. +- Use the correct contextual logger type for each service; internal collaborators must log with their own type category instead of reusing a parent facade logger, because wrong logger categories make diagnostics misleading. - Keep transport-specific logic inside the gateway and source registration abstractions, not scattered across the codebase. - Keep the package dependency surface small and justified. +- Prefer direct generic DI registrations such as `services.TryAddSingleton()` over lambda alias registrations when wiring package services, because the lambda style has already been called out as unreadable and error-prone in this repository. +- Keep runtime services DI-native from their public/internal constructors; types such as `McpGatewayRegistry` must be creatable through `IOptions` and other DI-managed dependencies rather than ad-hoc state-only constructors, because the package design requires services to live fully inside the container. +- When emitting package identity to external protocols such as MCP client info, never hardcode a fake version string; use the actual assembly/build version so runtime metadata stays aligned with the package being shipped. +- For search-quality improvements, prefer mathematical or statistical ranking changes over hardcoded phrase lists or ad-hoc query text hacks, because the user explicitly wants tokenizer search to improve through general scoring behavior rather than manual exceptions. ### Critical (NEVER violate) diff --git a/Directory.Build.props b/Directory.Build.props index d982c42..dade15e 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -11,7 +11,7 @@ latest-recommended true $(NoWarn);CS1591;CA1707;CA1848;CA1859;CA1873 - 0.1.1 + 0.2.0 $(Version) diff --git a/Directory.Packages.props b/Directory.Packages.props index 7c77703..f8bcc20 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -8,12 +8,15 @@ + + + - \ No newline at end of file + diff --git a/README.md b/README.md index 97e896d..63a7f76 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ The package is built on: - `Microsoft.Extensions.AI` - the official `ModelContextProtocol` .NET SDK -- in-memory descriptor indexing with optional embedding-based ranking +- in-memory descriptor indexing with vector ranking and built-in tokenizer-backed fallback ## Install @@ -20,13 +20,27 @@ The package is built on: dotnet add package ManagedCode.MCPGateway ``` +## Architecture And Decision Records + +- [Architecture overview](docs/Architecture/Overview.md) +- [ADR-0001: Runtime boundaries and index lifecycle](docs/ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md) +- [ADR-0002: Search ranking and query normalization](docs/ADR/ADR-0002-search-ranking-and-query-normalization.md) +- [Feature spec: Search query normalization and ranking](docs/Features/SearchQueryNormalizationAndRanking.md) + ## What You Get -- one registry for local tools, stdio MCP servers, HTTP MCP servers, or prebuilt `McpClient` instances +- one registry for local tools, stdio MCP servers, HTTP MCP servers, existing `McpClient` instances, or deferred `McpClient` factories +- a DI-native split between `IMcpGateway` for runtime search/invoke and `IMcpGatewayRegistry` for catalog mutation - descriptor indexing that enriches search with tool name, description, required arguments, and input schema +- lazy index build on the first catalog/search/invoke operation, plus optional eager warmup hooks for startup scenarios +- configurable search strategy with embeddings or tokenizer-backed heuristic ranking +- `SearchStrategy.Auto` by default: use embeddings when available, otherwise fall back to tokenizer-backed ranking automatically +- built-in `ChatGptO200kBase` tokenizer path for tokenizer search and tokenizer fallback +- optional English query normalization before ranking when a keyed search rewrite `IChatClient` is registered +- top 5 matches by default when `maxResults` is not specified - vector search when an `IEmbeddingGenerator>` is registered - optional persisted tool embeddings through `IMcpGatewayToolEmbeddingStore` -- lexical fallback when embeddings are unavailable +- token-aware lexical fallback when embeddings are unavailable or vector search cannot complete - one invoke surface for both local `AIFunction` tools and MCP tools - optional meta-tools you can hand back to another model as normal `AITool` instances @@ -61,18 +75,197 @@ services.AddManagedCodeMcpGateway(options => await using var serviceProvider = services.BuildServiceProvider(); var gateway = serviceProvider.GetRequiredService(); -await gateway.BuildIndexAsync(); - -var search = await gateway.SearchAsync("find github repositories", maxResults: 3); +var search = await gateway.SearchAsync("find github repositories"); var selectedTool = search.Matches[0]; var invoke = await gateway.InvokeAsync(new McpGatewayInvokeRequest( ToolId: selectedTool.ToolId, - Query: "managedcode")); + Query: "managedcode")); ``` `AddManagedCodeMcpGateway(...)` does not create or configure an embedding generator for you. Vector ranking is enabled only when the same DI container also has an `IEmbeddingGenerator>`. The gateway first tries the keyed registration `McpGatewayServiceKeys.EmbeddingGenerator` and falls back to any regular registration. Otherwise it stays fully functional and uses lexical ranking. +The gateway builds its catalog lazily on the first `ListToolsAsync(...)`, `SearchAsync(...)`, or `InvokeAsync(...)` call. If you add more tools later through the registry, the next catalog/search/invoke operation rebuilds the index automatically. You only need an explicit warmup call when you want eager startup validation or a pre-warmed cache. + +`McpGateway` is the runtime search/invoke facade. If you need to add tools or MCP sources after the container is built, resolve `IMcpGatewayRegistry` separately: + +```csharp +var registry = serviceProvider.GetRequiredService(); + +registry.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"weather:{query}", + new AIFunctionFactoryOptions + { + Name = "weather_search_forecast", + Description = "Search weather forecast and temperature information by city name." + })); + +var tools = await gateway.ListToolsAsync(); +``` + +`AddManagedCodeMcpGateway(...)` registers `IMcpGateway`, `IMcpGatewayRegistry`, and `McpGatewayToolSet`. Add `AddManagedCodeMcpGatewayIndexWarmup()` only when you want hosted eager initialization. + +## Public Surfaces + +Resolve these services depending on what the host needs: + +- `IMcpGateway`: build, list, search, invoke, and create meta-tools +- `IMcpGatewayRegistry`: add local tools or MCP sources after the container is built +- `McpGatewayToolSet`: expose the gateway itself as reusable `AITool` instances + +Those three services deliberately separate runtime execution, catalog mutation, and meta-tool creation instead of collapsing everything into one mutable gateway type. + +## Register Existing Or Deferred MCP Clients + +`IMcpGatewayRegistry` supports both immediate `McpClient` instances and deferred client factories: + +```csharp +var registry = serviceProvider.GetRequiredService(); + +registry.AddMcpClient( + sourceId: "issues", + client: existingClient, + disposeClient: false); + +registry.AddMcpClientFactory( + sourceId: "work-items", + clientFactory: static async cancellationToken => + { + return await CreateWorkItemClientAsync(cancellationToken); + }); +``` + +Use `AddMcpClient(...)` when another part of the host already owns the client lifetime. Use `AddMcpClientFactory(...)` when the gateway should lazily create and cache the client through its normal source-loading path. + +## Invoke By Tool Id Or Stable Identity + +The common flow is search first, then invoke by `ToolId`: + +```csharp +var search = await gateway.SearchAsync("find github repositories"); +var invoke = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: search.Matches[0].ToolId, + Query: "managedcode")); +``` + +If the host already knows the stable tool name, invocation can target `ToolName` and optionally `SourceId` instead: + +```csharp +var invoke = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolName: "github_search_repositories", + SourceId: "local", + Query: "managedcode")); +``` + +Use `SourceId` when the same tool name may exist in more than one registered source. + +## Optional Eager Warmup + +If you want to warm the catalog immediately after building the container, use the service-provider extension: + +```csharp +await using var serviceProvider = services.BuildServiceProvider(); + +var build = await serviceProvider.InitializeManagedCodeMcpGatewayAsync(); +``` + +`InitializeManagedCodeMcpGatewayAsync()` returns `McpGatewayIndexBuildResult`, so startup code can inspect diagnostics or fail fast explicitly. + +For hosted applications, register background warmup once and let the host trigger it on startup: + +```csharp +var services = new ServiceCollection(); + +services.AddManagedCodeMcpGateway(options => +{ + options.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"github:{query}", + new AIFunctionFactoryOptions + { + Name = "github_search_repositories", + Description = "Search GitHub repositories by user query." + })); +}); + +services.AddManagedCodeMcpGatewayIndexWarmup(); +``` + +Use eager warmup when you want fail-fast startup behavior, a warmed cache before the first request, or deterministic startup benchmarking. Otherwise the lazy default is enough. + +## Recommended Hosted Setup + +This example shows the full production-oriented integration shape in one place. Remove the optional registrations if your host does not need vector search, query normalization, or persistent embedding reuse. + +```csharp +using ManagedCode.MCPGateway; +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +var services = new ServiceCollection(); + +services.AddKeyedSingleton>, MyEmbeddingGenerator>( + McpGatewayServiceKeys.EmbeddingGenerator); +services.AddKeyedSingleton( + McpGatewayServiceKeys.SearchQueryChatClient); +services.AddSingleton(); + +services.AddManagedCodeMcpGateway(options => +{ + options.SearchStrategy = McpGatewaySearchStrategy.Auto; + options.SearchQueryNormalization = McpGatewaySearchQueryNormalization.TranslateToEnglishWhenAvailable; + + options.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"github:{query}", + new AIFunctionFactoryOptions + { + Name = "github_search_repositories", + Description = "Search GitHub repositories by user query." + })); + + options.AddHttpServer( + sourceId: "docs", + endpoint: new Uri("https://example.com/mcp")); +}); + +services.AddManagedCodeMcpGatewayIndexWarmup(); + +await using var serviceProvider = services.BuildServiceProvider(); + +var gateway = serviceProvider.GetRequiredService(); +var registry = serviceProvider.GetRequiredService(); +var metaTools = serviceProvider.GetRequiredService().CreateTools(); + +registry.AddTool( + "runtime", + AIFunctionFactory.Create( + static (string query) => $"status:{query}", + new AIFunctionFactoryOptions + { + Name = "project_status_lookup", + Description = "Look up project status by identifier or short title." + })); + +var search = await gateway.SearchAsync( + new McpGatewaySearchRequest( + Query: "review qeue for managedcode prs", + ContextSummary: "User is looking at repository maintenance work")); +``` + +Notes: + +- `SearchStrategy.Auto` is the default and is usually the right production setting. +- the embedding generator, search-query rewrite client, and embedding store are all optional DI integrations +- hosted warmup is optional; if you omit it, the gateway builds its catalog lazily on first use +- runtime registrations through `IMcpGatewayRegistry` invalidate the catalog automatically, so the next list/search/invoke call rebuilds the index +- `McpGatewayToolSet` and `gateway.CreateMetaTools()` expose the same meta-tools in two integration styles + ## Context-Aware Search And Invoke When the current turn has extra UI, workflow, or chat context, pass it through the request models: @@ -113,6 +306,21 @@ You can expose the gateway itself as two reusable `AITool` instances: var tools = gateway.CreateMetaTools(); ``` +Or resolve the reusable helper from DI: + +```csharp +var toolSet = serviceProvider.GetRequiredService(); +var tools = toolSet.CreateTools(); +``` + +Custom stable tool names are supported: + +```csharp +var tools = gateway.CreateMetaTools( + searchToolName: "workspace_tool_search", + invokeToolName: "workspace_tool_invoke"); +``` + By default this creates: - `gateway_tools_search` @@ -130,10 +338,201 @@ These tools are useful when another model should first search the gateway catalo - required arguments - input schema summaries -If an embedding generator is registered, the gateway vectorizes those descriptor documents and uses cosine similarity plus a small lexical boost. It first tries the keyed registration `McpGatewayServiceKeys.EmbeddingGenerator` and then falls back to any regular `IEmbeddingGenerator>`. If no embedding generator is present, it falls back to lexical ranking without disabling execution. +Default search profile: + +- `SearchStrategy = McpGatewaySearchStrategy.Auto` +- `SearchQueryNormalization = McpGatewaySearchQueryNormalization.TranslateToEnglishWhenAvailable` +- `DefaultSearchLimit = 5` +- `MaxSearchResults = 15` + +`McpGatewaySearchStrategy.Auto` means: + +- vector search when an embedding generator is registered +- tokenizer-backed heuristic search when embeddings are unavailable +- tokenizer-backed fallback when vector search cannot complete for a request + +The tokenizer-backed mode builds field-aware search documents from tool names, display names, descriptions, required arguments, and schema properties. Ranking then happens in two stages: + +- stage 1 retrieval with BM25-style field scoring, tokenizer-term cosine similarity, and character 3-gram similarity +- stage 2 reranking over the candidate pool with calibrated coverage, lexical similarity, approximate typo matching, and tool-name evidence + +This keeps the search mathematical and tokenizer-driven instead of relying on hand-written query phrase exceptions. The tokenizer-backed path uses the built-in `ChatGptO200kBase` profile for the GPT-4o / ChatGPT tokenizer family. + +There is no public tokenizer-selection option. The package ships one built-in tokenizer-backed path and keeps the behavior configurable through search strategy, optional embeddings, and optional English query normalization. + +If an embedding generator is registered and vector search is active, the gateway vectorizes descriptor documents and uses cosine similarity plus lexical boosts. It first tries the keyed registration `McpGatewayServiceKeys.EmbeddingGenerator` and then falls back to any regular `IEmbeddingGenerator>`. The embedding generator is resolved per gateway operation, so singleton, scoped, and transient DI registrations all work with index builds and search. +### Reading Search Diagnostics + +`McpGatewaySearchResult` exposes both the ranking mode and diagnostics for the chosen path: + +```csharp +var result = await gateway.SearchAsync("review qeue for managedcode prs"); + +Console.WriteLine(result.RankingMode); + +foreach (var diagnostic in result.Diagnostics) +{ + Console.WriteLine($"{diagnostic.Code}: {diagnostic.Message}"); +} +``` + +Common diagnostics: + +- `query_normalized` +- `lexical_fallback` +- `vector_search_failed` + +## Optional English Query Normalization + +By default, the gateway may rewrite the incoming search query into concise English before ranking: + +- it only happens when `options.SearchQueryNormalization` is enabled +- it only uses a keyed `IChatClient` registered as `McpGatewayServiceKeys.SearchQueryChatClient` +- if no keyed chat client is registered, search continues unchanged +- if normalization fails, search continues with the original query and emits a diagnostic + +Preferred registration: + +```csharp +var services = new ServiceCollection(); + +services.AddKeyedSingleton( + McpGatewayServiceKeys.SearchQueryChatClient, + mySearchRewriteChatClient); + +services.AddManagedCodeMcpGateway(options => +{ + options.SearchStrategy = McpGatewaySearchStrategy.Auto; + options.SearchQueryNormalization = McpGatewaySearchQueryNormalization.TranslateToEnglishWhenAvailable; + + options.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"travel:{query}", + new AIFunctionFactoryOptions + { + Name = "travel_hotel_search", + Description = "Find hotels by city, district, amenities, breakfast, or cancellation policy." + })); +}); +``` + +Disable normalization when the host wants purely local tokenizer behavior: + +```csharp +services.AddManagedCodeMcpGateway(options => +{ + options.SearchStrategy = McpGatewaySearchStrategy.Tokenizer; + options.SearchQueryNormalization = McpGatewaySearchQueryNormalization.Disabled; +}); +``` + +The package does not register or configure an `IChatClient` for you. This keeps the gateway generic while still allowing multilingual and typo-heavy search inputs to converge to an English retrieval form when the host opts in. + +`McpGatewaySearchResult.RankingMode` stays: + +- `vector` for embedding-backed ranking +- `lexical` for tokenizer-backed ranking and tokenizer fallback +- `browse` when no search text/context is provided +- `empty` when the catalog is empty + +In other words, the current `lexical` mode is the working tokenizer mode. + +## Search Strategy Matrix + +Use `McpGatewaySearchStrategy.Auto` when you want one production default that works everywhere: + +- if embeddings are registered, use embeddings +- if embeddings are missing, use tokenizer ranking +- if embeddings fail for a query, fall back to tokenizer ranking + +Use `McpGatewaySearchStrategy.Embeddings` when: + +- embeddings are expected in that host +- you want vector search whenever it is available +- tokenizer ranking should only be the fallback path when vector search cannot complete + +Use `McpGatewaySearchStrategy.Tokenizer` when: + +- you want a zero-embedding deployment +- you want deterministic local search behavior without an embedding provider +- you want to benchmark tokenizer-backed ranking independently from vector search + +## Search Strategy Configuration + +Force embeddings when they are available: + +```csharp +var services = new ServiceCollection(); + +services.AddKeyedSingleton>, MyEmbeddingGenerator>( + McpGatewayServiceKeys.EmbeddingGenerator); + +services.AddManagedCodeMcpGateway(options => +{ + options.SearchStrategy = McpGatewaySearchStrategy.Embeddings; + + options.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"github:{query}", + new AIFunctionFactoryOptions + { + Name = "github_search_repositories", + Description = "Search GitHub repositories by user query." + })); +}); +``` + +Force tokenizer-backed ranking: + +```csharp +var services = new ServiceCollection(); + +services.AddManagedCodeMcpGateway(options => +{ + options.SearchStrategy = McpGatewaySearchStrategy.Tokenizer; + + options.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"github:{query}", + new AIFunctionFactoryOptions + { + Name = "github_search_repositories", + Description = "Search GitHub repositories by user query." + })); +}); +``` + +Keep the default auto strategy, but make the defaults explicit in code: + +```csharp +var services = new ServiceCollection(); + +services.AddManagedCodeMcpGateway(options => +{ + options.SearchStrategy = McpGatewaySearchStrategy.Auto; + options.DefaultSearchLimit = 5; + options.MaxSearchResults = 15; + + options.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"github:{query}", + new AIFunctionFactoryOptions + { + Name = "github_search_repositories", + Description = "Search GitHub repositories by user query." + })); +}); +``` + +If you do not register an embedding generator, the same configuration still works and automatically uses tokenizer ranking. + ## Optional Embeddings Register any provider-specific implementation of `IEmbeddingGenerator>` in the same DI container before building the service provider. @@ -183,6 +582,45 @@ services.AddManagedCodeMcpGateway(options => The keyed registration is the preferred one, so you can dedicate a specific embedder to the gateway without affecting other app services. +## Tokenizer Fallback Without Embeddings + +This is the default operational fallback: + +```csharp +var services = new ServiceCollection(); + +services.AddManagedCodeMcpGateway(options => +{ + options.SearchStrategy = McpGatewaySearchStrategy.Auto; + + options.AddTool( + "local", + AIFunctionFactory.Create( + static (string query) => $"github:{query}", + new AIFunctionFactoryOptions + { + Name = "github_search_repositories", + Description = "Search GitHub repositories by user query." + })); +}); + +await using var serviceProvider = services.BuildServiceProvider(); +var gateway = serviceProvider.GetRequiredService(); + +var result = await gateway.SearchAsync("review qeue for managedcode prs"); +``` + +With no embedding generator registered: + +- the gateway still builds the catalog +- search uses the built-in `ChatGptO200kBase` tokenizer path and two-stage lexical ranking +- an optional keyed search rewrite `IChatClient` can normalize the query to English first +- typo-tolerant term heuristics still participate in ranking +- the result diagnostics contain `lexical_fallback` +- the default result set size is 5 + +If you want tokenizer search only, set `options.SearchStrategy = McpGatewaySearchStrategy.Tokenizer`. + ## Persistent Tool Embeddings For process-local caching, the package already includes `McpGatewayInMemoryToolEmbeddingStore`: @@ -231,7 +669,7 @@ services.AddManagedCodeMcpGateway(options => }); ``` -During `BuildIndexAsync()` the gateway: +When an index build runs, whether explicitly or through lazy/background warmup, the gateway: - computes a descriptor-document hash per tool - asks `IMcpGatewayToolEmbeddingStore` for matching stored vectors @@ -240,6 +678,31 @@ During `BuildIndexAsync()` the gateway: This avoids recalculating tool embeddings on every rebuild while still refreshing them automatically when the descriptor document changes. Stored vectors are scoped to both the descriptor hash and the resolved embedding-generator fingerprint, so changing the provider or model automatically forces regeneration. Query embeddings are still generated at search time from the registered `IEmbeddingGenerator>`. +## Search Evaluation + +The repository includes a tokenizer evaluation suite built around a 50-tool catalog with intentionally overlapping verbs such as `search`, `lookup`, `timeline`, and `summary`, while keeping the domain semantics separated in the descriptions. + +Coverage buckets in `tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationTests.cs`: + +- high relevance +- borderline / semantically adjacent tools +- multilingual +- typo / spelling mistakes +- weak-intent / underspecified commands +- irrelevant queries + +The evaluation asserts: + +- `top1`, `top3`, and `top5` +- mean reciprocal rank +- low-confidence behavior for irrelevant queries + +The noisy-query buckets intentionally include spelling mistakes and weakly specified commands so the tokenizer path is exercised as a real fallback, not only on clean benchmark phrasing. + +Current reference numbers from the repository test corpus: + +- `ChatGptO200kBase`: high relevance `top1=95.65%`, `top3=100%`, `top5=100%`, `MRR=0.98`; typo `top1=100%`; weak intent `top1=100%`; irrelevant `low-confidence=100%` + ## Supported Sources - local `AITool` / `AIFunction` @@ -253,6 +716,20 @@ This avoids recalculating tool embeddings on every rebuild while still refreshin ```bash dotnet restore ManagedCode.MCPGateway.slnx -dotnet build ManagedCode.MCPGateway.slnx -c Release -dotnet test --solution ManagedCode.MCPGateway.slnx -c Release +dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore +dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build +``` + +Analyzer pass: + +```bash +dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore -p:RunAnalyzers=true +``` + +Detailed TUnit output: + +```bash +dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build --output Detailed --no-progress ``` + +This repository uses `TUnit` on top of `Microsoft.Testing.Platform`, so prefer the `dotnet test --solution ...` commands above. Do not assume VSTest-only flags such as `--filter` or `--logger` are available here. diff --git a/docs/ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md b/docs/ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md new file mode 100644 index 0000000..96588e9 --- /dev/null +++ b/docs/ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md @@ -0,0 +1,142 @@ +# ADR-0001: Runtime Boundaries And Index Lifecycle + +## Context + +`ManagedCode.MCPGateway` started as a single gateway concept, but the package now has three distinct concerns: + +- runtime search and invocation for local `AITool` instances and remote MCP tools +- mutable catalog registration for local tools, stdio/HTTP MCP servers, and deferred `McpClient` factories +- index lifecycle management for lazy builds, hosted warmup, and rebuilds after registry mutations + +Recent changes also made index construction cancellation-aware and single-flight, so startup warmup, shutdown, and concurrent callers do not keep issuing duplicate MCP loads or continue rebuilding after a canceled operation should stop. + +The repository needs an explicit record for these boundaries so the public package surface, internal runtime structure, and README examples stay aligned. + +## Decision + +`ManagedCode.MCPGateway` will keep a thin public runtime facade, a separate DI-managed registry mutation surface, and an internal runtime orchestrator with lazy, cancellation-aware single-flight index builds plus optional eager warmup integration. + +## Diagram + +```mermaid +flowchart LR + Host["Host application"] --> DI["AddManagedCodeMcpGateway(...)"] + DI --> Gateway["IMcpGateway / McpGateway"] + DI --> Registry["IMcpGatewayRegistry / McpGatewayRegistry"] + DI --> ToolSet["McpGatewayToolSet"] + DI --> Warmup["AddManagedCodeMcpGatewayIndexWarmup()"] + Gateway --> Runtime["McpGatewayRuntime"] + Registry --> Snapshot["Catalog snapshots"] + Runtime --> Snapshot + Warmup --> Runtime + Runtime --> Search["Search / invoke / index build"] +``` + +## Alternatives + +### Alternative 1: Keep one monolithic gateway type that also mutates the registry + +Pros: + +- fewer types to explain +- direct mutation calls on the same service + +Cons: + +- violates single responsibility for runtime versus mutation +- makes DI usage less explicit +- encourages `McpGateway` to become a god object again + +### Alternative 2: Require every host to call `BuildIndexAsync()` manually + +Pros: + +- very explicit startup workflow +- easy to reason about in small demos + +Cons: + +- forces boilerplate on every consumer +- easy to forget in real hosts +- contradicts the package goal of working lazily by default + +### Alternative 3: Use blocking locks around registry mutation and index lifecycle + +Pros: + +- straightforward first implementation +- familiar concurrency model + +Cons: + +- obscures cancellation and shutdown behavior +- harder to scale under concurrent search/build callers +- already caused readability and lifecycle issues in this repository + +## Consequences + +Positive: + +- public DI wiring is explicit: `IMcpGateway` for runtime work, `IMcpGatewayRegistry` for catalog mutation, `McpGatewayToolSet` for meta-tools +- hosts get lazy behavior by default and optional eager warmup through `InitializeManagedCodeMcpGatewayAsync()` or `AddManagedCodeMcpGatewayIndexWarmup()` +- cancellation now propagates into source loading, embedding generation, and embedding-store I/O during index builds +- runtime rebuilds after registry mutations remain automatic without forcing every host into startup code + +Trade-offs: + +- there are more internal collaborator types to document +- lazy behavior means startup failures may surface on first use unless the host opts into eager warmup +- single-flight lifecycle code is more subtle than a naive sequential implementation + +Mitigations: + +- keep `McpGateway` thin and document the boundaries in `docs/Architecture/Overview.md` +- keep README examples for both lazy default usage and eager warmup +- cover cancellation, retry-after-cancel, and concurrent build behavior with tests + +## Invariants + +- `IMcpGateway` MUST remain the public runtime facade for build, list, search, invoke, and meta-tool creation. +- `IMcpGatewayRegistry` MUST remain the public mutation surface for adding tools and MCP sources after container build. +- `AddManagedCodeMcpGateway(...)` MUST register `IMcpGateway`, `IMcpGatewayRegistry`, and `McpGatewayToolSet`. +- Index builds MUST be lazy by default and MUST rebuild automatically after registry mutations invalidate the snapshot. +- Hosted warmup MUST stay optional and MUST use the same runtime/index path as normal gateway operations. +- Cancellation of `BuildIndexAsync(...)` MUST propagate into underlying source loading and embedding work. + +## Rollout And Rollback + +Rollout: + +1. Keep the separated facade/registry/runtime structure in `src/ManagedCode.MCPGateway/`. +2. Keep README startup guidance aligned with lazy default plus optional eager warmup. +3. Keep tests covering concurrent builds, cancellation, and post-mutation rebuild behavior. + +Rollback: + +1. Revert the runtime/registry split only if the package intentionally changes back to a single mutable gateway facade. +2. Remove warmup helpers only if startup prewarming is intentionally dropped as a supported scenario. + +## Verification + +- `dotnet restore ManagedCode.MCPGateway.slnx` +- `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore` +- `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore -p:RunAnalyzers=true` +- `dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build` +- `roslynator analyze src/ManagedCode.MCPGateway/ManagedCode.MCPGateway.csproj -p Configuration=Release --severity-level warning` +- `roslynator analyze tests/ManagedCode.MCPGateway.Tests/ManagedCode.MCPGateway.Tests.csproj -p Configuration=Release --severity-level warning` +- `cloc --include-lang=C# src tests` + +## Implementation Plan (step-by-step) + +1. Keep `McpGateway` as a thin public facade over `McpGatewayRuntime`. +2. Keep `McpGatewayRegistry` as the DI-managed mutation surface and snapshot source. +3. Keep `McpGatewayRuntime` responsible for lazy single-flight index builds and search/invocation orchestration. +4. Expose eager warmup through service-provider and hosted-service extensions instead of forcing manual `BuildIndexAsync()` in every host. +5. Keep cancellation and concurrency regression coverage in the search/build test suite. + +## Stakeholder Notes + +- Product: hosts can choose lazy startup or eager warmup without changing the public runtime API. +- Dev: runtime and mutation responsibilities are intentionally separate and must stay that way. +- QA: warmup, cancellation, and rebuild-after-mutation scenarios are first-class verification targets. +- DevOps: startup behavior is configurable; eager warmup is the fail-fast option for production hosts. diff --git a/docs/ADR/ADR-0002-search-ranking-and-query-normalization.md b/docs/ADR/ADR-0002-search-ranking-and-query-normalization.md new file mode 100644 index 0000000..408496b --- /dev/null +++ b/docs/ADR/ADR-0002-search-ranking-and-query-normalization.md @@ -0,0 +1,138 @@ +# ADR-0002: Search Ranking And Query Normalization + +## Context + +`ManagedCode.MCPGateway` must stay useful when a host does not register embeddings, while still taking advantage of embeddings when they are available. The package also needs to handle multilingual, typo-heavy, and weakly specified queries without degenerating into hardcoded phrase lists. + +Recent work introduced a tokenizer-backed ranking pipeline with BM25-style field scoring, token and character n-gram similarity, approximate typo handling, optional English query normalization, and automatic fallback from vector search to tokenizer ranking. The package also removed public tokenizer selection and standardized on the built-in `ChatGptO200kBase` path. + +This decision needs a durable record because it affects defaults, DI integration, tests, and user-facing README guidance. + +## Decision + +`ManagedCode.MCPGateway` will use `SearchStrategy.Auto` as the default production mode, keep one built-in `ChatGptO200kBase` tokenizer-backed search path, and improve non-vector retrieval through mathematical ranking plus optional English query normalization via a keyed `IChatClient`. + +## Diagram + +```mermaid +flowchart LR + Request["Search request"] --> Normalize{"Normalization enabled\nand keyed IChatClient available?"} + Normalize -->|Yes| Rewrite["Rewrite to concise English"] + Normalize -->|No| Raw["Use original query"] + Rewrite --> Strategy{"Vector search available?"} + Raw --> Strategy + Strategy -->|Yes| Vector["Embedding search"] + Strategy -->|No| Token["Tokenizer ranking"] + Vector -->|Failure| Token + Token --> Result["Search result + diagnostics"] + Vector --> Result +``` + +## Alternatives + +### Alternative 1: Embeddings only + +Pros: + +- simple ranking story +- potentially higher semantic quality when embeddings are always available + +Cons: + +- unusable in zero-embedding hosts +- adds a hard external dependency for a core package feature + +### Alternative 2: Keep multiple tokenizer options public + +Pros: + +- more tuning knobs for experiments +- easier A/B comparisons in package consumers + +Cons: + +- larger public API surface +- more documentation and compatibility burden +- unnecessary once the built-in tokenizer path became the only supported production option + +### Alternative 3: Improve search with hardcoded synonym or phrase lists + +Pros: + +- quick tactical gains for a few known queries +- easy to demo + +Cons: + +- brittle and domain-specific +- violates the repository rule to prefer mathematical ranking improvements over query text hacks +- scales poorly across languages and noisy inputs + +## Consequences + +Positive: + +- `SearchStrategy.Auto` works as one production default across embedding and zero-embedding hosts +- tokenizer search remains deterministic and local when no embedder is registered +- optional English normalization improves multilingual/noisy inputs without making the package depend on an AI client +- search quality improvements are explainable through diagnostics and testable through evaluation buckets + +Trade-offs: + +- ranking logic is more sophisticated than a flat score function +- multilingual quality without a normalizer still depends on tokenizer overlap and character n-grams +- documentation must explain that normalization is optional and keyed, not bundled by the package + +Mitigations: + +- keep the tokenizer path fully self-contained and deterministic +- keep README examples for both normalized and non-normalized deployments +- keep regression/evaluation tests for high-relevance, borderline, typo, multilingual, weak-intent, and irrelevant buckets + +## Invariants + +- `SearchStrategy.Auto` MUST remain the default search strategy. +- `SearchQueryNormalization` MUST default to `TranslateToEnglishWhenAvailable`. +- The package MUST keep only one built-in tokenizer-backed search path and MUST NOT expose stale tokenizer-choice options. +- If no embedding generator is registered, search MUST still function through tokenizer-backed ranking. +- If vector search fails for a request, the gateway MUST fall back to tokenizer-backed ranking and emit diagnostics instead of failing the request. +- If English normalization is enabled but no keyed `IChatClient` is registered, search MUST continue with the original query. +- Search-quality improvements MUST prefer mathematical scoring changes over hardcoded phrase exceptions. + +## Rollout And Rollback + +Rollout: + +1. Keep README defaults aligned with `SearchStrategy.Auto`, `SearchQueryNormalization`, and the top-5 default result size. +2. Keep `McpGatewayServiceKeys.SearchQueryChatClient` documented as the optional keyed normalizer dependency. +3. Keep tokenizer evaluation tests current with real noisy-query buckets. + +Rollback: + +1. Reintroduce a public tokenizer-selection option only if there is a concrete product requirement and a supported compatibility story. +2. Disable normalization by default only if the package intentionally stops preferring English retrieval convergence for multilingual or noisy inputs. + +## Verification + +- `dotnet restore ManagedCode.MCPGateway.slnx` +- `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore` +- `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore -p:RunAnalyzers=true` +- `dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build` +- `roslynator analyze src/ManagedCode.MCPGateway/ManagedCode.MCPGateway.csproj -p Configuration=Release --severity-level warning` +- `roslynator analyze tests/ManagedCode.MCPGateway.Tests/ManagedCode.MCPGateway.Tests.csproj -p Configuration=Release --severity-level warning` +- `cloc --include-lang=C# src tests` + +## Implementation Plan (step-by-step) + +1. Keep `McpGatewayOptions` defaults aligned with the intended production behavior. +2. Keep the tokenizer-backed ranking pipeline in `Internal/Runtime/Search/` field-aware and diagnostic-friendly. +3. Keep the optional keyed English query normalizer behind `McpGatewayServiceKeys.SearchQueryChatClient`. +4. Keep README examples for tokenizer-only, auto, embeddings, and optional normalization scenarios. +5. Keep evaluation coverage current and representative of noisy production-like queries. + +## Stakeholder Notes + +- Product: the package has one recommended search default that works with or without embeddings. +- Dev: search quality work should continue through ranking math and evaluation data, not manual phrase hacks. +- QA: typo, multilingual, weak-intent, and irrelevant buckets are required test coverage, not optional benchmark extras. +- DevOps: hosts may run fully local tokenizer search or add keyed chat/embedding services when needed. diff --git a/docs/Architecture/Overview.md b/docs/Architecture/Overview.md new file mode 100644 index 0000000..5befc98 --- /dev/null +++ b/docs/Architecture/Overview.md @@ -0,0 +1,126 @@ +# Architecture Overview + +## Scoping (read first) + +This document is the module map for `ManagedCode.MCPGateway`. + +In scope: + +- package boundaries +- runtime collaboration between the public facade, registry, meta-tools, warmup hooks, and internal runtime +- dependency direction between public APIs, internal modules, and optional AI services + +Out of scope: + +- feature-level ranking metrics +- test corpus details +- CI or release process + +## Summary + +`ManagedCode.MCPGateway` exposes three public DI surfaces: + +- `IMcpGateway` for list/search/invoke +- `IMcpGatewayRegistry` for catalog mutation +- `McpGatewayToolSet` for reusable meta-tools + +`McpGateway` stays a thin facade over `McpGatewayRuntime`, which reads immutable catalog snapshots, coordinates vector or tokenizer-backed search, optionally rewrites queries through a keyed `IChatClient`, and invokes local or MCP tools. Optional startup warmup is available through a service-provider extension or hosted background service without changing the lazy default. + +## System And Module Map + +```mermaid +flowchart LR + Host["Host application"] --> DI["DI registration"] + DI --> Facade["IMcpGateway / McpGateway"] + DI --> Registry["IMcpGatewayRegistry / McpGatewayRegistry"] + DI --> ToolSet["McpGatewayToolSet"] + DI --> Warmup["Optional warmup hooks"] + ToolSet --> Facade + Warmup --> Facade + Facade --> Runtime["Internal runtime orchestration"] + Runtime --> Catalog["Internal catalog snapshots"] + Registry --> Catalog + Catalog --> Sources["Catalog source registrations"] + Sources --> Local["Local AITool instances"] + Sources --> MCP["HTTP, stdio, and provided MCP clients"] + Runtime --> Embedder["Optional embedding generator"] + Runtime --> Store["Optional embedding store"] + Runtime --> Normalizer["Optional keyed search IChatClient"] +``` + +## Interfaces And Contracts + +```mermaid +flowchart LR + IMcpGateway["IMcpGateway"] --> McpGateway["McpGateway"] + IMcpGatewayRegistry["IMcpGatewayRegistry"] --> Registry["McpGatewayRegistry"] + ToolSet["McpGatewayToolSet"] --> IMcpGateway + Warmup["McpGatewayServiceProviderExtensions / McpGatewayIndexWarmupService"] --> IMcpGateway + McpGateway --> Runtime["McpGatewayRuntime"] + Runtime --> SearchRequest["McpGatewaySearchRequest"] + Runtime --> InvokeRequest["McpGatewayInvokeRequest"] + Runtime --> Descriptor["McpGatewayToolDescriptor"] + Runtime --> Options["McpGatewayOptions"] + Runtime --> EmbeddingStore["IMcpGatewayToolEmbeddingStore"] + Runtime --> ChatClient["IChatClient (keyed)"] + Registry --> CatalogSource["IMcpGatewayCatalogSource"] +``` + +## Key Classes And Types + +```mermaid +flowchart LR + McpGateway["McpGateway"] --> McpGatewayRuntime["McpGatewayRuntime"] + McpGatewayRuntime --> RuntimeCore["Internal/Runtime/Core/*"] + McpGatewayRuntime --> RuntimeCatalog["Internal/Runtime/Catalog/*"] + McpGatewayRuntime --> RuntimeSearch["Internal/Runtime/Search/*"] + McpGatewayRuntime --> RuntimeInvocation["Internal/Runtime/Invocation/*"] + McpGatewayRuntime --> RuntimeEmbeddings["Internal/Runtime/Embeddings/*"] + Registry["McpGatewayRegistry"] --> RegistrationCollection["McpGatewayRegistrationCollection"] + Registry --> OperationGate["McpGatewayOperationGate"] + RegistrationCollection --> SourceRegistrations["McpGatewayToolSourceRegistration*"] + RuntimeSearch --> Json["McpGatewayJsonSerializer"] + Warmup["McpGatewayIndexWarmupService"] --> McpGateway + InMemoryStore["McpGatewayInMemoryToolEmbeddingStore"] --> StoreIndex["McpGatewayToolEmbeddingStoreIndex"] +``` + +## Module Index + +- Public facade: [`src/ManagedCode.MCPGateway/McpGateway.cs`](../../src/ManagedCode.MCPGateway/McpGateway.cs) exposes the package runtime API and delegates work to the internal runtime. +- Public abstractions: [`src/ManagedCode.MCPGateway/Abstractions/`](../../src/ManagedCode.MCPGateway/Abstractions/) defines the stable interfaces consumers resolve from DI. +- Public configuration: [`src/ManagedCode.MCPGateway/Configuration/`](../../src/ManagedCode.MCPGateway/Configuration/) contains options and service keys that shape host integration. +- Public models: [`src/ManagedCode.MCPGateway/Models/`](../../src/ManagedCode.MCPGateway/Models/) contains request/result contracts and enums grouped by search, invocation, catalog, and embeddings behavior. +- Public embeddings: [`src/ManagedCode.MCPGateway/Embeddings/`](../../src/ManagedCode.MCPGateway/Embeddings/) provides optional embedding-store implementations. +- Public meta-tools: [`src/ManagedCode.MCPGateway/McpGatewayToolSet.cs`](../../src/ManagedCode.MCPGateway/McpGatewayToolSet.cs) exposes the gateway as reusable `AITool` instances for model-driven search and invoke flows. +- Internal catalog module: [`src/ManagedCode.MCPGateway/Internal/Catalog/`](../../src/ManagedCode.MCPGateway/Internal/Catalog/) owns mutable tool-source registration state and read-only snapshots for indexing. +- Internal catalog sources: [`src/ManagedCode.MCPGateway/Internal/Catalog/Sources/`](../../src/ManagedCode.MCPGateway/Internal/Catalog/Sources/) owns transport-specific source registrations and MCP client creation. +- Internal runtime module: [`src/ManagedCode.MCPGateway/Internal/Runtime/`](../../src/ManagedCode.MCPGateway/Internal/Runtime/) owns orchestration and is split by core, catalog, search, invocation, and embeddings concerns. +- Internal embedding helpers: [`src/ManagedCode.MCPGateway/Internal/Embeddings/`](../../src/ManagedCode.MCPGateway/Internal/Embeddings/) contains non-public embedding indexing helpers. +- Internal serialization: [`src/ManagedCode.MCPGateway/Internal/Serialization/`](../../src/ManagedCode.MCPGateway/Internal/Serialization/) contains the canonical JSON materialization path used by runtime features. +- Warmup hooks: [`src/ManagedCode.MCPGateway/Registration/McpGatewayServiceProviderExtensions.cs`](../../src/ManagedCode.MCPGateway/Registration/McpGatewayServiceProviderExtensions.cs) and [`src/ManagedCode.MCPGateway/Internal/Warmup/McpGatewayIndexWarmupService.cs`](../../src/ManagedCode.MCPGateway/Internal/Warmup/McpGatewayIndexWarmupService.cs) provide optional eager index-building integration. +- DI registration: [`src/ManagedCode.MCPGateway/Registration/McpGatewayServiceCollectionExtensions.cs`](../../src/ManagedCode.MCPGateway/Registration/McpGatewayServiceCollectionExtensions.cs) wires facade, registry, meta-tools, and warmup support into the container. + +## Dependency Rules + +- Public code may depend on `Models`, `Configuration`, and `Abstractions`, but internal modules must not depend on tests or docs. +- `McpGateway` is a thin facade only. It may delegate to `McpGatewayRuntime`, but it must not own registry mutation logic. +- `Internal/Catalog` owns mutable source registration state. `Internal/Runtime` may read snapshots from it, but must not mutate registrations directly. +- `Internal/Catalog/Sources` owns MCP transport-specific creation and caching. Transport setup must not leak into `Internal/Runtime`, `Models`, or `Configuration`. +- `Internal/Runtime` may depend on `Internal/Catalog`, `Internal/Embeddings`, `Embeddings`, `Models`, `Configuration`, and `Abstractions`. +- Optional AI services such as embedding generators and query-normalization chat clients must stay outside the package core and be resolved through DI service keys rather than hardwired provider code. +- `Models` should stay contract-first. Internal transport, registry, or lifecycle helpers do not belong there. +- Embedding support must stay optional and isolated behind `IMcpGatewayToolEmbeddingStore` and embedding-generator abstractions. +- Warmup remains optional. The package must work correctly with lazy indexing and must not require manual initialization for every host. + +## Key Decisions (ADRs) + +- [`docs/ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md`](../ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md): documents the public/runtime/catalog split, DI boundaries, lazy indexing, cancellation-aware single-flight builds, and optional warmup hooks. +- [`docs/ADR/ADR-0002-search-ranking-and-query-normalization.md`](../ADR/ADR-0002-search-ranking-and-query-normalization.md): documents the default `Auto` search behavior, tokenizer-backed fallback, optional English query normalization, and mathematical ranking strategy. + +## Related Docs + +- [`README.md`](../../README.md) +- [`docs/ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md`](../ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md) +- [`docs/ADR/ADR-0002-search-ranking-and-query-normalization.md`](../ADR/ADR-0002-search-ranking-and-query-normalization.md) +- [`docs/Features/SearchQueryNormalizationAndRanking.md`](../Features/SearchQueryNormalizationAndRanking.md) +- [`AGENTS.md`](../../AGENTS.md) diff --git a/docs/Features/SearchQueryNormalizationAndRanking.md b/docs/Features/SearchQueryNormalizationAndRanking.md new file mode 100644 index 0000000..b60e890 --- /dev/null +++ b/docs/Features/SearchQueryNormalizationAndRanking.md @@ -0,0 +1,124 @@ +# Search Query Normalization And Ranking + +## Purpose And Scope + +This feature improves `ManagedCode.MCPGateway` search quality for multilingual, typo-heavy, and weakly specified search requests without introducing phrase-level hardcoded rules. + +In scope: + +- optional English query normalization before ranking +- tokenizer-backed ranking improvements for non-embedding search +- deterministic fallback when no AI normalizer is registered +- automated verification for multilingual, noisy, and borderline search buckets + +Out of scope: + +- embedding model changes +- vendor-specific AI SDK setup inside the package +- domain-specific synonym lists or handcrafted query exceptions + +## Affected Modules + +- `src/ManagedCode.MCPGateway/Configuration/McpGatewayOptions.cs` +- `src/ManagedCode.MCPGateway/Configuration/McpGatewayServiceKeys.cs` +- `src/ManagedCode.MCPGateway/Models/Search/*` +- `src/ManagedCode.MCPGateway/Internal/Runtime/Search/*` +- `tests/ManagedCode.MCPGateway.Tests/Search/*` +- `README.md` + +## Business Rules + +1. Tokenizer-backed search must stay functional with zero embedding or chat-model dependencies. +2. When query normalization is enabled and a keyed `IChatClient` is available, the gateway must normalize the user query into concise English before ranking. +3. Query normalization must preserve identifiers and retrieval-critical literals such as emails, repository names, CVE references, order numbers, tracking numbers, and SKUs. +4. If normalization is enabled but no keyed normalizer client is registered, the gateway must continue with the original query and must not fail the search. +5. If normalization fails, the gateway must continue with the original query and expose a diagnostic rather than throwing. +6. Tokenizer-backed ranking must prefer mathematical retrieval improvements over text-level hardcoded exceptions. +7. Tokenizer-backed ranking must improve recall for typos and multilingual cognates while also reducing domain-local ties such as `invoice` versus `payment reconciliation`. +8. The package must keep one built-in tokenizer-backed search path and must not expose stale tokenizer-selection options. +9. Default search result limits and existing public search/invoke entry points must remain intact. + +## Main Flow + +```mermaid +flowchart LR + Request["Search request"] --> Normalize{"English normalization enabled\nand keyed IChatClient present?"} + Normalize -->|Yes| Rewrite["Rewrite query to concise English"] + Normalize -->|No| Original["Use original query"] + Rewrite --> Retrieval["Stage 1 retrieval\nBM25F + token cosine + char 3-gram"] + Original --> Retrieval + Retrieval --> CandidatePool["Top candidate pool"] + CandidatePool --> Rerank["Stage 2 rerank\nfield-aware features"] + Rerank --> Result["Search result + diagnostics"] +``` + +## Negative And Edge Cases + +- Empty query with no context still returns `browse` mode. +- Empty catalog still returns `empty` mode. +- A registered embedding generator still takes precedence when vector search is active. +- A normalization client that returns blank output must not replace the original query. +- A normalization client that times out or throws must emit a diagnostic and fall back to the original query. +- Typo-heavy inputs such as `shipmnt` and `contcat` must still retrieve the expected tool in the result set. +- Multilingual inputs without a normalizer must still benefit from tokenizer aliases and character n-grams. + +## System Behavior + +- Entry points: + - `IMcpGateway.SearchAsync(string?, int?, CancellationToken)` + - `IMcpGateway.SearchAsync(McpGatewaySearchRequest, CancellationToken)` +- Reads: + - tool catalog snapshot from `IMcpGatewayCatalogSource` + - keyed optional search normalizer client from DI + - search options from `McpGatewayOptions` +- Writes: + - no persistent writes beyond existing optional embedding-store behavior +- Side effects: + - optional `IChatClient` request for query normalization + - diagnostics describing normalization fallback or low-confidence conditions +- Idempotency: + - same indexed catalog and same deterministic query-normalizer response yield stable ranking +- Errors: + - search must not throw only because the optional normalizer is missing or fails + +## Verification + +Environment assumptions: + +- .NET 10 SDK from `global.json` +- `TUnit` on `Microsoft.Testing.Platform` + +Verification commands: + +- `dotnet restore ManagedCode.MCPGateway.slnx` +- `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore` +- `dotnet build ManagedCode.MCPGateway.slnx -c Release --no-restore -p:RunAnalyzers=true` +- `dotnet test --solution ManagedCode.MCPGateway.slnx -c Release --no-build` + +Test mapping: + +- normalization success and fallback behavior in `tests/ManagedCode.MCPGateway.Tests/Search/` +- tokenizer ranking regression coverage in `McpGatewayTokenizerSearchTests.cs` +- evaluation-bucket quality coverage in `McpGatewayTokenizerSearchEvaluationTests.cs` + +## Definition Of Done + +- tokenizer-backed search supports optional English query normalization through `Microsoft.Extensions.AI` +- multilingual and typo-heavy evaluation scenarios remain covered by automated tests +- docs explain how to register the optional query-normalization client +- build, analyzers, and tests stay green + +## Related Docs + +- [`README.md`](../../README.md) +- [`docs/ADR/ADR-0002-search-ranking-and-query-normalization.md`](../ADR/ADR-0002-search-ranking-and-query-normalization.md) +- [`docs/ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md`](../ADR/ADR-0001-runtime-boundaries-and-index-lifecycle.md) +- [`docs/Architecture/Overview.md`](../Architecture/Overview.md) + +## Implementation Plan (step-by-step) + +1. Add search-normalization configuration to `McpGatewayOptions` and a keyed DI service key for the optional normalizer chat client. +2. Implement normalization in the search pipeline with graceful fallback and diagnostics. +3. Replace the flat tokenizer score with a two-stage ranking flow that uses field-aware BM25-style retrieval plus character n-gram retrieval. +4. Add deterministic tests for query normalization, failure fallback, and updated ranking behavior. +5. Update `README.md` with configuration and operational guidance. diff --git a/src/ManagedCode.MCPGateway/Abstractions/IMcpGatewayRegistry.cs b/src/ManagedCode.MCPGateway/Abstractions/Catalog/IMcpGatewayRegistry.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Abstractions/IMcpGatewayRegistry.cs rename to src/ManagedCode.MCPGateway/Abstractions/Catalog/IMcpGatewayRegistry.cs diff --git a/src/ManagedCode.MCPGateway/Abstractions/IMcpGatewayToolEmbeddingStore.cs b/src/ManagedCode.MCPGateway/Abstractions/Embeddings/IMcpGatewayToolEmbeddingStore.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Abstractions/IMcpGatewayToolEmbeddingStore.cs rename to src/ManagedCode.MCPGateway/Abstractions/Embeddings/IMcpGatewayToolEmbeddingStore.cs diff --git a/src/ManagedCode.MCPGateway/Configuration/McpGatewayOptions.cs b/src/ManagedCode.MCPGateway/Configuration/McpGatewayOptions.cs new file mode 100644 index 0000000..c73dbaf --- /dev/null +++ b/src/ManagedCode.MCPGateway/Configuration/McpGatewayOptions.cs @@ -0,0 +1,70 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; + +namespace ManagedCode.MCPGateway; + +public sealed class McpGatewayOptions +{ + private readonly McpGatewayRegistrationCollection _sourceRegistrations = new(); + + public McpGatewaySearchStrategy SearchStrategy { get; set; } = McpGatewaySearchStrategy.Auto; + + public McpGatewaySearchQueryNormalization SearchQueryNormalization { get; set; } = + McpGatewaySearchQueryNormalization.TranslateToEnglishWhenAvailable; + + public int DefaultSearchLimit { get; set; } = 5; + + public int MaxSearchResults { get; set; } = 15; + + public int MaxDescriptorLength { get; set; } = 4096; + + internal IReadOnlyList SourceRegistrations => _sourceRegistrations.Snapshot(); + + public McpGatewayOptions AddTool(string sourceId, AITool tool, string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddTool(sourceId, tool, displayName)); + + public McpGatewayOptions AddTool(AITool tool, string sourceId = McpGatewayDefaults.DefaultSourceId, string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddTool(tool, sourceId, displayName)); + + public McpGatewayOptions AddTools(string sourceId, IEnumerable tools, string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddTools(sourceId, tools, displayName)); + + public McpGatewayOptions AddTools(IEnumerable tools, string sourceId = McpGatewayDefaults.DefaultSourceId, string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddTools(tools, sourceId, displayName)); + + public McpGatewayOptions AddHttpServer( + string sourceId, + Uri endpoint, + IReadOnlyDictionary? headers = null, + string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddHttpServer(sourceId, endpoint, headers, displayName)); + + public McpGatewayOptions AddStdioServer( + string sourceId, + string command, + IReadOnlyList? arguments = null, + string? workingDirectory = null, + IReadOnlyDictionary? environmentVariables = null, + string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddStdioServer(sourceId, command, arguments, workingDirectory, environmentVariables, displayName)); + + public McpGatewayOptions AddMcpClient( + string sourceId, + McpClient client, + bool disposeClient = false, + string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddMcpClient(sourceId, client, disposeClient, displayName)); + + public McpGatewayOptions AddMcpClientFactory( + string sourceId, + Func> clientFactory, + bool disposeClient = true, + string? displayName = null) + => ConfigureRegistrations(registrations => registrations.AddMcpClientFactory(sourceId, clientFactory, disposeClient, displayName)); + + private McpGatewayOptions ConfigureRegistrations(Action configure) + { + configure(_sourceRegistrations); + return this; + } +} diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayServiceKeys.cs b/src/ManagedCode.MCPGateway/Configuration/McpGatewayServiceKeys.cs similarity index 64% rename from src/ManagedCode.MCPGateway/Models/McpGatewayServiceKeys.cs rename to src/ManagedCode.MCPGateway/Configuration/McpGatewayServiceKeys.cs index 1cc4b28..b754d43 100644 --- a/src/ManagedCode.MCPGateway/Models/McpGatewayServiceKeys.cs +++ b/src/ManagedCode.MCPGateway/Configuration/McpGatewayServiceKeys.cs @@ -3,4 +3,5 @@ namespace ManagedCode.MCPGateway; public static class McpGatewayServiceKeys { public const string EmbeddingGenerator = "ManagedCode.MCPGateway.EmbeddingGenerator"; + public const string SearchQueryChatClient = "ManagedCode.MCPGateway.SearchQueryChatClient"; } diff --git a/src/ManagedCode.MCPGateway/Embeddings/McpGatewayInMemoryToolEmbeddingStore.cs b/src/ManagedCode.MCPGateway/Embeddings/McpGatewayInMemoryToolEmbeddingStore.cs new file mode 100644 index 0000000..49073cb --- /dev/null +++ b/src/ManagedCode.MCPGateway/Embeddings/McpGatewayInMemoryToolEmbeddingStore.cs @@ -0,0 +1,21 @@ +using ManagedCode.MCPGateway.Abstractions; + +namespace ManagedCode.MCPGateway; + +public sealed class McpGatewayInMemoryToolEmbeddingStore : IMcpGatewayToolEmbeddingStore +{ + private readonly McpGatewayToolEmbeddingStoreIndex _index = new(); + + public Task> GetAsync( + IReadOnlyList lookups, + CancellationToken cancellationToken = default) + => Task.FromResult(_index.Get(lookups, cancellationToken)); + + public Task UpsertAsync( + IReadOnlyList embeddings, + CancellationToken cancellationToken = default) + { + _index.Upsert(embeddings, cancellationToken); + return Task.CompletedTask; + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Catalog/IMcpGatewayCatalogSource.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/IMcpGatewayCatalogSource.cs new file mode 100644 index 0000000..2302d61 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/IMcpGatewayCatalogSource.cs @@ -0,0 +1,6 @@ +namespace ManagedCode.MCPGateway; + +internal interface IMcpGatewayCatalogSource +{ + McpGatewayCatalogSourceSnapshot CreateSnapshot(); +} diff --git a/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayCatalogSourceSnapshot.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayCatalogSourceSnapshot.cs new file mode 100644 index 0000000..f78f420 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayCatalogSourceSnapshot.cs @@ -0,0 +1,5 @@ +namespace ManagedCode.MCPGateway; + +internal sealed record McpGatewayCatalogSourceSnapshot( + int Version, + IReadOnlyList Registrations); diff --git a/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayDefaults.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayDefaults.cs new file mode 100644 index 0000000..46049f3 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayDefaults.cs @@ -0,0 +1,6 @@ +namespace ManagedCode.MCPGateway; + +internal static class McpGatewayDefaults +{ + public const string DefaultSourceId = "local"; +} diff --git a/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayOperationGate.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayOperationGate.cs new file mode 100644 index 0000000..028dd15 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayOperationGate.cs @@ -0,0 +1,76 @@ +namespace ManagedCode.MCPGateway; + +internal sealed class McpGatewayOperationGate +{ + private int _disposed; + private int _activeOperations; + private TaskCompletionSource? _operationsDrainedSignal; + + public void ThrowIfDisposed(object owner) + { + ObjectDisposedException.ThrowIf(Volatile.Read(ref _disposed) != 0, owner); + } + + public void Enter(object owner) + { + while (Volatile.Read(ref _disposed) == 0) + { + ThrowIfDisposed(owner); + Interlocked.Increment(ref _activeOperations); + if (Volatile.Read(ref _disposed) == 0) + { + return; + } + + Exit(); + } + + ThrowIfDisposed(owner); + } + + public void Exit() + { + if (Interlocked.Decrement(ref _activeOperations) == 0) + { + Volatile.Read(ref _operationsDrainedSignal)?.TrySetResult(null); + } + } + + public bool TryStartDispose(out ValueTask waitForDrain) + { + if (Interlocked.Exchange(ref _disposed, 1) != 0) + { + waitForDrain = ValueTask.CompletedTask; + return false; + } + + var operationsDrainedSignal = EnsureOperationsDrainedSignal(); + waitForDrain = Volatile.Read(ref _activeOperations) > 0 + ? new ValueTask(operationsDrainedSignal.Task) + : ValueTask.CompletedTask; + return true; + } + + private TaskCompletionSource EnsureOperationsDrainedSignal() + { + var operationsDrainedSignal = Volatile.Read(ref _operationsDrainedSignal); + while (operationsDrainedSignal is null) + { + var created = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + if (Interlocked.CompareExchange(ref _operationsDrainedSignal, created, null) is null) + { + operationsDrainedSignal = created; + break; + } + + operationsDrainedSignal = Volatile.Read(ref _operationsDrainedSignal); + } + + if (Volatile.Read(ref _activeOperations) == 0) + { + operationsDrainedSignal.TrySetResult(null); + } + + return operationsDrainedSignal; + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayRegistry.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayRegistry.cs new file mode 100644 index 0000000..5eebd5e --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/McpGatewayRegistry.cs @@ -0,0 +1,109 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Client; + +namespace ManagedCode.MCPGateway; + +internal sealed class McpGatewayRegistry(IOptions options) : IMcpGatewayRegistry, IMcpGatewayCatalogSource, IAsyncDisposable +{ + private readonly McpGatewayRegistrationCollection _registrations = CreateRegistrations(options); + private readonly McpGatewayOperationGate _operationGate = new(); + private int _version; + + public void AddTool(string sourceId, AITool tool, string? displayName = null) + => Mutate(registrations => registrations.AddTool(sourceId, tool, displayName)); + + public void AddTool(AITool tool, string sourceId = McpGatewayDefaults.DefaultSourceId, string? displayName = null) + => Mutate(registrations => registrations.AddTool(tool, sourceId, displayName)); + + public void AddTools(string sourceId, IEnumerable tools, string? displayName = null) + => Mutate(registrations => registrations.AddTools(sourceId, tools, displayName)); + + public void AddTools(IEnumerable tools, string sourceId = McpGatewayDefaults.DefaultSourceId, string? displayName = null) + => Mutate(registrations => registrations.AddTools(tools, sourceId, displayName)); + + public void AddHttpServer( + string sourceId, + Uri endpoint, + IReadOnlyDictionary? headers = null, + string? displayName = null) + => Mutate(registrations => registrations.AddHttpServer(sourceId, endpoint, headers, displayName)); + + public void AddStdioServer( + string sourceId, + string command, + IReadOnlyList? arguments = null, + string? workingDirectory = null, + IReadOnlyDictionary? environmentVariables = null, + string? displayName = null) + => Mutate(registrations => registrations.AddStdioServer(sourceId, command, arguments, workingDirectory, environmentVariables, displayName)); + + public void AddMcpClient( + string sourceId, + McpClient client, + bool disposeClient = false, + string? displayName = null) + => Mutate(registrations => registrations.AddMcpClient(sourceId, client, disposeClient, displayName)); + + public void AddMcpClientFactory( + string sourceId, + Func> clientFactory, + bool disposeClient = true, + string? displayName = null) + => Mutate(registrations => registrations.AddMcpClientFactory(sourceId, clientFactory, disposeClient, displayName)); + + public McpGatewayCatalogSourceSnapshot CreateSnapshot() + { + _operationGate.Enter(this); + try + { + _operationGate.ThrowIfDisposed(this); + return new McpGatewayCatalogSourceSnapshot( + Volatile.Read(ref _version), + _registrations.Snapshot()); + } + finally + { + _operationGate.Exit(); + } + } + + public async ValueTask DisposeAsync() + { + if (!_operationGate.TryStartDispose(out var waitForDrain)) + { + return; + } + + Interlocked.Increment(ref _version); + await waitForDrain; + + var registrations = _registrations.Drain(); + foreach (var registration in registrations) + { + await registration.DisposeAsync(); + } + } + + private void Mutate(Action mutation) + { + _operationGate.Enter(this); + try + { + _operationGate.ThrowIfDisposed(this); + mutation(_registrations); + Interlocked.Increment(ref _version); + } + finally + { + _operationGate.Exit(); + } + } + + private static McpGatewayRegistrationCollection CreateRegistrations(IOptions options) + { + ArgumentNullException.ThrowIfNull(options); + return new McpGatewayRegistrationCollection(options.Value.SourceRegistrations); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayClientFactory.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayClientFactory.cs new file mode 100644 index 0000000..9ecdc5f --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayClientFactory.cs @@ -0,0 +1,27 @@ +using System.Reflection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; + +namespace ManagedCode.MCPGateway; + +internal static class McpGatewayClientFactory +{ + private const string ClientName = "managedcode-mcpgateway"; + private static readonly string ClientVersion = ResolveClientVersion(); + + public static McpClientOptions CreateClientOptions() + => new() + { + ClientInfo = new Implementation + { + Name = ClientName, + Version = ClientVersion + } + }; + + private static string ResolveClientVersion() + => typeof(McpGatewayClientFactory).Assembly + .GetCustomAttribute()?.InformationalVersion + ?? typeof(McpGatewayClientFactory).Assembly.GetName().Version?.ToString() + ?? "unknown"; +} diff --git a/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayRegistrationCollection.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayRegistrationCollection.cs new file mode 100644 index 0000000..8d74d2a --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayRegistrationCollection.cs @@ -0,0 +1,158 @@ +using System.Collections.Concurrent; +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; + +namespace ManagedCode.MCPGateway; + +internal sealed class McpGatewayRegistrationCollection(IEnumerable? registrations = null) +{ + private const string CommandRequiredMessage = "A command is required."; + private const string SourceIdRequiredMessage = "A source id is required."; + + private ConcurrentQueue _registrations = new(registrations ?? []); + private ConcurrentDictionary _localRegistrations = + CreateLocalRegistrations(registrations); + + public void AddTool(string sourceId, AITool tool, string? displayName = null) + => AddTool(tool, sourceId, displayName); + + public void AddTool(AITool tool, string sourceId = McpGatewayDefaults.DefaultSourceId, string? displayName = null) + { + ArgumentNullException.ThrowIfNull(tool); + GetOrAddLocalRegistration(sourceId, displayName).AddTool(tool); + } + + public void AddTools(string sourceId, IEnumerable tools, string? displayName = null) + => AddTools(tools, sourceId, displayName); + + public void AddTools(IEnumerable tools, string sourceId = McpGatewayDefaults.DefaultSourceId, string? displayName = null) + { + ArgumentNullException.ThrowIfNull(tools); + + var registration = GetOrAddLocalRegistration(sourceId, displayName); + foreach (var tool in tools) + { + ArgumentNullException.ThrowIfNull(tool); + registration.AddTool(tool); + } + } + + public void AddHttpServer( + string sourceId, + Uri endpoint, + IReadOnlyDictionary? headers = null, + string? displayName = null) + { + ArgumentNullException.ThrowIfNull(endpoint); + _registrations.Enqueue(new McpGatewayHttpToolSourceRegistration(ValidateSourceId(sourceId), endpoint, headers, displayName)); + } + + public void AddStdioServer( + string sourceId, + string command, + IReadOnlyList? arguments = null, + string? workingDirectory = null, + IReadOnlyDictionary? environmentVariables = null, + string? displayName = null) + { + if (string.IsNullOrWhiteSpace(command)) + { + throw new ArgumentException(CommandRequiredMessage, nameof(command)); + } + + _registrations.Enqueue(new McpGatewayStdioToolSourceRegistration( + ValidateSourceId(sourceId), + command.Trim(), + arguments, + workingDirectory, + environmentVariables, + displayName)); + } + + public void AddMcpClient( + string sourceId, + McpClient client, + bool disposeClient = false, + string? displayName = null) + { + ArgumentNullException.ThrowIfNull(client); + _registrations.Enqueue(new McpGatewayProvidedClientToolSourceRegistration( + ValidateSourceId(sourceId), + _ => ValueTask.FromResult(client), + disposeClient, + displayName)); + } + + public void AddMcpClientFactory( + string sourceId, + Func> clientFactory, + bool disposeClient = true, + string? displayName = null) + { + ArgumentNullException.ThrowIfNull(clientFactory); + _registrations.Enqueue(new McpGatewayProvidedClientToolSourceRegistration( + ValidateSourceId(sourceId), + clientFactory, + disposeClient, + displayName)); + } + + public IReadOnlyList Snapshot() + => _registrations.ToArray(); + + public IReadOnlyList Drain() + { + Interlocked.Exchange( + ref _localRegistrations, + new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase)); + + return Interlocked.Exchange(ref _registrations, new ConcurrentQueue()) + .ToArray(); + } + + private McpGatewayLocalToolSourceRegistration GetOrAddLocalRegistration(string sourceId, string? displayName) + { + sourceId = ValidateSourceId(sourceId); + + McpGatewayLocalToolSourceRegistration? existing; + while (!Volatile.Read(ref _localRegistrations).TryGetValue(sourceId, out existing)) + { + var localRegistrations = Volatile.Read(ref _localRegistrations); + var created = new McpGatewayLocalToolSourceRegistration(sourceId, displayName); + if (localRegistrations.TryAdd(sourceId, created)) + { + Volatile.Read(ref _registrations).Enqueue(created); + return created; + } + } + + return existing; + } + + private static string ValidateSourceId(string sourceId) + { + if (string.IsNullOrWhiteSpace(sourceId)) + { + throw new ArgumentException(SourceIdRequiredMessage, nameof(sourceId)); + } + + return sourceId.Trim(); + } + + private static ConcurrentDictionary CreateLocalRegistrations( + IEnumerable? registrations) + { + var result = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + if (registrations is null) + { + return result; + } + + foreach (var registration in registrations.OfType()) + { + result.TryAdd(registration.SourceId, registration); + } + + return result; + } +} diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayOptions.cs b/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayToolSourceRegistrations.cs similarity index 52% rename from src/ManagedCode.MCPGateway/Models/McpGatewayOptions.cs rename to src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayToolSourceRegistrations.cs index 6c6657b..4ceef74 100644 --- a/src/ManagedCode.MCPGateway/Models/McpGatewayOptions.cs +++ b/src/ManagedCode.MCPGateway/Internal/Catalog/Sources/McpGatewayToolSourceRegistrations.cs @@ -1,150 +1,11 @@ +using System.Collections.Concurrent; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using ModelContextProtocol; using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; namespace ManagedCode.MCPGateway; -public sealed class McpGatewayOptions -{ - private readonly List _sourceRegistrations = []; - - public int DefaultSearchLimit { get; set; } = 8; - - public int MaxSearchResults { get; set; } = 20; - - public int MaxDescriptorLength { get; set; } = 4096; - - internal IReadOnlyList SourceRegistrations => _sourceRegistrations; - - public McpGatewayOptions AddTool(string sourceId, AITool tool, string? displayName = null) - => AddTool(tool, sourceId, displayName); - - public McpGatewayOptions AddTool(AITool tool, string sourceId = "local", string? displayName = null) - { - ArgumentNullException.ThrowIfNull(tool); - - GetOrAddLocalRegistration(sourceId, displayName).AddTool(tool); - return this; - } - - public McpGatewayOptions AddTools(string sourceId, IEnumerable tools, string? displayName = null) - => AddTools(tools, sourceId, displayName); - - public McpGatewayOptions AddTools(IEnumerable tools, string sourceId = "local", string? displayName = null) - { - ArgumentNullException.ThrowIfNull(tools); - - var registration = GetOrAddLocalRegistration(sourceId, displayName); - foreach (var tool in tools) - { - ArgumentNullException.ThrowIfNull(tool); - registration.AddTool(tool); - } - - return this; - } - - public McpGatewayOptions AddHttpServer( - string sourceId, - Uri endpoint, - IReadOnlyDictionary? headers = null, - string? displayName = null) - { - ArgumentNullException.ThrowIfNull(endpoint); - - _sourceRegistrations.Add(new McpGatewayHttpToolSourceRegistration( - ValidateSourceId(sourceId), - endpoint, - headers, - displayName)); - return this; - } - - public McpGatewayOptions AddStdioServer( - string sourceId, - string command, - IReadOnlyList? arguments = null, - string? workingDirectory = null, - IReadOnlyDictionary? environmentVariables = null, - string? displayName = null) - { - if (string.IsNullOrWhiteSpace(command)) - { - throw new ArgumentException("A command is required.", nameof(command)); - } - - _sourceRegistrations.Add(new McpGatewayStdioToolSourceRegistration( - ValidateSourceId(sourceId), - command.Trim(), - arguments, - workingDirectory, - environmentVariables, - displayName)); - return this; - } - - public McpGatewayOptions AddMcpClient( - string sourceId, - McpClient client, - bool disposeClient = false, - string? displayName = null) - { - ArgumentNullException.ThrowIfNull(client); - - _sourceRegistrations.Add(new McpGatewayProvidedClientToolSourceRegistration( - ValidateSourceId(sourceId), - _ => ValueTask.FromResult(client), - disposeClient, - displayName)); - return this; - } - - public McpGatewayOptions AddMcpClientFactory( - string sourceId, - Func> clientFactory, - bool disposeClient = true, - string? displayName = null) - { - ArgumentNullException.ThrowIfNull(clientFactory); - - _sourceRegistrations.Add(new McpGatewayProvidedClientToolSourceRegistration( - ValidateSourceId(sourceId), - clientFactory, - disposeClient, - displayName)); - return this; - } - - private McpGatewayLocalToolSourceRegistration GetOrAddLocalRegistration(string sourceId, string? displayName) - { - sourceId = ValidateSourceId(sourceId); - - var existing = _sourceRegistrations - .OfType() - .FirstOrDefault(item => string.Equals(item.SourceId, sourceId, StringComparison.OrdinalIgnoreCase)); - if (existing is not null) - { - return existing; - } - - var created = new McpGatewayLocalToolSourceRegistration(sourceId, displayName); - _sourceRegistrations.Add(created); - return created; - } - - private static string ValidateSourceId(string sourceId) - { - if (string.IsNullOrWhiteSpace(sourceId)) - { - throw new ArgumentException("A source id is required.", nameof(sourceId)); - } - - return sourceId.Trim(); - } -} - internal enum McpGatewaySourceRegistrationKind { Local, @@ -172,16 +33,16 @@ public abstract ValueTask> LoadToolsAsync( internal sealed class McpGatewayLocalToolSourceRegistration(string sourceId, string? displayName) : McpGatewayToolSourceRegistration(sourceId, displayName) { - private readonly List _tools = []; + private readonly ConcurrentQueue _tools = new(); public override McpGatewaySourceRegistrationKind Kind => McpGatewaySourceRegistrationKind.Local; - public void AddTool(AITool tool) => _tools.Add(tool); + public void AddTool(AITool tool) => _tools.Enqueue(tool); public override ValueTask> LoadToolsAsync( ILoggerFactory loggerFactory, CancellationToken cancellationToken) - => ValueTask.FromResult>(_tools.ToList()); + => ValueTask.FromResult>(_tools.ToArray()); } internal sealed class McpGatewayHttpToolSourceRegistration( @@ -283,10 +144,10 @@ internal abstract class McpGatewayClientToolSourceRegistration( bool disposeClient) : McpGatewayToolSourceRegistration(sourceId, displayName) { - private readonly SemaphoreSlim _sync = new(1, 1); private readonly bool _disposeClient = disposeClient; private McpClient? _client; - private Task? _clientTask; + private ClientOperation? _clientOperation; + private int _disposed; public override async ValueTask> LoadToolsAsync( ILoggerFactory loggerFactory, @@ -303,12 +164,16 @@ protected abstract ValueTask CreateClientAsync( public override async ValueTask DisposeAsync() { - if (_disposeClient && _client is not null) + if (Interlocked.Exchange(ref _disposed, 1) != 0) { - await _client.DisposeAsync(); + return; + } + + if (_disposeClient && Volatile.Read(ref _client) is { } client) + { + await client.DisposeAsync(); } - _sync.Dispose(); await base.DisposeAsync(); } @@ -316,32 +181,74 @@ private async Task GetClientAsync( ILoggerFactory loggerFactory, CancellationToken cancellationToken) { - if (_client is not null) - { - return _client; - } + ObjectDisposedException.ThrowIf(Volatile.Read(ref _disposed) != 0, this); - if (_clientTask is not null) + if (Volatile.Read(ref _client) is { } client) { - return await AwaitClientTaskAsync(_clientTask, cancellationToken); + return client; } - await _sync.WaitAsync(cancellationToken); - try + var clientTask = Volatile.Read(ref _clientOperation); + while (!cancellationToken.IsCancellationRequested) { - if (_client is not null) + if (clientTask is null) + { + var clientSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var createdTask = new ClientOperation(clientSource.Task, cancellationToken); + if (Interlocked.CompareExchange(ref _clientOperation, createdTask, null) is null) + { + _ = RunCreateClientAsync(clientSource, loggerFactory, createdTask); + clientTask = createdTask; + break; + } + + clientTask = Volatile.Read(ref _clientOperation); + continue; + } + + if (clientTask.CancellationToken.IsCancellationRequested) { - return _client; + await AwaitCanceledClientCreationAsync(clientTask); + _ = Interlocked.CompareExchange(ref _clientOperation, null, clientTask); + clientTask = Volatile.Read(ref _clientOperation); + continue; } - _clientTask ??= CreateClientAsync(loggerFactory, CancellationToken.None).AsTask(); + if (clientTask.Task.IsCanceled || clientTask.Task.IsFaulted) + { + _ = Interlocked.CompareExchange(ref _clientOperation, null, clientTask); + clientTask = Volatile.Read(ref _clientOperation); + continue; + } + + break; } - finally + + if (clientTask is null) { - _sync.Release(); + cancellationToken.ThrowIfCancellationRequested(); } - return await AwaitClientTaskAsync(_clientTask, cancellationToken); + return await AwaitClientTaskAsync(clientTask!.Task, cancellationToken); + } + + private async Task RunCreateClientAsync( + TaskCompletionSource clientSource, + ILoggerFactory loggerFactory, + ClientOperation clientOperation) + { + try + { + clientSource.SetResult(await CreateClientAsync(loggerFactory, clientOperation.CancellationToken)); + } + catch (OperationCanceledException) when (clientOperation.CancellationToken.IsCancellationRequested) + { + clientSource.SetCanceled(clientOperation.CancellationToken); + } + catch (Exception ex) + { + clientSource.SetException(ex); + } } private async Task AwaitClientTaskAsync( @@ -350,38 +257,43 @@ private async Task AwaitClientTaskAsync( { try { - _client = await clientTask.WaitAsync(cancellationToken); - return _client; - } - catch when (clientTask.IsFaulted || clientTask.IsCanceled) - { - await _sync.WaitAsync(CancellationToken.None); - try + var client = await clientTask.WaitAsync(cancellationToken); + if (Volatile.Read(ref _disposed) != 0) { - if (ReferenceEquals(_clientTask, clientTask)) + if (_disposeClient) { - _clientTask = null; + await client.DisposeAsync(); } + + throw new ObjectDisposedException(GetType().Name); } - finally + + var cachedClient = Volatile.Read(ref _client); + return cachedClient ?? Interlocked.CompareExchange(ref _client, client, null) ?? client; + } + catch when (clientTask.IsFaulted || clientTask.IsCanceled) + { + if (Volatile.Read(ref _clientOperation) is { Task: { } currentTask } currentOperation && + ReferenceEquals(currentTask, clientTask)) { - _sync.Release(); + _ = Interlocked.CompareExchange(ref _clientOperation, null, currentOperation); } - throw; } } -} -internal static class McpGatewayClientFactory -{ - public static McpClientOptions CreateClientOptions() - => new() + private static async Task AwaitCanceledClientCreationAsync(ClientOperation clientOperation) + { + try { - ClientInfo = new Implementation - { - Name = "managedcode-mcpgateway", - Version = "1.0.0" - } - }; + await clientOperation.Task; + } + catch (OperationCanceledException) when (clientOperation.CancellationToken.IsCancellationRequested) + { + } + } + + private sealed record ClientOperation( + Task Task, + CancellationToken CancellationToken); } diff --git a/src/ManagedCode.MCPGateway/McpGatewayInMemoryToolEmbeddingStore.cs b/src/ManagedCode.MCPGateway/Internal/Embeddings/McpGatewayToolEmbeddingStoreIndex.cs similarity index 78% rename from src/ManagedCode.MCPGateway/McpGatewayInMemoryToolEmbeddingStore.cs rename to src/ManagedCode.MCPGateway/Internal/Embeddings/McpGatewayToolEmbeddingStoreIndex.cs index acc60a0..3f5c095 100644 --- a/src/ManagedCode.MCPGateway/McpGatewayInMemoryToolEmbeddingStore.cs +++ b/src/ManagedCode.MCPGateway/Internal/Embeddings/McpGatewayToolEmbeddingStoreIndex.cs @@ -1,13 +1,12 @@ using System.Collections.Concurrent; -using ManagedCode.MCPGateway.Abstractions; namespace ManagedCode.MCPGateway; -public sealed class McpGatewayInMemoryToolEmbeddingStore : IMcpGatewayToolEmbeddingStore +internal sealed class McpGatewayToolEmbeddingStoreIndex { private readonly ConcurrentDictionary _embeddings = new(); - public Task> GetAsync( + public IReadOnlyList Get( IReadOnlyList lookups, CancellationToken cancellationToken = default) { @@ -22,22 +21,38 @@ public Task> GetAsync( } } - return Task.FromResult>(results); + return results; } - public Task UpsertAsync( + public IReadOnlyList Upsert( IReadOnlyList embeddings, CancellationToken cancellationToken = default) { cancellationToken.ThrowIfCancellationRequested(); - foreach (var embedding in embeddings) + var clones = embeddings + .Select(Clone) + .ToList(); + + foreach (var embedding in clones) { - var clone = Clone(embedding); - _embeddings[StoreKey.FromEmbedding(clone)] = clone; + _embeddings[StoreKey.FromEmbedding(embedding)] = embedding; } - return Task.CompletedTask; + return clones; + } + + public void Remove(string toolId) + { + var normalizedToolId = NormalizeToolId(toolId); + var keys = _embeddings.Keys + .Where(key => string.Equals(key.NormalizedToolId, normalizedToolId, StringComparison.Ordinal)) + .ToList(); + + foreach (var key in keys) + { + _embeddings.TryRemove(key, out _); + } } private bool TryGetEmbedding( diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Catalog/McpGatewayRuntime.Descriptors.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Catalog/McpGatewayRuntime.Descriptors.cs new file mode 100644 index 0000000..6c56beb --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Catalog/McpGatewayRuntime.Descriptors.cs @@ -0,0 +1,227 @@ +using System.Text; +using System.Text.Json; +using Microsoft.Extensions.AI; +using ModelContextProtocol.Client; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private static McpGatewayToolDescriptor? BuildDescriptor( + McpGatewayToolSourceRegistration registration, + AITool tool) + { + if (string.IsNullOrWhiteSpace(tool.Name)) + { + return null; + } + + var toolName = tool.Name.Trim(); + var sourceKind = registration.Kind switch + { + McpGatewaySourceRegistrationKind.Http => McpGatewaySourceKind.HttpMcp, + McpGatewaySourceRegistrationKind.Stdio => McpGatewaySourceKind.StdioMcp, + McpGatewaySourceRegistrationKind.CustomMcpClient => McpGatewaySourceKind.CustomMcpClient, + _ => McpGatewaySourceKind.Local + }; + + var inputSchema = ResolveInputSchema(tool); + + return new McpGatewayToolDescriptor( + ToolId: $"{registration.SourceId}:{toolName}", + SourceId: registration.SourceId, + SourceKind: sourceKind, + ToolName: toolName, + DisplayName: ResolveDisplayName(tool), + Description: tool.Description ?? string.Empty, + RequiredArguments: inputSchema.RequiredArguments, + InputSchemaJson: inputSchema.Json); + } + + private string BuildDescriptorDocument(McpGatewayToolDescriptor descriptor, AITool tool) + { + var builder = new StringBuilder(); + builder.Append(ToolNameLabel); + builder.AppendLine(descriptor.ToolName); + + if (!string.IsNullOrWhiteSpace(descriptor.DisplayName)) + { + builder.Append(DisplayNameLabel); + builder.AppendLine(descriptor.DisplayName); + } + + if (!string.IsNullOrWhiteSpace(descriptor.Description)) + { + builder.Append(DescriptionLabel); + builder.AppendLine(descriptor.Description); + } + + if (descriptor.RequiredArguments.Count > 0) + { + builder.Append(RequiredArgumentsLabel); + builder.AppendLine(string.Join(", ", descriptor.RequiredArguments)); + } + + AppendInputSchema(builder, descriptor.InputSchemaJson); + var document = builder.ToString().Trim(); + return document.Length <= _maxDescriptorLength + ? document + : document[.._maxDescriptorLength]; + } + + private static void AppendInputSchema(StringBuilder builder, string? inputSchemaJson) + { + if (string.IsNullOrWhiteSpace(inputSchemaJson)) + { + return; + } + + try + { + using var schemaDocument = JsonDocument.Parse(inputSchemaJson); + if (!schemaDocument.RootElement.TryGetProperty(InputSchemaPropertiesPropertyName, out var properties) || + properties.ValueKind != JsonValueKind.Object) + { + return; + } + + foreach (var property in properties.EnumerateObject()) + { + builder.Append(ParameterLabel); + builder.Append(property.Name); + builder.Append(": "); + + if (property.Value.TryGetProperty(InputSchemaDescriptionPropertyName, out var description) && + description.ValueKind == JsonValueKind.String) + { + builder.Append(description.GetString()); + builder.Append(". "); + } + + if (property.Value.TryGetProperty(InputSchemaTypePropertyName, out var type) && + type.ValueKind == JsonValueKind.String) + { + builder.Append(TypeLabel); + builder.Append(type.GetString()); + builder.Append(". "); + } + + if (property.Value.TryGetProperty(InputSchemaEnumPropertyName, out var enumValues) && + enumValues.ValueKind == JsonValueKind.Array) + { + var values = new List(); + foreach (var enumValue in enumValues.EnumerateArray()) + { + if (enumValue.ValueKind != JsonValueKind.String) + { + continue; + } + + var value = enumValue.GetString(); + if (!string.IsNullOrWhiteSpace(value)) + { + values.Add(value); + } + } + + if (values.Count > 0) + { + builder.Append(TypicalValuesLabel); + builder.Append(string.Join(", ", values)); + builder.Append(". "); + } + } + + builder.AppendLine(); + } + } + catch (JsonException) + { + builder.Append(InputSchemaLabel); + builder.AppendLine(inputSchemaJson); + } + } + + private static string? ResolveDisplayName(AITool tool) + { + if (tool is McpClientTool mcpTool) + { + return mcpTool.ProtocolTool?.Title; + } + + var function = tool as AIFunction ?? tool.GetService(); + if (function?.AdditionalProperties is { Count: > 0 } && + function.AdditionalProperties.TryGetValue(DisplayNamePropertyName, out var displayName) && + displayName is string value && + !string.IsNullOrWhiteSpace(value)) + { + return value; + } + + return null; + } + + private static SerializedSchema ResolveInputSchema(AITool tool) + { + if (tool is McpClientTool mcpTool) + { + return SerializeSchema(mcpTool.ProtocolTool?.InputSchema); + } + + var function = tool as AIFunction ?? tool.GetService(); + if (function is null) + { + return SerializedSchema.Empty; + } + + return function.JsonSchema.ValueKind == JsonValueKind.Undefined + ? SerializedSchema.Empty + : SerializeSchema(function.JsonSchema); + } + + private static SerializedSchema SerializeSchema(object? schema) + { + if (McpGatewayJsonSerializer.TrySerializeToElement(schema) is not JsonElement serializedSchema) + { + return SerializedSchema.Empty; + } + + return new SerializedSchema( + serializedSchema.GetRawText(), + ExtractRequiredArguments(serializedSchema)); + } + + private static IReadOnlyList ExtractRequiredArguments(JsonElement schemaElement) + { + if (!schemaElement.TryGetProperty(InputSchemaRequiredPropertyName, out var required) || + required.ValueKind != JsonValueKind.Array) + { + return []; + } + + var requiredArguments = new List(); + var seenArguments = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var item in required.EnumerateArray()) + { + if (item.ValueKind != JsonValueKind.String) + { + continue; + } + + var value = item.GetString(); + if (string.IsNullOrWhiteSpace(value) || !seenArguments.Add(value)) + { + continue; + } + + requiredArguments.Add(value); + } + + return requiredArguments; + } + + private sealed record SerializedSchema(string? Json, IReadOnlyList RequiredArguments) + { + public static SerializedSchema Empty { get; } = new(null, []); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Catalog/McpGatewayRuntime.Indexing.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Catalog/McpGatewayRuntime.Indexing.cs new file mode 100644 index 0000000..c8d615e --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Catalog/McpGatewayRuntime.Indexing.cs @@ -0,0 +1,367 @@ +using System.Globalization; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + public async Task BuildIndexAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + var existingBuild = Volatile.Read(ref _buildOperation); + while (!cancellationToken.IsCancellationRequested) + { + ThrowIfDisposed(); + + if (existingBuild is null) + { + var buildSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var createdBuild = new BuildOperation(buildSource.Task, cancellationToken); + if (Interlocked.CompareExchange(ref _buildOperation, createdBuild, null) is null) + { + _ = RunBuildIndexAsync(buildSource, createdBuild); + existingBuild = createdBuild; + break; + } + + existingBuild = Volatile.Read(ref _buildOperation); + continue; + } + + if (existingBuild.CancellationToken.IsCancellationRequested) + { + await AwaitCanceledBuildAsync(existingBuild); + _ = Interlocked.CompareExchange(ref _buildOperation, null, existingBuild); + existingBuild = Volatile.Read(ref _buildOperation); + continue; + } + + if (existingBuild.Task.IsCanceled || existingBuild.Task.IsFaulted) + { + _ = Interlocked.CompareExchange(ref _buildOperation, null, existingBuild); + existingBuild = Volatile.Read(ref _buildOperation); + continue; + } + + break; + } + + if (existingBuild is null) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + return await existingBuild!.Task.WaitAsync(cancellationToken); + } + + private async Task RunBuildIndexAsync( + TaskCompletionSource buildSource, + BuildOperation buildOperation) + { + try + { + buildSource.SetResult(await BuildIndexCoreAsync(buildOperation.CancellationToken)); + } + catch (OperationCanceledException) when (buildOperation.CancellationToken.IsCancellationRequested) + { + buildSource.SetCanceled(buildOperation.CancellationToken); + } + catch (Exception ex) + { + buildSource.SetException(ex); + } + finally + { + _ = Interlocked.CompareExchange(ref _buildOperation, null, buildOperation); + } + } + + private async Task BuildIndexCoreAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + ThrowIfDisposed(); + + var registrySnapshot = _catalogSource.CreateSnapshot(); + var diagnostics = new List(); + var entries = new List(); + var seenToolIds = new HashSet(StringComparer.OrdinalIgnoreCase); + + foreach (var registration in registrySnapshot.Registrations) + { + IReadOnlyList tools; + try + { + tools = await registration.LoadToolsAsync(_loggerFactory, cancellationToken); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + diagnostics.Add(new McpGatewayDiagnostic( + SourceLoadFailedDiagnosticCode, + string.Format( + CultureInfo.InvariantCulture, + SourceLoadFailedMessageFormat, + registration.SourceId, + ex.GetBaseException().Message))); + _logger.LogWarning(ex, FailedToLoadGatewaySourceLogMessage, registration.SourceId); + continue; + } + + foreach (var tool in tools) + { + var descriptor = BuildDescriptor(registration, tool); + if (descriptor is null) + { + continue; + } + + if (!seenToolIds.Add(descriptor.ToolId)) + { + diagnostics.Add(new McpGatewayDiagnostic( + DuplicateToolIdDiagnosticCode, + string.Format(CultureInfo.InvariantCulture, DuplicateToolIdMessageFormat, descriptor.ToolId))); + continue; + } + + var tokenSearchSegments = BuildDescriptorTokenSearchSegments(descriptor); + var searchFields = BuildTokenizedSearchFields(tokenSearchSegments); + entries.Add(new ToolCatalogEntry( + descriptor, + tool, + BuildDescriptorDocument(descriptor, tool), + searchFields, + BuildLexicalTerms(tokenSearchSegments), + TokenSearchProfile.Empty, + TokenSearchProfile.Empty)); + } + } + + var rawTokenProfiles = entries + .Select(entry => BuildTokenSearchProfile(entry.SearchFields)) + .ToList(); + var rawCharacterNGramProfiles = entries + .Select(entry => BuildCharacterNGramProfile(entry.SearchFields)) + .ToList(); + var tokenInverseDocumentFrequencies = BuildTokenInverseDocumentFrequencies(rawTokenProfiles); + var characterNGramInverseDocumentFrequencies = BuildTokenInverseDocumentFrequencies(rawCharacterNGramProfiles); + var averageSearchFieldLength = CalculateAverageSearchFieldLength(entries); + for (var index = 0; index < entries.Count; index++) + { + entries[index] = entries[index] with + { + TokenProfile = ApplyTokenInverseDocumentFrequencies(rawTokenProfiles[index], tokenInverseDocumentFrequencies), + CharacterNGramProfile = ApplyTokenInverseDocumentFrequencies( + rawCharacterNGramProfiles[index], + characterNGramInverseDocumentFrequencies) + }; + } + + var vectorizedToolCount = 0; + var isVectorSearchEnabled = false; + if (entries.Count > 0 && _searchStrategy is not McpGatewaySearchStrategy.Tokenizer) + { + await using var embeddingGeneratorLease = ResolveEmbeddingGenerator(); + await using var embeddingStoreLease = ResolveToolEmbeddingStore(); + var embeddingGenerator = embeddingGeneratorLease.Generator; + var embeddingGeneratorFingerprint = ResolveEmbeddingGeneratorFingerprint(embeddingGenerator); + var embeddingStore = embeddingStoreLease.Store; + var storeCandidates = entries + .Select((entry, index) => new ToolEmbeddingCandidate( + index, + new McpGatewayToolEmbeddingLookup( + entry.Descriptor.ToolId, + ComputeDocumentHash(entry.Document), + embeddingGeneratorFingerprint), + entry.Descriptor.SourceId, + entry.Descriptor.ToolName)) + .ToList(); + + if (embeddingStore is not null) + { + try + { + var storedEmbeddings = await embeddingStore.GetAsync( + storeCandidates.Select(static candidate => candidate.Lookup).ToList(), + cancellationToken); + + foreach (var candidate in storeCandidates) + { + var storedEmbedding = storedEmbeddings.LastOrDefault(embedding => + MatchesStoredEmbedding(candidate.Lookup, embedding)); + if (storedEmbedding is not null) + { + ApplyEmbedding(entries, candidate.Index, storedEmbedding.Vector, ref vectorizedToolCount); + } + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + diagnostics.Add(new McpGatewayDiagnostic( + EmbeddingStoreLoadFailedDiagnosticCode, + string.Format(CultureInfo.InvariantCulture, EmbeddingStoreLoadFailedMessageFormat, ex.GetBaseException().Message))); + _logger.LogWarning(ex, EmbeddingStoreLoadFailedLogMessage); + } + } + + var missingCandidates = storeCandidates + .Where(candidate => entries[candidate.Index].Magnitude <= double.Epsilon) + .ToList(); + + if (embeddingGenerator is null && vectorizedToolCount > 0) + { + diagnostics.Add(new McpGatewayDiagnostic( + EmbeddingGeneratorMissingDiagnosticCode, + EmbeddingGeneratorMissingMessage)); + } + + if (missingCandidates.Count > 0) + { + try + { + if (embeddingGenerator is not null) + { + var embeddings = (await embeddingGenerator.GenerateAsync( + missingCandidates.Select(candidate => entries[candidate.Index].Document), + cancellationToken: cancellationToken)) + .ToList(); + if (embeddings.Count == missingCandidates.Count) + { + var generatedEmbeddings = new List(missingCandidates.Count); + for (var index = 0; index < missingCandidates.Count; index++) + { + var candidate = missingCandidates[index]; + var vector = embeddings[index].Vector.ToArray(); + if (ApplyEmbedding(entries, candidate.Index, vector, ref vectorizedToolCount)) + { + generatedEmbeddings.Add(new McpGatewayToolEmbedding( + candidate.Lookup.ToolId, + candidate.SourceId, + candidate.ToolName, + candidate.Lookup.DocumentHash, + candidate.Lookup.EmbeddingGeneratorFingerprint, + vector)); + } + } + + if (generatedEmbeddings.Count > 0 && embeddingStore is not null) + { + try + { + await embeddingStore.UpsertAsync(generatedEmbeddings, cancellationToken); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + diagnostics.Add(new McpGatewayDiagnostic( + EmbeddingStoreSaveFailedDiagnosticCode, + string.Format(CultureInfo.InvariantCulture, EmbeddingStoreSaveFailedMessageFormat, ex.GetBaseException().Message))); + _logger.LogWarning(ex, EmbeddingStoreSaveFailedLogMessage); + } + } + } + else + { + diagnostics.Add(new McpGatewayDiagnostic( + EmbeddingCountMismatchDiagnosticCode, + string.Format( + CultureInfo.InvariantCulture, + EmbeddingCountMismatchMessageFormat, + embeddings.Count, + missingCandidates.Count))); + } + } + else + { + diagnostics.Add(new McpGatewayDiagnostic( + EmbeddingGeneratorMissingDiagnosticCode, + EmbeddingGeneratorMissingMessage)); + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + diagnostics.Add(new McpGatewayDiagnostic( + EmbeddingFailedDiagnosticCode, + string.Format(CultureInfo.InvariantCulture, EmbeddingFailedMessageFormat, ex.GetBaseException().Message))); + _logger.LogWarning(ex, EmbeddingGenerationFailedLogMessage); + } + } + + isVectorSearchEnabled = vectorizedToolCount > 0 && embeddingGenerator is not null; + } + + var snapshot = new ToolCatalogSnapshot( + entries + .OrderBy(static item => item.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) + .ThenBy(static item => item.Descriptor.SourceId, StringComparer.OrdinalIgnoreCase) + .ToList(), + isVectorSearchEnabled, + tokenInverseDocumentFrequencies, + characterNGramInverseDocumentFrequencies, + averageSearchFieldLength); + + TryUpdateState(snapshot, registrySnapshot.Version); + + _logger.LogInformation( + GatewayIndexRebuiltLogMessage, + snapshot.Entries.Count, + vectorizedToolCount); + + return new McpGatewayIndexBuildResult( + snapshot.Entries.Count, + vectorizedToolCount, + snapshot.HasVectors, + diagnostics); + } + + private void TryUpdateState(ToolCatalogSnapshot snapshot, int snapshotVersion) + { + var state = Volatile.Read(ref _state); + while (!state.IsDisposed) + { + var updatedState = state with + { + Snapshot = snapshot, + SnapshotVersion = snapshotVersion + }; + if (ReferenceEquals(Interlocked.CompareExchange(ref _state, updatedState, state), state)) + { + return; + } + + state = Volatile.Read(ref _state); + } + } + + private static double CalculateAverageSearchFieldLength(IReadOnlyList entries) + { + var totalLength = 0d; + var fieldCount = 0; + foreach (var entry in entries) + { + foreach (var field in entry.SearchFields) + { + totalLength += Math.Max(1, field.Length); + fieldCount++; + } + } + + return fieldCount == 0 + ? 1d + : totalLength / fieldCount; + } + + private static async Task AwaitCanceledBuildAsync(BuildOperation buildOperation) + { + try + { + await buildOperation.Task; + } + catch (OperationCanceledException) when (buildOperation.CancellationToken.IsCancellationRequested) + { + } + } + + private sealed record BuildOperation( + Task Task, + CancellationToken CancellationToken); +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.Snapshot.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.Snapshot.cs new file mode 100644 index 0000000..e9c0135 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.Snapshot.cs @@ -0,0 +1,25 @@ +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private async Task GetSnapshotAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + while (!cancellationToken.IsCancellationRequested) + { + ThrowIfDisposed(); + var registrySnapshot = _catalogSource.CreateSnapshot(); + var state = Volatile.Read(ref _state); + if (state.SnapshotVersion == registrySnapshot.Version) + { + return state.Snapshot; + } + + await BuildIndexAsync(cancellationToken); + } + + cancellationToken.ThrowIfCancellationRequested(); + return Volatile.Read(ref _state).Snapshot; + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.Types.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.Types.cs new file mode 100644 index 0000000..d307c02 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.Types.cs @@ -0,0 +1,187 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private sealed record InvocationResolution(bool IsSuccess, ToolCatalogEntry? Entry, string? Error) + { + public static InvocationResolution Success(ToolCatalogEntry entry) => new(true, entry, null); + + public static InvocationResolution Fail(string error) => new(false, null, error); + } + + private sealed record ToolEmbeddingCandidate( + int Index, + McpGatewayToolEmbeddingLookup Lookup, + string SourceId, + string ToolName); + + private sealed record ScoredToolEntry(ToolCatalogEntry Entry, double Score); + + private sealed record RetrievalCandidate( + ToolCatalogEntry Entry, + double Bm25Score, + double TokenSimilarity, + double CharacterNGramSimilarity, + double LegacyLexicalScore); + + private sealed record ToolCatalogEntry( + McpGatewayToolDescriptor Descriptor, + AITool Tool, + string Document, + IReadOnlyList SearchFields, + IReadOnlySet LexicalTerms, + TokenSearchProfile TokenProfile, + TokenSearchProfile CharacterNGramProfile, + float[]? Vector = null, + double Magnitude = 0d); + + private sealed record ToolCatalogSnapshot( + IReadOnlyList Entries, + bool HasVectors, + IReadOnlyDictionary TokenInverseDocumentFrequencies, + IReadOnlyDictionary CharacterNGramInverseDocumentFrequencies, + double AverageSearchFieldLength) + { + public static ToolCatalogSnapshot Empty { get; } = new([], false, EmptyTokenWeights, EmptyTokenWeights, 1d); + } + + private sealed record RuntimeState( + ToolCatalogSnapshot Snapshot, + int SnapshotVersion, + bool IsDisposed) + { + public static RuntimeState Empty { get; } = new(ToolCatalogSnapshot.Empty, -1, false); + + public static RuntimeState Disposed { get; } = new(ToolCatalogSnapshot.Empty, -1, true); + } + + private sealed record WeightedTextSegment(string Text, double Weight); + + private sealed record TokenizedSearchField( + double Weight, + int Length, + IReadOnlyDictionary TermWeights, + IReadOnlyDictionary CharacterNGramWeights); + + private sealed record TokenSearchProfile( + IReadOnlyDictionary TermWeights, + IReadOnlySet Terms, + double Magnitude, + double TotalWeight) + { + public static TokenSearchProfile Empty { get; } = new( + EmptyTokenWeights, + new HashSet(StringComparer.OrdinalIgnoreCase), + 0d, + 0d); + } + + private sealed record SearchInput( + string? OriginalQuery, + string? NormalizedQuery, + string? ContextSummary, + string? FlattenedContext) + { + private const string SearchInputSegmentSeparator = " | "; + + public string EffectiveQuery + => BuildEffectiveQuery( + NormalizedQuery ?? OriginalQuery, + ContextSummary, + FlattenedContext); + + public string BoostQuery + => NormalizedQuery ?? OriginalQuery ?? EffectiveQuery; + + private static string BuildEffectiveQuery( + string? query, + string? contextSummary, + string? flattenedContext) + { + if (query is null) + { + if (contextSummary is null) + { + return flattenedContext is null + ? string.Empty + : string.Concat(ContextPrefix, flattenedContext); + } + + if (flattenedContext is null) + { + return string.Concat(ContextSummaryPrefix, contextSummary); + } + + return string.Concat( + ContextSummaryPrefix, + contextSummary, + SearchInputSegmentSeparator, + ContextPrefix, + flattenedContext); + } + + if (contextSummary is null) + { + return flattenedContext is null + ? query + : string.Concat( + query, + SearchInputSegmentSeparator, + ContextPrefix, + flattenedContext); + } + + if (flattenedContext is null) + { + return string.Concat( + query, + SearchInputSegmentSeparator, + ContextSummaryPrefix, + contextSummary); + } + + return string.Concat( + query, + SearchInputSegmentSeparator, + ContextSummaryPrefix, + contextSummary, + SearchInputSegmentSeparator, + ContextPrefix, + flattenedContext); + } + } + + private sealed class EmbeddingGeneratorLease( + IEmbeddingGenerator>? generator, + AsyncServiceScope? scope = null) + : IAsyncDisposable + { + public IEmbeddingGenerator>? Generator { get; } = generator; + + public ValueTask DisposeAsync() => scope?.DisposeAsync() ?? ValueTask.CompletedTask; + } + + private sealed class ToolEmbeddingStoreLease( + IMcpGatewayToolEmbeddingStore? store, + AsyncServiceScope? scope = null) + : IAsyncDisposable + { + public IMcpGatewayToolEmbeddingStore? Store { get; } = store; + + public ValueTask DisposeAsync() => scope?.DisposeAsync() ?? ValueTask.CompletedTask; + } + + private sealed class ChatClientLease( + IChatClient? client, + AsyncServiceScope? scope = null) + : IAsyncDisposable + { + public IChatClient? Client { get; } = client; + + public ValueTask DisposeAsync() => scope?.DisposeAsync() ?? ValueTask.CompletedTask; + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.cs new file mode 100644 index 0000000..38f0a0e --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Core/McpGatewayRuntime.cs @@ -0,0 +1,246 @@ +using System.Text; +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.ML.Tokenizers; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime : IMcpGateway +{ + private const string QueryArgumentName = "query"; + private const string ContextArgumentName = "context"; + private const string ContextSummaryArgumentName = "contextSummary"; + private const string GatewayInvocationMetaKey = "managedCodeMcpGateway"; + private const string SearchModeEmpty = "empty"; + private const string SearchModeBrowse = "browse"; + private const string SearchModeLexical = "lexical"; + private const string SearchModeVector = "vector"; + private const string SourceLoadFailedDiagnosticCode = "source_load_failed"; + private const string DuplicateToolIdDiagnosticCode = "duplicate_tool_id"; + private const string EmbeddingCountMismatchDiagnosticCode = "embedding_count_mismatch"; + private const string EmbeddingGeneratorMissingDiagnosticCode = "embedding_generator_missing"; + private const string EmbeddingFailedDiagnosticCode = "embedding_failed"; + private const string EmbeddingStoreLoadFailedDiagnosticCode = "embedding_store_load_failed"; + private const string EmbeddingStoreSaveFailedDiagnosticCode = "embedding_store_save_failed"; + private const string QueryVectorEmptyDiagnosticCode = "query_vector_empty"; + private const string QueryNormalizedDiagnosticCode = "query_normalized"; + private const string QueryNormalizationFailedDiagnosticCode = "query_normalization_failed"; + private const string LexicalFallbackDiagnosticCode = "lexical_fallback"; + private const string VectorSearchFailedDiagnosticCode = "vector_search_failed"; + private const string SourceLoadFailedMessageTemplate = "Failed to load tools from source '{0}': {1}"; + private const string DuplicateToolIdMessageTemplate = "Skipped duplicate tool id '{0}'."; + private const string EmbeddingCountMismatchMessageTemplate = "Embedding generation returned {0} vectors for {1} tools."; + private const string EmbeddingGeneratorMissingMessage = "No keyed or unkeyed IEmbeddingGenerator> is registered. Stored tool embeddings may be reused, but search falls back lexically without a query embedding generator."; + private const string EmbeddingFailedMessageTemplate = "Embedding generation failed: {0}"; + private const string EmbeddingStoreLoadFailedMessageTemplate = "Loading stored tool embeddings failed: {0}"; + private const string EmbeddingStoreSaveFailedMessageTemplate = "Persisting generated tool embeddings failed: {0}"; + private const string QueryVectorEmptyMessage = "Embedding generator returned an empty query vector."; + private const string QueryNormalizedMessage = "Search query was normalized to English before ranking."; + private const string QueryNormalizationFailedMessageTemplate = "Search query normalization failed and the original query was used: {0}"; + private const string LexicalFallbackMessage = "Vector search is unavailable. Lexical ranking was used."; + private const string VectorSearchFailedMessageTemplate = "Vector ranking failed and lexical fallback was used: {0}"; + private const string ToolNotInvokableMessageTemplate = "Tool '{0}' is not invokable."; + private const string ToolIdOrToolNameRequiredMessage = "Either ToolId or ToolName is required."; + private const string ToolIdNotFoundMessageTemplate = "Tool '{0}' was not found."; + private const string ToolNameAmbiguousMessageTemplate = "Tool '{0}' is ambiguous. Use ToolId or specify SourceId explicitly."; + private const string CatalogSourceMissingMessage = "ManagedCode.MCPGateway requires IMcpGatewayRegistry to be registered in the service provider. Use AddManagedCodeMcpGateway(...) to wire the package services."; + private const string FailedToLoadGatewaySourceLogMessage = "Failed to load gateway source {SourceId}."; + private const string EmbeddingGenerationFailedLogMessage = "Gateway embedding generation failed. Falling back to lexical search."; + private const string GatewayIndexRebuiltLogMessage = "Gateway index rebuilt. Tools={ToolCount} VectorizedTools={VectorizedToolCount}."; + private const string GatewayVectorSearchFailedLogMessage = "Gateway vector search failed. Falling back to lexical ranking."; + private const string GatewayInvocationFailedLogMessage = "Gateway invocation failed for {ToolId}."; + private const string EmbeddingStoreLoadFailedLogMessage = "Loading stored tool embeddings failed. Falling back to generator-backed indexing."; + private const string EmbeddingStoreSaveFailedLogMessage = "Persisting generated tool embeddings failed."; + private const string GatewayQueryNormalizationFailedLogMessage = "Gateway search query normalization failed. Using original query."; + private const string InputSchemaPropertiesPropertyName = "properties"; + private const string InputSchemaRequiredPropertyName = "required"; + private const string InputSchemaDescriptionPropertyName = "description"; + private const string InputSchemaTypePropertyName = "type"; + private const string InputSchemaEnumPropertyName = "enum"; + private const string DisplayNamePropertyName = "DisplayName"; + private const string ToolNameLabel = "Tool name: "; + private const string DisplayNameLabel = "Display name: "; + private const string DescriptionLabel = "Description: "; + private const string RequiredArgumentsLabel = "Required arguments: "; + private const string ParameterLabel = "Parameter "; + private const string TypeLabel = "Type "; + private const string TypicalValuesLabel = "Typical values: "; + private const string InputSchemaLabel = "Input schema: "; + private const string ContextSummaryPrefix = "context summary: "; + private const string ContextPrefix = "context: "; + private const string PluralSuffixIes = "ies"; + private const string PluralSuffixEs = "es"; + private const string CharacterNGramPrefix = "tri:"; + private const string EmbeddingGeneratorFingerprintUnknownComponent = "unknown"; + private const string EmbeddingGeneratorFingerprintComponentSeparator = "\n"; + private const string SearchQueryNormalizationInstructions = "Rewrite the user search request as a concise English tool-search query. Preserve identifiers, emails, repository names, CVE references, order numbers, tracking numbers, SKUs, version strings, filenames, and product names exactly. Do not answer the request. Do not explain anything. Return only the rewritten English search query. If the request is already concise English, return it unchanged."; + private const int CharacterNGramLength = 3; + private const int SearchQueryNormalizationMaxOutputTokens = 96; + private const int StageOneCandidatePoolSize = 10; + private const double ToolNameTokenWeight = 5d; + private const double DisplayNameTokenWeight = 4d; + private const double DescriptionTokenWeight = 3d; + private const double RequiredArgumentTokenWeight = 2.25d; + private const double ParameterNameTokenWeight = 2.5d; + private const double ParameterDescriptionTokenWeight = 2d; + private const double ParameterTypeTokenWeight = 0.75d; + private const double EnumValuesTokenWeight = 0d; + private const double HumanizedIdentifierWeightFactor = 0.85d; + private const double CharacterNGramTokenWeightFactor = 0.35d; + private const double QueryTokenWeight = 3d; + private const double OriginalQueryBackoffWeight = 1.25d; + private const double ContextSummaryTokenWeight = 1.5d; + private const double ContextTokenWeight = 1d; + private const double ReciprocalRankFusionConstant = 60d; + private const double Bm25K1 = 1.2d; + private const double Bm25FieldLengthNormalization = 0.75d; + private const double Bm25FeatureWeight = 0.2d; + private const double TokenSimilarityWeight = 0.16d; + private const double CharacterNGramSimilarityWeight = 0.12d; + private const double TokenCoverageWeight = 0.1d; + private const double DistinctCoverageWeight = 0.15d; + private const double LexicalSimilarityWeight = 0.06d; + private const double LegacyLexicalFeatureWeight = 0.16d; + private const double ToolNameSignalWeight = 0.05d; + + private static readonly char[] TokenSeparators = + [ + ' ', + '\t', + '\r', + '\n', + '_', + '-', + '.', + ',', + ';', + ':', + '/', + '\\', + '(', + ')', + '[', + ']', + '{', + '}', + '"', + '\'', + '@', + '?', + '!' + ]; + private static readonly IReadOnlySet IgnoredSearchTerms = new HashSet(StringComparer.OrdinalIgnoreCase) + { + "a", + "an", + "and", + "again", + "any", + "for", + "just", + "me", + "need", + "now", + "please", + "plz", + "really", + "something", + "stuff", + "that", + "the", + "thing", + "this", + "to", + "with" + }; + private static readonly CompositeFormat SourceLoadFailedMessageFormat = CompositeFormat.Parse(SourceLoadFailedMessageTemplate); + private static readonly CompositeFormat DuplicateToolIdMessageFormat = CompositeFormat.Parse(DuplicateToolIdMessageTemplate); + private static readonly CompositeFormat EmbeddingCountMismatchMessageFormat = CompositeFormat.Parse(EmbeddingCountMismatchMessageTemplate); + private static readonly CompositeFormat EmbeddingFailedMessageFormat = CompositeFormat.Parse(EmbeddingFailedMessageTemplate); + private static readonly CompositeFormat EmbeddingStoreLoadFailedMessageFormat = CompositeFormat.Parse(EmbeddingStoreLoadFailedMessageTemplate); + private static readonly CompositeFormat EmbeddingStoreSaveFailedMessageFormat = CompositeFormat.Parse(EmbeddingStoreSaveFailedMessageTemplate); + private static readonly CompositeFormat QueryNormalizationFailedMessageFormat = CompositeFormat.Parse(QueryNormalizationFailedMessageTemplate); + private static readonly CompositeFormat VectorSearchFailedMessageFormat = CompositeFormat.Parse(VectorSearchFailedMessageTemplate); + private static readonly CompositeFormat ToolNotInvokableMessageFormat = CompositeFormat.Parse(ToolNotInvokableMessageTemplate); + private static readonly CompositeFormat ToolIdNotFoundMessageFormat = CompositeFormat.Parse(ToolIdNotFoundMessageTemplate); + private static readonly CompositeFormat ToolNameAmbiguousMessageFormat = CompositeFormat.Parse(ToolNameAmbiguousMessageTemplate); + private static readonly IReadOnlyDictionary EmptyTokenWeights = + new Dictionary(StringComparer.OrdinalIgnoreCase); + + private readonly IServiceProvider _serviceProvider; + private readonly ILogger _logger; + private readonly ILoggerFactory _loggerFactory; + private readonly IMcpGatewayCatalogSource _catalogSource; + private readonly McpGatewaySearchStrategy _searchStrategy; + private readonly McpGatewaySearchQueryNormalization _searchQueryNormalization; + private readonly Tokenizer _searchTokenizer; + private readonly int _defaultSearchLimit; + private readonly int _maxSearchResults; + private readonly int _maxDescriptorLength; + private RuntimeState _state = RuntimeState.Empty; + private BuildOperation? _buildOperation; + + internal McpGatewayRuntime( + IServiceProvider serviceProvider, + IOptions options, + ILogger logger, + ILoggerFactory loggerFactory) + { + ArgumentNullException.ThrowIfNull(serviceProvider); + ArgumentNullException.ThrowIfNull(options); + ArgumentNullException.ThrowIfNull(logger); + ArgumentNullException.ThrowIfNull(loggerFactory); + + _serviceProvider = serviceProvider; + _logger = logger; + _loggerFactory = loggerFactory; + + var resolvedOptions = options.Value; + _catalogSource = ResolveCatalogSource(serviceProvider); + _searchStrategy = resolvedOptions.SearchStrategy; + _searchQueryNormalization = resolvedOptions.SearchQueryNormalization; + _searchTokenizer = McpGatewaySearchTokenizerFactory.GetTokenizer(); + _defaultSearchLimit = Math.Max(1, resolvedOptions.DefaultSearchLimit); + _maxSearchResults = Math.Max(1, resolvedOptions.MaxSearchResults); + _maxDescriptorLength = Math.Max(256, resolvedOptions.MaxDescriptorLength); + } + + public IReadOnlyList CreateMetaTools( + string searchToolName = McpGatewayToolSet.DefaultSearchToolName, + string invokeToolName = McpGatewayToolSet.DefaultInvokeToolName) + => new McpGatewayToolSet(this).CreateTools(searchToolName, invokeToolName); + + public ValueTask DisposeAsync() + { + var previousState = Interlocked.Exchange(ref _state, RuntimeState.Disposed); + if (previousState.IsDisposed) + { + return ValueTask.CompletedTask; + } + + return ValueTask.CompletedTask; + } + + private static IMcpGatewayCatalogSource ResolveCatalogSource(IServiceProvider serviceProvider) + { + if (serviceProvider.GetService() is IMcpGatewayCatalogSource catalogSource) + { + return catalogSource; + } + + if (serviceProvider.GetService() is IMcpGatewayCatalogSource registryCatalogSource) + { + return registryCatalogSource; + } + + throw new InvalidOperationException(CatalogSourceMissingMessage); + } + + private void ThrowIfDisposed() + { + ObjectDisposedException.ThrowIf(Volatile.Read(ref _state).IsDisposed, this); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Embeddings/McpGatewayRuntime.Embeddings.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Embeddings/McpGatewayRuntime.Embeddings.cs new file mode 100644 index 0000000..eb17cba --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Embeddings/McpGatewayRuntime.Embeddings.cs @@ -0,0 +1,153 @@ +using System.Globalization; +using System.Security.Cryptography; +using System.Text; +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private EmbeddingGeneratorLease ResolveEmbeddingGenerator() + { + if (_serviceProvider.GetService(typeof(IServiceScopeFactory)) is not IServiceScopeFactory scopeFactory) + { + return new EmbeddingGeneratorLease(ResolveEmbeddingGenerator(_serviceProvider)); + } + + var scope = scopeFactory.CreateAsyncScope(); + var generator = ResolveEmbeddingGenerator(scope.ServiceProvider); + return new EmbeddingGeneratorLease(generator, scope); + } + + private static IEmbeddingGenerator>? ResolveEmbeddingGenerator(IServiceProvider serviceProvider) + => serviceProvider.GetKeyedService>>(McpGatewayServiceKeys.EmbeddingGenerator) + ?? serviceProvider.GetService>>(); + + private ToolEmbeddingStoreLease ResolveToolEmbeddingStore() + { + if (_serviceProvider.GetService(typeof(IServiceScopeFactory)) is not IServiceScopeFactory scopeFactory) + { + return new ToolEmbeddingStoreLease(_serviceProvider.GetService()); + } + + var scope = scopeFactory.CreateAsyncScope(); + var store = scope.ServiceProvider.GetService(); + return new ToolEmbeddingStoreLease(store, scope); + } + + private static double CalculateCosine(ToolCatalogEntry entry, float[] queryVector, double queryMagnitude) + { + if (entry.Vector is null || entry.Magnitude <= double.Epsilon || queryMagnitude <= double.Epsilon) + { + return 0d; + } + + var overlap = Math.Min(entry.Vector.Length, queryVector.Length); + if (overlap == 0) + { + return 0d; + } + + var dot = 0d; + for (var index = 0; index < overlap; index++) + { + dot += entry.Vector[index] * queryVector[index]; + } + + return dot / (entry.Magnitude * queryMagnitude); + } + + private static double CalculateMagnitude(IReadOnlyList vector) + { + if (vector.Count == 0) + { + return 0d; + } + + var magnitudeSquared = 0d; + foreach (var component in vector) + { + magnitudeSquared += component * component; + } + + return Math.Sqrt(magnitudeSquared); + } + + private static bool ApplyEmbedding( + IList entries, + int index, + IReadOnlyList vector, + ref int vectorizedToolCount) + { + if (vector.Count == 0) + { + return false; + } + + var normalizedVector = vector.ToArray(); + var magnitude = CalculateMagnitude(normalizedVector); + entries[index] = entries[index] with + { + Vector = normalizedVector, + Magnitude = magnitude + }; + + if (magnitude <= double.Epsilon) + { + return false; + } + + vectorizedToolCount++; + return true; + } + + private static string ComputeDocumentHash(string value) + => Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(value))); + + private static bool MatchesStoredEmbedding( + McpGatewayToolEmbeddingLookup lookup, + McpGatewayToolEmbedding embedding) + { + if (!string.Equals(embedding.ToolId, lookup.ToolId, StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + if (!string.Equals(embedding.DocumentHash, lookup.DocumentHash, StringComparison.Ordinal)) + { + return false; + } + + if (lookup.EmbeddingGeneratorFingerprint is null) + { + return true; + } + + return string.Equals( + embedding.EmbeddingGeneratorFingerprint, + lookup.EmbeddingGeneratorFingerprint, + StringComparison.Ordinal); + } + + private static string? ResolveEmbeddingGeneratorFingerprint( + IEmbeddingGenerator>? embeddingGenerator) + { + if (embeddingGenerator is null) + { + return null; + } + + var metadata = embeddingGenerator.GetService(typeof(EmbeddingGeneratorMetadata)) as EmbeddingGeneratorMetadata; + var generatorTypeName = embeddingGenerator.GetType().FullName ?? embeddingGenerator.GetType().Name; + + return ComputeDocumentHash(string.Join( + EmbeddingGeneratorFingerprintComponentSeparator, + metadata?.ProviderName ?? EmbeddingGeneratorFingerprintUnknownComponent, + metadata?.ProviderUri?.AbsoluteUri ?? EmbeddingGeneratorFingerprintUnknownComponent, + metadata?.DefaultModelId ?? EmbeddingGeneratorFingerprintUnknownComponent, + metadata?.DefaultModelDimensions?.ToString(CultureInfo.InvariantCulture) ?? EmbeddingGeneratorFingerprintUnknownComponent, + generatorTypeName ?? EmbeddingGeneratorFingerprintUnknownComponent)); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Invocation/McpGatewayRuntime.Invocation.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Invocation/McpGatewayRuntime.Invocation.cs new file mode 100644 index 0000000..e56fc9f --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Invocation/McpGatewayRuntime.Invocation.cs @@ -0,0 +1,191 @@ +using System.Globalization; +using System.Text.Json; +using System.Text.Json.Nodes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using ModelContextProtocol; +using ModelContextProtocol.Client; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + public async Task InvokeAsync( + McpGatewayInvokeRequest request, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + var snapshot = await GetSnapshotAsync(cancellationToken); + var resolution = ResolveInvocationTarget(snapshot, request); + if (!resolution.IsSuccess || resolution.Entry is null) + { + return new McpGatewayInvokeResult( + false, + request.ToolId ?? string.Empty, + request.SourceId ?? string.Empty, + request.ToolName ?? string.Empty, + Output: null, + Error: resolution.Error); + } + + var entry = resolution.Entry; + var arguments = request.Arguments is { Count: > 0 } + ? new Dictionary(request.Arguments, StringComparer.OrdinalIgnoreCase) + : new Dictionary(StringComparer.OrdinalIgnoreCase); + + if (!string.IsNullOrWhiteSpace(request.Query) && + !arguments.ContainsKey(QueryArgumentName) && + SupportsArgument(entry.Descriptor, QueryArgumentName)) + { + arguments[QueryArgumentName] = request.Query; + } + + MapRequestArgument(arguments, entry.Descriptor, ContextArgumentName, request.Context); + MapRequestArgument(arguments, entry.Descriptor, ContextSummaryArgumentName, request.ContextSummary); + + try + { + var resolvedMcpTool = entry.Tool as McpClientTool ?? entry.Tool.GetService(); + if (resolvedMcpTool is not null) + { + var result = await AttachInvocationMeta(resolvedMcpTool, request).CallAsync( + arguments, + progress: null, + options: new RequestOptions(), + cancellationToken: cancellationToken); + + return new McpGatewayInvokeResult( + true, + entry.Descriptor.ToolId, + entry.Descriptor.SourceId, + entry.Descriptor.ToolName, + ExtractMcpOutput(result)); + } + + var function = entry.Tool as AIFunction ?? entry.Tool.GetService(); + if (function is null) + { + return new McpGatewayInvokeResult( + false, + entry.Descriptor.ToolId, + entry.Descriptor.SourceId, + entry.Descriptor.ToolName, + Output: null, + Error: string.Format(CultureInfo.InvariantCulture, ToolNotInvokableMessageFormat, entry.Descriptor.ToolName)); + } + + var resultValue = await function.InvokeAsync( + new AIFunctionArguments(arguments, StringComparer.OrdinalIgnoreCase), + cancellationToken); + return new McpGatewayInvokeResult( + true, + entry.Descriptor.ToolId, + entry.Descriptor.SourceId, + entry.Descriptor.ToolName, + NormalizeFunctionOutput(resultValue)); + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + _logger.LogError(ex, GatewayInvocationFailedLogMessage, entry.Descriptor.ToolId); + return new McpGatewayInvokeResult( + false, + entry.Descriptor.ToolId, + entry.Descriptor.SourceId, + entry.Descriptor.ToolName, + Output: null, + Error: ex.GetBaseException().Message); + } + } + + private static void MapRequestArgument( + IDictionary arguments, + McpGatewayToolDescriptor descriptor, + string argumentName, + object? value) + { + if (value is null || + arguments.ContainsKey(argumentName) || + !SupportsArgument(descriptor, argumentName)) + { + return; + } + + if (value is string text && string.IsNullOrWhiteSpace(text)) + { + return; + } + + arguments[argumentName] = value; + } + + private static bool SupportsArgument( + McpGatewayToolDescriptor descriptor, + string argumentName) + { + if (descriptor.RequiredArguments.Contains(argumentName, StringComparer.OrdinalIgnoreCase)) + { + return true; + } + + if (string.IsNullOrWhiteSpace(descriptor.InputSchemaJson)) + { + return false; + } + + try + { + using var schemaDocument = JsonDocument.Parse(descriptor.InputSchemaJson); + if (!schemaDocument.RootElement.TryGetProperty(InputSchemaPropertiesPropertyName, out var properties) || + properties.ValueKind != JsonValueKind.Object) + { + return false; + } + + return properties + .EnumerateObject() + .Any(property => string.Equals(property.Name, argumentName, StringComparison.OrdinalIgnoreCase)); + } + catch (JsonException) + { + return false; + } + } + + private static McpClientTool AttachInvocationMeta(McpClientTool tool, McpGatewayInvokeRequest request) + { + var meta = BuildInvocationMeta(request); + return meta is null ? tool : tool.WithMeta(meta); + } + + private static JsonObject? BuildInvocationMeta(McpGatewayInvokeRequest request) + { + var payload = new JsonObject(); + if (!string.IsNullOrWhiteSpace(request.Query)) + { + payload[QueryArgumentName] = request.Query.Trim(); + } + + if (!string.IsNullOrWhiteSpace(request.ContextSummary)) + { + payload[ContextSummaryArgumentName] = request.ContextSummary.Trim(); + } + + if (request.Context is { Count: > 0 }) + { + var contextNode = McpGatewayJsonSerializer.TrySerializeToNode(request.Context); + if (contextNode is not null) + { + payload[ContextArgumentName] = contextNode; + } + } + + return payload.Count == 0 + ? null + : new JsonObject + { + [GatewayInvocationMetaKey] = payload + }; + } + +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Invocation/McpGatewayRuntime.InvocationResults.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Invocation/McpGatewayRuntime.InvocationResults.cs new file mode 100644 index 0000000..42d39ff --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Invocation/McpGatewayRuntime.InvocationResults.cs @@ -0,0 +1,92 @@ +using System.Globalization; +using System.Text.Json; +using ModelContextProtocol.Protocol; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private static InvocationResolution ResolveInvocationTarget( + ToolCatalogSnapshot snapshot, + McpGatewayInvokeRequest request) + { + if (!string.IsNullOrWhiteSpace(request.ToolId)) + { + var byToolId = snapshot.Entries.FirstOrDefault(item => + string.Equals(item.Descriptor.ToolId, request.ToolId, StringComparison.OrdinalIgnoreCase)); + return byToolId is null + ? InvocationResolution.Fail(string.Format(CultureInfo.InvariantCulture, ToolIdNotFoundMessageFormat, request.ToolId)) + : InvocationResolution.Success(byToolId); + } + + if (string.IsNullOrWhiteSpace(request.ToolName)) + { + return InvocationResolution.Fail(ToolIdOrToolNameRequiredMessage); + } + + var candidates = snapshot.Entries + .Where(item => string.Equals(item.Descriptor.ToolName, request.ToolName, StringComparison.OrdinalIgnoreCase)) + .Where(item => string.IsNullOrWhiteSpace(request.SourceId) || + string.Equals(item.Descriptor.SourceId, request.SourceId, StringComparison.OrdinalIgnoreCase)) + .ToList(); + + return candidates.Count switch + { + 0 => InvocationResolution.Fail(string.Format(CultureInfo.InvariantCulture, ToolIdNotFoundMessageFormat, request.ToolName)), + 1 => InvocationResolution.Success(candidates[0]), + _ => InvocationResolution.Fail( + string.Format(CultureInfo.InvariantCulture, ToolNameAmbiguousMessageFormat, request.ToolName)) + }; + } + + private static object? ExtractMcpOutput(CallToolResult result) + { + if (result.StructuredContent is JsonElement element) + { + return element.Clone(); + } + + var text = result.Content? + .OfType() + .FirstOrDefault(static block => !string.IsNullOrWhiteSpace(block.Text)) + ?.Text; + if (string.IsNullOrWhiteSpace(text)) + { + return result; + } + + try + { + using var document = JsonDocument.Parse(text); + return document.RootElement.Clone(); + } + catch (JsonException) + { + return text; + } + } + + private static object? NormalizeFunctionOutput(object? value) + { + return value switch + { + JsonElement element => NormalizeJsonElement(element), + JsonDocument document => NormalizeJsonElement(document.RootElement), + _ => value + }; + } + + private static object? NormalizeJsonElement(JsonElement element) + { + return element.ValueKind switch + { + JsonValueKind.Null or JsonValueKind.Undefined => null, + JsonValueKind.String => element.GetString(), + JsonValueKind.True or JsonValueKind.False => element.GetBoolean(), + JsonValueKind.Number when element.TryGetInt64(out var int64Value) => int64Value, + JsonValueKind.Number when element.TryGetDecimal(out var decimalValue) => decimalValue, + JsonValueKind.Number when element.TryGetDouble(out var doubleValue) => doubleValue, + _ => element.Clone() + }; + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Context.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Context.cs new file mode 100644 index 0000000..24ad4bd --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Context.cs @@ -0,0 +1,96 @@ +using System.Text; +using System.Text.Json; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private static SearchInput BuildSearchInput( + McpGatewaySearchRequest request, + string? normalizedQuery) + => new( + NormalizeSearchComponent(request.Query), + NormalizeSearchComponent(normalizedQuery), + NormalizeSearchComponent(request.ContextSummary), + NormalizeSearchComponent(FlattenContext(request.Context))); + + private static string? NormalizeSearchComponent(string? value) + => string.IsNullOrWhiteSpace(value) + ? null + : value.Trim(); + + private static string? FlattenContext(IReadOnlyDictionary? context) + { + if (context is not { Count: > 0 }) + { + return null; + } + + if (McpGatewayJsonSerializer.TrySerializeToElement(context) is not JsonElement contextElement || + contextElement.ValueKind != JsonValueKind.Object) + { + return null; + } + + var builder = new StringBuilder(); + foreach (var property in contextElement.EnumerateObject()) + { + AppendJsonElementTerms(builder, property.Name, property.Value); + } + + return builder.Length == 0 + ? null + : builder.ToString(); + } + + private static void AppendJsonElementTerms(StringBuilder builder, string key, JsonElement element) + { + switch (element.ValueKind) + { + case JsonValueKind.Object: + foreach (var property in element.EnumerateObject()) + { + AppendJsonElementTerms(builder, string.Concat(key, " ", property.Name), property.Value); + } + return; + + case JsonValueKind.Array: + foreach (var item in element.EnumerateArray()) + { + AppendJsonElementTerms(builder, key, item); + } + return; + + case JsonValueKind.String: + if (NormalizeSearchComponent(element.GetString()) is string text) + { + AppendContextTerm(builder, key, text); + } + return; + + case JsonValueKind.True: + case JsonValueKind.False: + AppendContextTerm(builder, key, element.GetBoolean() ? bool.TrueString : bool.FalseString); + return; + + case JsonValueKind.Number: + AppendContextTerm(builder, key, element.ToString()); + return; + + default: + return; + } + } + + private static void AppendContextTerm(StringBuilder builder, string key, string value) + { + if (builder.Length > 0) + { + builder.Append("; "); + } + + builder.Append(key); + builder.Append(' '); + builder.Append(value); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.QueryNormalization.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.QueryNormalization.cs new file mode 100644 index 0000000..de06b34 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.QueryNormalization.cs @@ -0,0 +1,94 @@ +using System.Globalization; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private async Task NormalizeSearchQueryAsync( + string? query, + ICollection diagnostics, + CancellationToken cancellationToken) + { + if (_searchQueryNormalization == McpGatewaySearchQueryNormalization.Disabled || + string.IsNullOrWhiteSpace(query)) + { + return null; + } + + try + { + await using var chatClientLease = ResolveSearchQueryChatClient(); + if (chatClientLease.Client is not IChatClient chatClient) + { + return null; + } + + var response = await chatClient.GetResponseAsync( + [new ChatMessage(ChatRole.User, query.Trim())], + new ChatOptions + { + Instructions = SearchQueryNormalizationInstructions, + Temperature = 0f, + MaxOutputTokens = SearchQueryNormalizationMaxOutputTokens + }, + cancellationToken); + + var normalizedQuery = NormalizeChatResponseText(response.Text); + if (string.IsNullOrWhiteSpace(normalizedQuery) || + string.Equals(normalizedQuery, query.Trim(), StringComparison.OrdinalIgnoreCase)) + { + return null; + } + + diagnostics.Add(new McpGatewayDiagnostic(QueryNormalizedDiagnosticCode, QueryNormalizedMessage)); + return normalizedQuery; + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + diagnostics.Add(new McpGatewayDiagnostic( + QueryNormalizationFailedDiagnosticCode, + string.Format( + CultureInfo.InvariantCulture, + QueryNormalizationFailedMessageFormat, + ex.GetBaseException().Message))); + _logger.LogWarning(ex, GatewayQueryNormalizationFailedLogMessage); + return null; + } + } + + private ChatClientLease ResolveSearchQueryChatClient() + { + if (_serviceProvider.GetService(typeof(IServiceScopeFactory)) is not IServiceScopeFactory scopeFactory) + { + return new ChatClientLease(ResolveSearchQueryChatClient(_serviceProvider)); + } + + var scope = scopeFactory.CreateAsyncScope(); + var chatClient = ResolveSearchQueryChatClient(scope.ServiceProvider); + return new ChatClientLease(chatClient, scope); + } + + private static IChatClient? ResolveSearchQueryChatClient(IServiceProvider serviceProvider) + => serviceProvider.GetKeyedService(McpGatewayServiceKeys.SearchQueryChatClient); + + private static string? NormalizeChatResponseText(string? responseText) + { + if (string.IsNullOrWhiteSpace(responseText)) + { + return null; + } + + var normalized = responseText.Trim(); + normalized = normalized.Trim('`', '"', '\'', ' '); + normalized = normalized + .Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .FirstOrDefault(); + + return string.IsNullOrWhiteSpace(normalized) + ? null + : normalized.Trim().Trim('`', '"', '\''); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Search.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Search.cs new file mode 100644 index 0000000..6e72e4a --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Search.cs @@ -0,0 +1,133 @@ +using System.Globalization; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + public async Task> ListToolsAsync(CancellationToken cancellationToken = default) + { + var snapshot = await GetSnapshotAsync(cancellationToken); + return snapshot.Entries + .Select(static item => item.Descriptor) + .ToList(); + } + + public async Task SearchAsync( + string? query, + int? maxResults = null, + CancellationToken cancellationToken = default) + => await SearchAsync( + new McpGatewaySearchRequest( + Query: query, + MaxResults: maxResults), + cancellationToken); + + public async Task SearchAsync( + McpGatewaySearchRequest request, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + var snapshot = await GetSnapshotAsync(cancellationToken); + var limit = Math.Clamp(request.MaxResults.GetValueOrDefault(_defaultSearchLimit), 1, _maxSearchResults); + var diagnostics = new List(); + + if (snapshot.Entries.Count == 0) + { + return new McpGatewaySearchResult([], diagnostics, SearchModeEmpty); + } + + var normalizedQuery = await NormalizeSearchQueryAsync(request.Query, diagnostics, cancellationToken); + var searchInput = BuildSearchInput(request, normalizedQuery); + if (string.IsNullOrWhiteSpace(searchInput.EffectiveQuery)) + { + var browse = snapshot.Entries + .Take(limit) + .Select(static entry => ToSearchMatch(entry, 0d)) + .ToList(); + return new McpGatewaySearchResult(browse, diagnostics, SearchModeBrowse); + } + + IReadOnlyList ranked; + var rankingMode = SearchModeLexical; + var shouldPreferVectorSearch = _searchStrategy is not McpGatewaySearchStrategy.Tokenizer && snapshot.HasVectors; + if (shouldPreferVectorSearch) + { + try + { + await using var embeddingGeneratorLease = ResolveEmbeddingGenerator(); + if (embeddingGeneratorLease.Generator is IEmbeddingGenerator> generator) + { + var embedding = await generator.GenerateAsync(searchInput.EffectiveQuery, cancellationToken: cancellationToken); + var queryVector = embedding.Vector.ToArray(); + var queryMagnitude = CalculateMagnitude(queryVector); + if (queryMagnitude > double.Epsilon) + { + ranked = snapshot.Entries + .Select(entry => new ScoredToolEntry( + entry, + ApplySearchBoosts( + entry, + searchInput.BoostQuery, + CalculateCosine(entry, queryVector, queryMagnitude)))) + .OrderByDescending(static item => item.Score) + .ThenBy(static item => item.Entry.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) + .ToList(); + rankingMode = SearchModeVector; + } + else + { + ranked = RankLexically(snapshot, searchInput); + diagnostics.Add(new McpGatewayDiagnostic(QueryVectorEmptyDiagnosticCode, QueryVectorEmptyMessage)); + } + } + else + { + ranked = RankLexically(snapshot, searchInput); + diagnostics.Add(new McpGatewayDiagnostic( + LexicalFallbackDiagnosticCode, + LexicalFallbackMessage)); + } + } + catch (Exception ex) when (ex is not OperationCanceledException) + { + ranked = RankLexically(snapshot, searchInput); + diagnostics.Add(new McpGatewayDiagnostic( + VectorSearchFailedDiagnosticCode, + string.Format(CultureInfo.InvariantCulture, VectorSearchFailedMessageFormat, ex.GetBaseException().Message))); + _logger.LogWarning(ex, GatewayVectorSearchFailedLogMessage); + } + } + else + { + ranked = RankLexically(snapshot, searchInput); + if (_searchStrategy is not McpGatewaySearchStrategy.Tokenizer) + { + diagnostics.Add(new McpGatewayDiagnostic( + LexicalFallbackDiagnosticCode, + LexicalFallbackMessage)); + } + } + + var matches = ranked + .Take(limit) + .Select(item => ToSearchMatch(item.Entry, item.Score)) + .ToList(); + + return new McpGatewaySearchResult(matches, diagnostics, rankingMode); + } + + private static McpGatewaySearchMatch ToSearchMatch(ToolCatalogEntry entry, double score) + => new( + entry.Descriptor.ToolId, + entry.Descriptor.SourceId, + entry.Descriptor.SourceKind, + entry.Descriptor.ToolName, + entry.Descriptor.DisplayName, + entry.Descriptor.Description, + entry.Descriptor.RequiredArguments, + entry.Descriptor.InputSchemaJson, + score); +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.TokenSearch.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.TokenSearch.cs new file mode 100644 index 0000000..1f148ac --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.TokenSearch.cs @@ -0,0 +1,632 @@ +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private static double ApplySearchBoosts(ToolCatalogEntry entry, string query, double score) + => Math.Clamp(score + (CalculateToolNameSignal(entry, query, TokenSearchProfile.Empty) * ToolNameSignalWeight), 0d, 1d); + + private IReadOnlyList RankLexically( + ToolCatalogSnapshot snapshot, + SearchInput searchInput) + { + var querySegments = BuildQueryTokenSearchSegments(searchInput); + var queryFields = BuildTokenizedSearchFields(querySegments); + var rawQueryProfile = BuildTokenSearchProfile(queryFields); + if (rawQueryProfile.TermWeights.Count == 0) + { + return RankLegacyLexically(snapshot.Entries, searchInput.BoostQuery); + } + + var queryProfile = ApplyTokenInverseDocumentFrequencies( + rawQueryProfile, + snapshot.TokenInverseDocumentFrequencies); + var characterNGramProfile = BuildCharacterNGramProfile( + queryFields, + snapshot.CharacterNGramInverseDocumentFrequencies); + var lexicalSearchTerms = BuildLexicalTerms(querySegments); + var rawLexicalSearchTerms = BuildSearchTerms(searchInput.BoostQuery); + var candidates = RetrieveCandidates( + snapshot, + queryProfile, + rawQueryProfile, + characterNGramProfile, + rawLexicalSearchTerms, + searchInput.BoostQuery); + + return candidates + .Select(candidate => new ScoredToolEntry( + candidate.Entry, + CalculateTokenSearchScore( + candidate, + queryProfile, + rawQueryProfile, + lexicalSearchTerms, + searchInput.BoostQuery))) + .OrderByDescending(static item => item.Score) + .ThenBy(static item => item.Entry.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) + .ToList(); + } + + private static IReadOnlyList RetrieveCandidates( + ToolCatalogSnapshot snapshot, + TokenSearchProfile queryProfile, + TokenSearchProfile rawQueryProfile, + TokenSearchProfile characterNGramProfile, + IReadOnlySet rawLexicalSearchTerms, + string rawQuery) + { + var retrievalCandidates = snapshot.Entries + .Select(entry => new RetrievalCandidate( + entry, + CalculateBm25Score( + entry, + rawQueryProfile, + snapshot.TokenInverseDocumentFrequencies, + snapshot.AverageSearchFieldLength), + CalculateSparseCosine(entry.TokenProfile, queryProfile), + CalculateSparseCosine(entry.CharacterNGramProfile, characterNGramProfile), + CalculateLegacyLexicalScore(entry, rawQuery, rawLexicalSearchTerms))) + .ToList(); + + var bm25Ranks = BuildRankLookup(retrievalCandidates, static candidate => candidate.Bm25Score); + var tokenRanks = BuildRankLookup(retrievalCandidates, static candidate => candidate.TokenSimilarity); + var characterRanks = BuildRankLookup(retrievalCandidates, static candidate => candidate.CharacterNGramSimilarity); + var legacyRanks = BuildRankLookup(retrievalCandidates, static candidate => candidate.LegacyLexicalScore); + + return retrievalCandidates + .OrderByDescending(candidate => + CalculateReciprocalRankFusionScore(candidate.Entry, bm25Ranks, tokenRanks, characterRanks, legacyRanks)) + .ThenByDescending(static candidate => candidate.Bm25Score) + .ThenBy(static candidate => candidate.Entry.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) + .Take(ResolveCandidatePoolSize(retrievalCandidates.Count)) + .ToList(); + } + + private static int ResolveCandidatePoolSize(int candidateCount) + => candidateCount <= 128 + ? candidateCount + : Math.Min(StageOneCandidatePoolSize, candidateCount); + + private static IReadOnlyDictionary BuildRankLookup( + IReadOnlyList candidates, + Func scoreSelector) + => candidates + .OrderByDescending(scoreSelector) + .ThenBy(static candidate => candidate.Entry.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) + .Select((candidate, index) => new KeyValuePair(candidate.Entry.Descriptor.ToolId, index + 1)) + .ToDictionary(static item => item.Key, static item => item.Value, StringComparer.OrdinalIgnoreCase); + + private static double CalculateReciprocalRankFusionScore( + ToolCatalogEntry entry, + IReadOnlyDictionary bm25Ranks, + IReadOnlyDictionary tokenRanks, + IReadOnlyDictionary characterRanks, + IReadOnlyDictionary legacyRanks) + => CalculateReciprocalRankComponent(entry, bm25Ranks) + + CalculateReciprocalRankComponent(entry, tokenRanks) + + CalculateReciprocalRankComponent(entry, characterRanks) + + CalculateReciprocalRankComponent(entry, legacyRanks); + + private static double CalculateReciprocalRankComponent( + ToolCatalogEntry entry, + IReadOnlyDictionary ranks) + => ranks.TryGetValue(entry.Descriptor.ToolId, out var rank) + ? 1d / (ReciprocalRankFusionConstant + rank) + : 0d; + + private static double CalculateTokenSearchScore( + RetrievalCandidate candidate, + TokenSearchProfile queryProfile, + TokenSearchProfile rawQueryProfile, + IReadOnlySet lexicalSearchTerms, + string rawQuery) + { + var bm25Score = NormalizePositiveScore(candidate.Bm25Score, 4d); + var tokenCoverage = Math.Max( + CalculateQueryCoverage(candidate.Entry.TokenProfile, queryProfile), + CalculateQueryCoverage(candidate.Entry.TokenProfile, rawQueryProfile)); + var distinctCoverage = CalculateDistinctQueryCoverage(candidate.Entry, rawQueryProfile); + var approximateCoverage = CalculateApproximateQueryCoverage(candidate.Entry, rawQueryProfile); + var weightedApproximateCoverage = CalculateWeightedApproximateCoverage(candidate.Entry, queryProfile); + var matchBreadth = Math.Max(Math.Max(distinctCoverage, approximateCoverage), weightedApproximateCoverage); + var lexicalSimilarity = CalculateLexicalSimilarity(candidate.Entry, rawQuery, lexicalSearchTerms); + var toolNameSignal = CalculateToolNameSignal(candidate.Entry, rawQuery, rawQueryProfile); + + var score = + (bm25Score * Bm25FeatureWeight) + + (candidate.TokenSimilarity * TokenSimilarityWeight) + + (candidate.CharacterNGramSimilarity * CharacterNGramSimilarityWeight) + + (tokenCoverage * TokenCoverageWeight) + + (matchBreadth * DistinctCoverageWeight) + + (lexicalSimilarity * LexicalSimilarityWeight) + + (candidate.LegacyLexicalScore * LegacyLexicalFeatureWeight) + + (toolNameSignal * ToolNameSignalWeight); + + var evidenceCalibration = Math.Max(tokenCoverage, matchBreadth); + return Math.Clamp(score * evidenceCalibration, 0d, 1d); + } + + private static double NormalizePositiveScore(double value, double scale) + => value <= double.Epsilon + ? 0d + : 1d - Math.Exp(-value / scale); + + private static double CalculateBm25Score( + ToolCatalogEntry entry, + TokenSearchProfile queryProfile, + IReadOnlyDictionary inverseDocumentFrequencies, + double averageFieldLength) + { + if (queryProfile.TermWeights.Count == 0) + { + return 0d; + } + + var score = 0d; + foreach (var (term, queryWeight) in queryProfile.TermWeights) + { + var fieldScore = 0d; + foreach (var field in entry.SearchFields) + { + if (!field.TermWeights.TryGetValue(term, out var termFrequency) || + termFrequency <= double.Epsilon) + { + continue; + } + + var normalizedLength = + 1d - Bm25FieldLengthNormalization + + (Bm25FieldLengthNormalization * (Math.Max(1, field.Length) / Math.Max(1d, averageFieldLength))); + var denominator = termFrequency + (Bm25K1 * normalizedLength); + if (denominator <= double.Epsilon) + { + continue; + } + + fieldScore += field.Weight * ((termFrequency * (Bm25K1 + 1d)) / denominator); + } + + if (fieldScore <= double.Epsilon) + { + continue; + } + + var inverseDocumentFrequency = inverseDocumentFrequencies.TryGetValue(term, out var value) + ? value + : 1d; + score += inverseDocumentFrequency * Math.Sqrt(queryWeight) * fieldScore; + } + + return score; + } + + private static double CalculateSparseCosine( + TokenSearchProfile entryProfile, + TokenSearchProfile queryProfile) + { + if (entryProfile.Magnitude <= double.Epsilon || + queryProfile.Magnitude <= double.Epsilon) + { + return 0d; + } + + var smaller = entryProfile.TermWeights.Count <= queryProfile.TermWeights.Count + ? entryProfile.TermWeights + : queryProfile.TermWeights; + var larger = ReferenceEquals(smaller, entryProfile.TermWeights) + ? queryProfile.TermWeights + : entryProfile.TermWeights; + + var dot = 0d; + foreach (var (term, weight) in smaller) + { + if (larger.TryGetValue(term, out var otherWeight)) + { + dot += weight * otherWeight; + } + } + + return dot <= double.Epsilon + ? 0d + : dot / (entryProfile.Magnitude * queryProfile.Magnitude); + } + + private static double CalculateQueryCoverage( + TokenSearchProfile entryProfile, + TokenSearchProfile queryProfile) + { + if (queryProfile.TotalWeight <= double.Epsilon) + { + return 0d; + } + + var matchedWeight = 0d; + foreach (var (term, weight) in queryProfile.TermWeights) + { + if (entryProfile.TermWeights.ContainsKey(term)) + { + matchedWeight += weight; + } + } + + return matchedWeight <= double.Epsilon + ? 0d + : Math.Clamp(matchedWeight / queryProfile.TotalWeight, 0d, 1d); + } + + private static double CalculateLexicalSimilarity( + ToolCatalogEntry entry, + string query, + IReadOnlySet searchTerms) + { + if (searchTerms.Count == 0) + { + return 0d; + } + + var corpus = entry.LexicalTerms; + var score = 0d; + foreach (var term in searchTerms) + { + if (corpus.Contains(term)) + { + score += 1d; + continue; + } + + if (corpus.Any(candidate => + candidate.StartsWith(term, StringComparison.OrdinalIgnoreCase) || + term.StartsWith(candidate, StringComparison.OrdinalIgnoreCase))) + { + score += 0.35d; + } + } + + if (entry.Descriptor.ToolName.Contains(query, StringComparison.OrdinalIgnoreCase)) + { + score += 1.5d; + } + + return Math.Clamp(score / (searchTerms.Count + 1.5d), 0d, 1d); + } + + private static double CalculateDistinctQueryCoverage( + ToolCatalogEntry entry, + TokenSearchProfile rawQueryProfile) + { + if (rawQueryProfile.TermWeights.Count == 0) + { + return 0d; + } + + var matchedTerms = 0; + foreach (var term in rawQueryProfile.TermWeights.Keys) + { + if (entry.LexicalTerms.Contains(term) || + entry.LexicalTerms.Any(candidate => + candidate.StartsWith(term, StringComparison.OrdinalIgnoreCase) || + term.StartsWith(candidate, StringComparison.OrdinalIgnoreCase))) + { + matchedTerms++; + } + } + + return matchedTerms == 0 + ? 0d + : Math.Clamp(matchedTerms / (double)rawQueryProfile.TermWeights.Count, 0d, 1d); + } + + private static double CalculateApproximateQueryCoverage( + ToolCatalogEntry entry, + TokenSearchProfile rawQueryProfile) + { + if (rawQueryProfile.TermWeights.Count == 0) + { + return 0d; + } + + var matchedTerms = 0; + foreach (var term in rawQueryProfile.TermWeights.Keys) + { + if (entry.LexicalTerms.Contains(term) || + entry.LexicalTerms.Any(candidate => + candidate.StartsWith(term, StringComparison.OrdinalIgnoreCase) || + term.StartsWith(candidate, StringComparison.OrdinalIgnoreCase) || + CalculateApproximateTermSimilarity(term, candidate) > double.Epsilon)) + { + matchedTerms++; + } + } + + return matchedTerms == 0 + ? 0d + : Math.Clamp(matchedTerms / (double)rawQueryProfile.TermWeights.Count, 0d, 1d); + } + + private static double CalculateWeightedApproximateCoverage( + ToolCatalogEntry entry, + TokenSearchProfile queryProfile) + { + if (queryProfile.TotalWeight <= double.Epsilon) + { + return 0d; + } + + var matchedWeight = 0d; + foreach (var (term, weight) in queryProfile.TermWeights) + { + if (entry.LexicalTerms.Contains(term) || + entry.LexicalTerms.Any(candidate => + candidate.StartsWith(term, StringComparison.OrdinalIgnoreCase) || + term.StartsWith(candidate, StringComparison.OrdinalIgnoreCase) || + CalculateApproximateTermSimilarity(term, candidate) > double.Epsilon)) + { + matchedWeight += weight; + } + } + + return matchedWeight <= double.Epsilon + ? 0d + : Math.Clamp(matchedWeight / queryProfile.TotalWeight, 0d, 1d); + } + + private static double CalculateToolNameSignal( + ToolCatalogEntry entry, + string rawQuery, + TokenSearchProfile rawQueryProfile) + { + if (string.IsNullOrWhiteSpace(rawQuery)) + { + return 0d; + } + + if (string.Equals(entry.Descriptor.ToolName, rawQuery, StringComparison.OrdinalIgnoreCase)) + { + return 1d; + } + + var humanizedToolName = HumanizeIdentifier(entry.Descriptor.ToolName); + if (humanizedToolName.Contains(rawQuery, StringComparison.OrdinalIgnoreCase) || + entry.Descriptor.ToolName.Contains(rawQuery, StringComparison.OrdinalIgnoreCase)) + { + return 0.5d; + } + + if (rawQueryProfile.TermWeights.Count == 0) + { + return 0d; + } + + var toolNameTerms = BuildSearchTerms(entry.Descriptor.ToolName); + if (toolNameTerms.Count == 0) + { + return 0d; + } + + var matchedTerms = rawQueryProfile.TermWeights.Keys.Count(toolNameTerms.Contains); + return matchedTerms == 0 + ? 0d + : Math.Clamp(matchedTerms / (double)Math.Max(1, toolNameTerms.Count), 0d, 1d); + } + + private static IReadOnlyList RankLegacyLexically( + IReadOnlyList entries, + string query) + { + var searchTerms = BuildSearchTerms(query); + return entries + .Select(entry => new ScoredToolEntry(entry, CalculateLegacyLexicalScore(entry, query, searchTerms))) + .OrderByDescending(static item => item.Score) + .ThenBy(static item => item.Entry.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) + .ToList(); + } + + private static double CalculateLegacyLexicalScore( + ToolCatalogEntry entry, + string query, + IReadOnlySet searchTerms) + { + var lexicalSimilarity = CalculateLexicalSimilarity(entry, query, searchTerms); + var approximateTermSimilarity = CalculateApproximateTermSimilarity(entry, searchTerms); + var score = + (lexicalSimilarity * 0.7d) + + (approximateTermSimilarity * 0.3d); + return Math.Clamp(score, 0d, 1d); + } + + private static double CalculateApproximateTermSimilarity( + ToolCatalogEntry entry, + IReadOnlySet searchTerms) + { + if (searchTerms.Count == 0) + { + return 0d; + } + + var fuzzyScore = 0d; + var fuzzyTerms = 0; + foreach (var term in searchTerms) + { + if (entry.LexicalTerms.Contains(term) || + entry.LexicalTerms.Any(candidate => + candidate.StartsWith(term, StringComparison.OrdinalIgnoreCase) || + term.StartsWith(candidate, StringComparison.OrdinalIgnoreCase))) + { + continue; + } + + var bestSimilarity = 0d; + foreach (var candidate in entry.LexicalTerms) + { + var similarity = CalculateApproximateTermSimilarity(term, candidate); + if (similarity > bestSimilarity) + { + bestSimilarity = similarity; + } + } + + if (bestSimilarity > double.Epsilon) + { + fuzzyScore += bestSimilarity; + fuzzyTerms++; + } + } + + return fuzzyTerms == 0 + ? 0d + : Math.Clamp(fuzzyScore / fuzzyTerms, 0d, 1d); + } + + private static double CalculateApproximateTermSimilarity(string source, string candidate) + { + if (source.Length < 4 || + candidate.Length < 4 || + Math.Abs(source.Length - candidate.Length) > 2) + { + return 0d; + } + + var distanceThreshold = source.Length >= 8 || candidate.Length >= 8 + ? 2 + : 1; + var distance = CalculateDamerauLevenshteinDistance(source, candidate, distanceThreshold); + if (distance > distanceThreshold) + { + return 0d; + } + + return 1d - (distance / (double)Math.Max(source.Length, candidate.Length)); + } + + private static int CalculateDamerauLevenshteinDistance( + string source, + string candidate, + int maxDistance) + { + if (string.Equals(source, candidate, StringComparison.OrdinalIgnoreCase)) + { + return 0; + } + + if (source.Length == 0) + { + return candidate.Length; + } + + if (candidate.Length == 0) + { + return source.Length; + } + + if (Math.Abs(source.Length - candidate.Length) > maxDistance) + { + return maxDistance + 1; + } + + var previousPrevious = new int[candidate.Length + 1]; + var previous = new int[candidate.Length + 1]; + var current = new int[candidate.Length + 1]; + + for (var index = 0; index <= candidate.Length; index++) + { + previous[index] = index; + } + + for (var sourceIndex = 1; sourceIndex <= source.Length; sourceIndex++) + { + current[0] = sourceIndex; + var rowMinimum = current[0]; + + for (var candidateIndex = 1; candidateIndex <= candidate.Length; candidateIndex++) + { + var substitutionCost = char.ToLowerInvariant(source[sourceIndex - 1]) == + char.ToLowerInvariant(candidate[candidateIndex - 1]) + ? 0 + : 1; + + var value = Math.Min( + Math.Min( + current[candidateIndex - 1] + 1, + previous[candidateIndex] + 1), + previous[candidateIndex - 1] + substitutionCost); + + if (sourceIndex > 1 && + candidateIndex > 1 && + char.ToLowerInvariant(source[sourceIndex - 1]) == char.ToLowerInvariant(candidate[candidateIndex - 2]) && + char.ToLowerInvariant(source[sourceIndex - 2]) == char.ToLowerInvariant(candidate[candidateIndex - 1])) + { + value = Math.Min(value, previousPrevious[candidateIndex - 2] + 1); + } + + current[candidateIndex] = value; + if (value < rowMinimum) + { + rowMinimum = value; + } + } + + if (rowMinimum > maxDistance) + { + return maxDistance + 1; + } + + (previousPrevious, previous, current) = (previous, current, previousPrevious); + } + + return previous[candidate.Length]; + } + + private static IReadOnlyDictionary BuildTokenInverseDocumentFrequencies( + IReadOnlyList rawProfiles) + { + if (rawProfiles.Count == 0) + { + return EmptyTokenWeights; + } + + var documentFrequencies = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var profile in rawProfiles) + { + foreach (var term in profile.Terms) + { + documentFrequencies[term] = documentFrequencies.TryGetValue(term, out var count) + ? count + 1 + : 1; + } + } + + var inverseDocumentFrequencies = new Dictionary( + documentFrequencies.Count, + StringComparer.OrdinalIgnoreCase); + foreach (var (term, documentFrequency) in documentFrequencies) + { + inverseDocumentFrequencies[term] = + 1d + Math.Log((1d + rawProfiles.Count) / (1d + documentFrequency)); + } + + return inverseDocumentFrequencies; + } + + private static TokenSearchProfile ApplyTokenInverseDocumentFrequencies( + TokenSearchProfile rawProfile, + IReadOnlyDictionary inverseDocumentFrequencies) + { + if (rawProfile.TermWeights.Count == 0 || inverseDocumentFrequencies.Count == 0) + { + return rawProfile; + } + + var weightedTerms = new Dictionary( + rawProfile.TermWeights.Count, + StringComparer.OrdinalIgnoreCase); + foreach (var (term, rawWeight) in rawProfile.TermWeights) + { + var inverseDocumentFrequency = inverseDocumentFrequencies.TryGetValue(term, out var value) + ? value + : 1d; + weightedTerms[term] = rawWeight * inverseDocumentFrequency; + } + + return CreateTokenSearchProfile(weightedTerms); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.TokenSearchSegments.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.TokenSearchSegments.cs new file mode 100644 index 0000000..9ecda3a --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.TokenSearchSegments.cs @@ -0,0 +1,97 @@ +using System.Text.Json; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private static IReadOnlyList BuildDescriptorTokenSearchSegments( + McpGatewayToolDescriptor descriptor) + { + var segments = new List(); + + AddTokenSearchIdentifierSegment(segments, descriptor.ToolName, ToolNameTokenWeight); + AddTokenSearchIdentifierSegment(segments, descriptor.DisplayName, DisplayNameTokenWeight); + AddTokenSearchTextSegment(segments, descriptor.Description, DescriptionTokenWeight); + + foreach (var requiredArgument in descriptor.RequiredArguments) + { + AddTokenSearchIdentifierSegment(segments, requiredArgument, RequiredArgumentTokenWeight); + } + + if (string.IsNullOrWhiteSpace(descriptor.InputSchemaJson)) + { + return segments; + } + + try + { + using var schemaDocument = JsonDocument.Parse(descriptor.InputSchemaJson); + if (!schemaDocument.RootElement.TryGetProperty(InputSchemaPropertiesPropertyName, out var properties) || + properties.ValueKind != JsonValueKind.Object) + { + return segments; + } + + foreach (var property in properties.EnumerateObject()) + { + AddTokenSearchIdentifierSegment(segments, property.Name, ParameterNameTokenWeight); + + if (property.Value.TryGetProperty(InputSchemaDescriptionPropertyName, out var description) && + description.ValueKind == JsonValueKind.String) + { + AddTokenSearchTextSegment(segments, description.GetString(), ParameterDescriptionTokenWeight); + } + + if (property.Value.TryGetProperty(InputSchemaTypePropertyName, out var type) && + type.ValueKind == JsonValueKind.String) + { + AddTokenSearchIdentifierSegment(segments, type.GetString(), ParameterTypeTokenWeight); + } + + if (property.Value.TryGetProperty(InputSchemaEnumPropertyName, out var enumValues) && + enumValues.ValueKind == JsonValueKind.Array) + { + foreach (var enumValue in enumValues.EnumerateArray()) + { + if (enumValue.ValueKind == JsonValueKind.String) + { + AddTokenSearchIdentifierSegment( + segments, + enumValue.GetString(), + EnumValuesTokenWeight); + } + } + } + } + } + catch (JsonException) + { + AddTokenSearchTextSegment(segments, descriptor.InputSchemaJson, ParameterDescriptionTokenWeight); + } + + return segments; + } + + private static IReadOnlyList BuildQueryTokenSearchSegments( + SearchInput searchInput) + { + var segments = new List(); + + AddTokenSearchTextSegment(segments, searchInput.NormalizedQuery, QueryTokenWeight); + + if (!string.IsNullOrWhiteSpace(searchInput.OriginalQuery) && + !string.Equals(searchInput.OriginalQuery, searchInput.NormalizedQuery, StringComparison.OrdinalIgnoreCase)) + { + AddTokenSearchTextSegment(segments, searchInput.OriginalQuery, OriginalQueryBackoffWeight); + } + else + { + AddTokenSearchTextSegment(segments, searchInput.OriginalQuery, QueryTokenWeight); + } + + AddTokenSearchTextSegment(segments, searchInput.ContextSummary, ContextSummaryTokenWeight); + AddTokenSearchTextSegment(segments, searchInput.FlattenedContext, ContextTokenWeight); + + return segments; + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Tokenization.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Tokenization.cs new file mode 100644 index 0000000..7882c50 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewayRuntime.Tokenization.cs @@ -0,0 +1,333 @@ +using System.Text; + +namespace ManagedCode.MCPGateway; + +internal sealed partial class McpGatewayRuntime +{ + private IReadOnlyList BuildTokenizedSearchFields( + IEnumerable segments) + { + var fields = new List(); + foreach (var segment in segments) + { + if (segment.Weight <= double.Epsilon || string.IsNullOrWhiteSpace(segment.Text)) + { + continue; + } + + var tokenTerms = ExtractTokenTerms(segment.Text).ToList(); + var lexicalTerms = BuildSearchTerms(segment.Text); + if (tokenTerms.Count == 0) + { + tokenTerms = lexicalTerms.ToList(); + } + else + { + foreach (var lexicalTerm in lexicalTerms) + { + if (!tokenTerms.Contains(lexicalTerm, StringComparer.OrdinalIgnoreCase)) + { + tokenTerms.Add(lexicalTerm); + } + } + } + + if (tokenTerms.Count == 0) + { + continue; + } + + fields.Add(new TokenizedSearchField( + segment.Weight, + tokenTerms.Count, + BuildTermFrequencies(tokenTerms), + BuildCharacterNGramFrequencies(tokenTerms))); + } + + return fields; + } + + private static TokenSearchProfile BuildTokenSearchProfile( + IEnumerable fields, + IReadOnlyDictionary? inverseDocumentFrequencies = null) + { + var termWeights = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var field in fields) + { + foreach (var (term, weight) in field.TermWeights) + { + AddWeightedSearchTerm(termWeights, term, weight * field.Weight); + } + } + + var rawProfile = CreateTokenSearchProfile(termWeights); + return inverseDocumentFrequencies is { Count: > 0 } + ? ApplyTokenInverseDocumentFrequencies(rawProfile, inverseDocumentFrequencies) + : rawProfile; + } + + private static TokenSearchProfile BuildCharacterNGramProfile( + IEnumerable fields, + IReadOnlyDictionary? inverseDocumentFrequencies = null) + { + var termWeights = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var field in fields) + { + foreach (var (term, weight) in field.CharacterNGramWeights) + { + AddWeightedSearchTerm( + termWeights, + term, + weight * field.Weight * CharacterNGramTokenWeightFactor); + } + } + + var rawProfile = CreateTokenSearchProfile(termWeights); + return inverseDocumentFrequencies is { Count: > 0 } + ? ApplyTokenInverseDocumentFrequencies(rawProfile, inverseDocumentFrequencies) + : rawProfile; + } + + private IReadOnlyList ExtractTokenTerms(string text) + { + if (string.IsNullOrWhiteSpace(text)) + { + return []; + } + + try + { + var terms = new List(); + var tokens = _searchTokenizer.EncodeToTokens(text, out var _); + foreach (var token in tokens) + { + var normalizedTerms = BuildSearchTerms(token.Value); + if (normalizedTerms.Count == 0) + { + continue; + } + + terms.AddRange(normalizedTerms); + } + + return terms; + } + catch + { + return BuildSearchTerms(text).ToList(); + } + } + + private static TokenSearchProfile CreateTokenSearchProfile( + IReadOnlyDictionary termWeights) + { + if (termWeights.Count == 0) + { + return TokenSearchProfile.Empty; + } + + var magnitudeSquared = 0d; + var totalWeight = 0d; + foreach (var weight in termWeights.Values) + { + magnitudeSquared += weight * weight; + totalWeight += weight; + } + + return new TokenSearchProfile( + termWeights, + termWeights.Keys.ToHashSet(StringComparer.OrdinalIgnoreCase), + Math.Sqrt(magnitudeSquared), + totalWeight); + } + + private static IReadOnlySet BuildLexicalTerms( + IEnumerable segments) + { + var terms = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var segment in segments) + { + foreach (var term in BuildSearchTerms(segment.Text)) + { + terms.Add(term); + } + } + + return terms; + } + + private static void AddTokenSearchTextSegment( + List segments, + string? text, + double weight) + { + if (string.IsNullOrWhiteSpace(text) || weight <= double.Epsilon) + { + return; + } + + segments.Add(new WeightedTextSegment(text.Trim(), weight)); + } + + private static void AddTokenSearchIdentifierSegment( + List segments, + string? identifier, + double weight) + { + if (string.IsNullOrWhiteSpace(identifier) || weight <= double.Epsilon) + { + return; + } + + var trimmedIdentifier = identifier.Trim(); + segments.Add(new WeightedTextSegment(trimmedIdentifier, weight)); + + var humanizedIdentifier = HumanizeIdentifier(trimmedIdentifier); + if (!string.IsNullOrWhiteSpace(humanizedIdentifier) && + !string.Equals(humanizedIdentifier, trimmedIdentifier, StringComparison.OrdinalIgnoreCase)) + { + segments.Add(new WeightedTextSegment( + humanizedIdentifier, + weight * HumanizedIdentifierWeightFactor)); + } + } + + private static string HumanizeIdentifier(string identifier) + { + if (string.IsNullOrWhiteSpace(identifier)) + { + return string.Empty; + } + + var builder = new StringBuilder(identifier.Length + 8); + var previousWasSeparator = false; + var previousWasLowerOrDigit = false; + + foreach (var character in identifier.Trim()) + { + if (char.IsWhiteSpace(character) || + character is '_' or '-' or '.' or ',' or ';' or ':' or '/' or '\\') + { + if (builder.Length > 0 && !previousWasSeparator) + { + builder.Append(' '); + } + + previousWasSeparator = true; + previousWasLowerOrDigit = false; + continue; + } + + if (char.IsUpper(character) && previousWasLowerOrDigit && !previousWasSeparator) + { + builder.Append(' '); + } + + builder.Append(char.ToLowerInvariant(character)); + previousWasSeparator = false; + previousWasLowerOrDigit = char.IsLower(character) || char.IsDigit(character); + } + + return builder.ToString().Trim(); + } + + private static HashSet BuildSearchTerms(string? text) + { + if (string.IsNullOrWhiteSpace(text)) + { + return []; + } + + var terms = new HashSet(StringComparer.OrdinalIgnoreCase); + foreach (var token in text.Split(TokenSeparators, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) + { + if (token.Length < 2) + { + continue; + } + + var normalized = token.ToLowerInvariant(); + if (IgnoredSearchTerms.Contains(normalized)) + { + continue; + } + + terms.Add(normalized); + + if (normalized.Length > 3 && normalized.EndsWith(PluralSuffixIes, StringComparison.Ordinal)) + { + terms.Add($"{normalized[..^3]}y"); + continue; + } + + if (normalized.Length > 3 && normalized.EndsWith(PluralSuffixEs, StringComparison.Ordinal)) + { + terms.Add(normalized[..^2]); + } + else if (normalized.Length > 3 && normalized.EndsWith('s')) + { + terms.Add(normalized[..^1]); + } + } + + return terms; + } + + private static IReadOnlyDictionary BuildTermFrequencies( + IEnumerable terms) + { + var termWeights = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var term in terms) + { + AddWeightedSearchTerm(termWeights, term, 1d); + } + + return termWeights; + } + + private static IReadOnlyDictionary BuildCharacterNGramFrequencies( + IEnumerable terms) + { + var termWeights = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var term in terms) + { + foreach (var ngram in BuildCharacterNGramTerms(term)) + { + AddWeightedSearchTerm(termWeights, ngram, 1d); + } + } + + return termWeights; + } + + private static void AddWeightedSearchTerm( + Dictionary termWeights, + string? term, + double weight) + { + if (string.IsNullOrWhiteSpace(term) || weight <= double.Epsilon) + { + return; + } + + termWeights[term] = termWeights.TryGetValue(term, out var existingWeight) + ? existingWeight + weight + : weight; + } + + private static IReadOnlyList BuildCharacterNGramTerms(string term) + { + if (string.IsNullOrWhiteSpace(term) || term.Length < 5) + { + return []; + } + + var ngrams = new HashSet(StringComparer.OrdinalIgnoreCase); + for (var index = 0; index <= term.Length - CharacterNGramLength; index++) + { + ngrams.Add($"{CharacterNGramPrefix}{term.Substring(index, CharacterNGramLength)}"); + } + + return ngrams.ToList(); + } +} diff --git a/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewaySearchTokenizerFactory.cs b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewaySearchTokenizerFactory.cs new file mode 100644 index 0000000..1547388 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Runtime/Search/McpGatewaySearchTokenizerFactory.cs @@ -0,0 +1,14 @@ +using Microsoft.ML.Tokenizers; + +namespace ManagedCode.MCPGateway; + +internal static class McpGatewaySearchTokenizerFactory +{ + private const string ChatGptTokenizerModelName = "gpt-4o"; + + private static readonly Lazy ChatGptTokenizer = new( + static () => TiktokenTokenizer.CreateForModel(ChatGptTokenizerModelName), + LazyThreadSafetyMode.ExecutionAndPublication); + + public static Tokenizer GetTokenizer() => ChatGptTokenizer.Value; +} diff --git a/src/ManagedCode.MCPGateway/Internal/Serialization/McpGatewayJsonSerializer.cs b/src/ManagedCode.MCPGateway/Internal/Serialization/McpGatewayJsonSerializer.cs new file mode 100644 index 0000000..a57bf39 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Serialization/McpGatewayJsonSerializer.cs @@ -0,0 +1,68 @@ +using System.Text.Json; +using System.Text.Json.Nodes; + +namespace ManagedCode.MCPGateway; + +internal static class McpGatewayJsonSerializer +{ + public static JsonSerializerOptions Options { get; } = CreateOptions(); + + public static JsonElement? TrySerializeToElement(object? value) + { + try + { + return value switch + { + null => null, + JsonElement element => NormalizeElement(element), + JsonDocument document => NormalizeElement(document.RootElement), + JsonNode node => node is null ? null : NormalizeElement(JsonSerializer.SerializeToElement(node, Options)), + _ => NormalizeElement(JsonSerializer.SerializeToElement(value, Options)) + }; + } + catch (JsonException) + { + return null; + } + catch (NotSupportedException) + { + return null; + } + } + + public static JsonNode? TrySerializeToNode(object? value) + { + try + { + return value switch + { + null => null, + JsonNode node => node.DeepClone(), + JsonElement element => SerializeElementToNode(element), + JsonDocument document => SerializeElementToNode(document.RootElement), + _ => JsonSerializer.SerializeToNode(value, Options) + }; + } + catch (JsonException) + { + return null; + } + catch (NotSupportedException) + { + return null; + } + } + + private static JsonSerializerOptions CreateOptions() + => new(JsonSerializerDefaults.Web); + + private static JsonElement? NormalizeElement(JsonElement element) + => element.ValueKind is JsonValueKind.Null or JsonValueKind.Undefined + ? null + : element.Clone(); + + private static JsonNode? SerializeElementToNode(JsonElement element) + => element.ValueKind is JsonValueKind.Null or JsonValueKind.Undefined + ? null + : JsonSerializer.SerializeToNode(element, Options); +} diff --git a/src/ManagedCode.MCPGateway/Internal/Warmup/McpGatewayIndexWarmupService.cs b/src/ManagedCode.MCPGateway/Internal/Warmup/McpGatewayIndexWarmupService.cs new file mode 100644 index 0000000..c1a9b95 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Internal/Warmup/McpGatewayIndexWarmupService.cs @@ -0,0 +1,50 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace ManagedCode.MCPGateway; + +internal sealed class McpGatewayIndexWarmupService( + IMcpGateway gateway, + ILogger logger) : IHostedService +{ + private const string WarmupFailedLogMessage = "ManagedCode.MCPGateway background index warmup failed."; + private Task? _warmupTask; + + public Task StartAsync(CancellationToken cancellationToken) + { + _warmupTask = WarmAsync(cancellationToken); + return Task.CompletedTask; + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + if (_warmupTask is null) + { + return; + } + + try + { + await _warmupTask.WaitAsync(cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + } + + private async Task WarmAsync(CancellationToken cancellationToken) + { + try + { + await gateway.BuildIndexAsync(cancellationToken); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (Exception ex) + { + logger.LogWarning(ex, WarmupFailedLogMessage); + } + } +} diff --git a/src/ManagedCode.MCPGateway/ManagedCode.MCPGateway.csproj b/src/ManagedCode.MCPGateway/ManagedCode.MCPGateway.csproj index 4bac0e0..e36962f 100644 --- a/src/ManagedCode.MCPGateway/ManagedCode.MCPGateway.csproj +++ b/src/ManagedCode.MCPGateway/ManagedCode.MCPGateway.csproj @@ -8,7 +8,10 @@ + + + diff --git a/src/ManagedCode.MCPGateway/McpGateway.cs b/src/ManagedCode.MCPGateway/McpGateway.cs index 8b62cd2..7a711f0 100644 --- a/src/ManagedCode.MCPGateway/McpGateway.cs +++ b/src/ManagedCode.MCPGateway/McpGateway.cs @@ -1,16 +1,7 @@ -using System.Text; -using System.Text.Json; -using System.Text.Json.Nodes; -using System.Globalization; -using System.Security.Cryptography; using ManagedCode.MCPGateway.Abstractions; using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol; -using ModelContextProtocol.Client; -using ModelContextProtocol.Protocol; namespace ManagedCode.MCPGateway; @@ -18,1519 +9,54 @@ public sealed class McpGateway( IServiceProvider serviceProvider, IOptions options, ILogger logger, - ILoggerFactory loggerFactory) - : IMcpGateway, IMcpGatewayRegistry + ILoggerFactory loggerFactory) : IMcpGateway { - private const string DefaultSourceId = "local"; - private const string QueryArgumentName = "query"; - private const string ContextArgumentName = "context"; - private const string ContextSummaryArgumentName = "contextSummary"; - private const string GatewayInvocationMetaKey = "managedCodeMcpGateway"; - private const string SearchModeEmpty = "empty"; - private const string SearchModeBrowse = "browse"; - private const string SearchModeLexical = "lexical"; - private const string SearchModeVector = "vector"; - private const string SourceLoadFailedDiagnosticCode = "source_load_failed"; - private const string DuplicateToolIdDiagnosticCode = "duplicate_tool_id"; - private const string EmbeddingCountMismatchDiagnosticCode = "embedding_count_mismatch"; - private const string EmbeddingGeneratorMissingDiagnosticCode = "embedding_generator_missing"; - private const string EmbeddingFailedDiagnosticCode = "embedding_failed"; - private const string EmbeddingStoreLoadFailedDiagnosticCode = "embedding_store_load_failed"; - private const string EmbeddingStoreSaveFailedDiagnosticCode = "embedding_store_save_failed"; - private const string QueryVectorEmptyDiagnosticCode = "query_vector_empty"; - private const string LexicalFallbackDiagnosticCode = "lexical_fallback"; - private const string VectorSearchFailedDiagnosticCode = "vector_search_failed"; - private const string CommandRequiredMessage = "A command is required."; - private const string SourceLoadFailedMessageTemplate = "Failed to load tools from source '{0}': {1}"; - private const string DuplicateToolIdMessageTemplate = "Skipped duplicate tool id '{0}'."; - private const string EmbeddingCountMismatchMessageTemplate = "Embedding generation returned {0} vectors for {1} tools."; - private const string EmbeddingGeneratorMissingMessage = "No keyed or unkeyed IEmbeddingGenerator> is registered. Stored tool embeddings may be reused, but search falls back lexically without a query embedding generator."; - private const string EmbeddingFailedMessageTemplate = "Embedding generation failed: {0}"; - private const string EmbeddingStoreLoadFailedMessageTemplate = "Loading stored tool embeddings failed: {0}"; - private const string EmbeddingStoreSaveFailedMessageTemplate = "Persisting generated tool embeddings failed: {0}"; - private const string QueryVectorEmptyMessage = "Embedding generator returned an empty query vector."; - private const string LexicalFallbackMessage = "Vector search is unavailable. Lexical ranking was used."; - private const string VectorSearchFailedMessageTemplate = "Vector ranking failed and lexical fallback was used: {0}"; - private const string ToolNotInvokableMessageTemplate = "Tool '{0}' is not invokable."; - private const string ToolIdOrToolNameRequiredMessage = "Either ToolId or ToolName is required."; - private const string ToolIdNotFoundMessageTemplate = "Tool '{0}' was not found."; - private const string ToolNameAmbiguousMessageTemplate = "Tool '{0}' is ambiguous. Use ToolId or specify SourceId explicitly."; - private const string FailedToLoadGatewaySourceLogMessage = "Failed to load gateway source {SourceId}."; - private const string EmbeddingGenerationFailedLogMessage = "Gateway embedding generation failed. Falling back to lexical search."; - private const string GatewayIndexRebuiltLogMessage = "Gateway index rebuilt. Tools={ToolCount} VectorizedTools={VectorizedToolCount}."; - private const string GatewayVectorSearchFailedLogMessage = "Gateway vector search failed. Falling back to lexical ranking."; - private const string GatewayInvocationFailedLogMessage = "Gateway invocation failed for {ToolId}."; - private const string EmbeddingStoreLoadFailedLogMessage = "Loading stored tool embeddings failed. Falling back to generator-backed indexing."; - private const string EmbeddingStoreSaveFailedLogMessage = "Persisting generated tool embeddings failed."; - private const string InputSchemaPropertiesPropertyName = "properties"; - private const string InputSchemaRequiredPropertyName = "required"; - private const string InputSchemaDescriptionPropertyName = "description"; - private const string InputSchemaTypePropertyName = "type"; - private const string InputSchemaEnumPropertyName = "enum"; - private const string DisplayNamePropertyName = "DisplayName"; - private const string ToolNameLabel = "Tool name: "; - private const string DisplayNameLabel = "Display name: "; - private const string DescriptionLabel = "Description: "; - private const string RequiredArgumentsLabel = "Required arguments: "; - private const string ParameterLabel = "Parameter "; - private const string TypeLabel = "Type "; - private const string TypicalValuesLabel = "Typical values: "; - private const string InputSchemaLabel = "Input schema: "; - private const string ContextSummaryPrefix = "context summary: "; - private const string ContextPrefix = "context: "; - private const string PluralSuffixIes = "ies"; - private const string PluralSuffixEs = "es"; - private const string EmbeddingGeneratorFingerprintUnknownComponent = "unknown"; - private const string EmbeddingGeneratorFingerprintComponentSeparator = "\n"; + private readonly McpGatewayRuntime _runtime = CreateRuntime(serviceProvider, options, logger, loggerFactory); - private static readonly char[] TokenSeparators = - [ - ' ', - '\t', - '\r', - '\n', - '_', - '-', - '.', - ',', - ';', - ':', - '/', - '\\', - '(', - ')', - '[', - ']', - '{', - '}', - '"', - '\'', - '?', - '!' - ]; - private static readonly CompositeFormat SourceLoadFailedMessageFormat = CompositeFormat.Parse(SourceLoadFailedMessageTemplate); - private static readonly CompositeFormat DuplicateToolIdMessageFormat = CompositeFormat.Parse(DuplicateToolIdMessageTemplate); - private static readonly CompositeFormat EmbeddingCountMismatchMessageFormat = CompositeFormat.Parse(EmbeddingCountMismatchMessageTemplate); - private static readonly CompositeFormat EmbeddingFailedMessageFormat = CompositeFormat.Parse(EmbeddingFailedMessageTemplate); - private static readonly CompositeFormat EmbeddingStoreLoadFailedMessageFormat = CompositeFormat.Parse(EmbeddingStoreLoadFailedMessageTemplate); - private static readonly CompositeFormat EmbeddingStoreSaveFailedMessageFormat = CompositeFormat.Parse(EmbeddingStoreSaveFailedMessageTemplate); - private static readonly CompositeFormat VectorSearchFailedMessageFormat = CompositeFormat.Parse(VectorSearchFailedMessageTemplate); - private static readonly CompositeFormat ToolNotInvokableMessageFormat = CompositeFormat.Parse(ToolNotInvokableMessageTemplate); - private static readonly CompositeFormat ToolIdNotFoundMessageFormat = CompositeFormat.Parse(ToolIdNotFoundMessageTemplate); - private static readonly CompositeFormat ToolNameAmbiguousMessageFormat = CompositeFormat.Parse(ToolNameAmbiguousMessageTemplate); + public Task BuildIndexAsync(CancellationToken cancellationToken = default) + => _runtime.BuildIndexAsync(cancellationToken); - private readonly object _gate = new(); - private readonly SemaphoreSlim _rebuildLock = new(1, 1); - private readonly IServiceProvider _serviceProvider = serviceProvider; - private readonly ILogger _logger = logger; - private readonly ILoggerFactory _loggerFactory = loggerFactory; - private readonly List _registrations = options.Value.SourceRegistrations.ToList(); - private readonly int _defaultSearchLimit = Math.Max(1, options.Value.DefaultSearchLimit); - private readonly int _maxSearchResults = Math.Max(1, options.Value.MaxSearchResults); - private readonly int _maxDescriptorLength = Math.Max(256, options.Value.MaxDescriptorLength); - private ToolCatalogSnapshot? _snapshot; - private bool _disposed; + public Task> ListToolsAsync(CancellationToken cancellationToken = default) + => _runtime.ListToolsAsync(cancellationToken); - public void AddTool(string sourceId, AITool tool, string? displayName = null) - => AddTool(tool, sourceId, displayName); - - public void AddTool(AITool tool, string sourceId = DefaultSourceId, string? displayName = null) - { - ArgumentNullException.ThrowIfNull(tool); - - lock (_gate) - { - ThrowIfDisposed(); - - var existing = _registrations - .OfType() - .FirstOrDefault(item => string.Equals(item.SourceId, sourceId, StringComparison.OrdinalIgnoreCase)); - if (existing is null) - { - existing = new McpGatewayLocalToolSourceRegistration(sourceId.Trim(), displayName); - _registrations.Add(existing); - } - - existing.AddTool(tool); - _snapshot = null; - } - } - - public void AddTools(string sourceId, IEnumerable tools, string? displayName = null) - => AddTools(tools, sourceId, displayName); - - public void AddTools(IEnumerable tools, string sourceId = DefaultSourceId, string? displayName = null) - { - ArgumentNullException.ThrowIfNull(tools); - - foreach (var tool in tools) - { - AddTool(tool, sourceId, displayName); - } - } - - public void AddHttpServer( - string sourceId, - Uri endpoint, - IReadOnlyDictionary? headers = null, - string? displayName = null) - { - ArgumentNullException.ThrowIfNull(endpoint); - AddRegistration(new McpGatewayHttpToolSourceRegistration(sourceId.Trim(), endpoint, headers, displayName)); - } - - public void AddStdioServer( - string sourceId, - string command, - IReadOnlyList? arguments = null, - string? workingDirectory = null, - IReadOnlyDictionary? environmentVariables = null, - string? displayName = null) - { - if (string.IsNullOrWhiteSpace(command)) - { - throw new ArgumentException(CommandRequiredMessage, nameof(command)); - } - - AddRegistration(new McpGatewayStdioToolSourceRegistration( - sourceId.Trim(), - command.Trim(), - arguments, - workingDirectory, - environmentVariables, - displayName)); - } - - public void AddMcpClient( - string sourceId, - McpClient client, - bool disposeClient = false, - string? displayName = null) - { - ArgumentNullException.ThrowIfNull(client); - AddRegistration(new McpGatewayProvidedClientToolSourceRegistration( - sourceId.Trim(), - _ => ValueTask.FromResult(client), - disposeClient, - displayName)); - } - - public void AddMcpClientFactory( - string sourceId, - Func> clientFactory, - bool disposeClient = true, - string? displayName = null) - { - ArgumentNullException.ThrowIfNull(clientFactory); - AddRegistration(new McpGatewayProvidedClientToolSourceRegistration( - sourceId.Trim(), - clientFactory, - disposeClient, - displayName)); - } - - public async Task BuildIndexAsync(CancellationToken cancellationToken = default) - { - await _rebuildLock.WaitAsync(cancellationToken); - try - { - ThrowIfDisposed(); - - var registrations = CopyRegistrations(); - var diagnostics = new List(); - var entries = new List(); - var seenToolIds = new HashSet(StringComparer.OrdinalIgnoreCase); - - foreach (var registration in registrations) - { - cancellationToken.ThrowIfCancellationRequested(); - - IReadOnlyList tools; - try - { - tools = await registration.LoadToolsAsync(_loggerFactory, cancellationToken); - } - catch (Exception ex) when (ex is not OperationCanceledException) - { - diagnostics.Add(new McpGatewayDiagnostic( - SourceLoadFailedDiagnosticCode, - string.Format( - CultureInfo.InvariantCulture, - SourceLoadFailedMessageFormat, - registration.SourceId, - ex.GetBaseException().Message))); - _logger.LogWarning(ex, FailedToLoadGatewaySourceLogMessage, registration.SourceId); - continue; - } - - foreach (var tool in tools) - { - var descriptor = BuildDescriptor(registration, tool); - if (descriptor is null) - { - continue; - } - - if (!seenToolIds.Add(descriptor.ToolId)) - { - diagnostics.Add(new McpGatewayDiagnostic( - DuplicateToolIdDiagnosticCode, - string.Format(CultureInfo.InvariantCulture, DuplicateToolIdMessageFormat, descriptor.ToolId))); - continue; - } - - entries.Add(new ToolCatalogEntry( - descriptor, - tool, - BuildDescriptorDocument(descriptor, tool))); - } - } - - var vectorizedToolCount = 0; - var isVectorSearchEnabled = false; - if (entries.Count > 0) - { - await using var embeddingGeneratorLease = ResolveEmbeddingGenerator(); - await using var embeddingStoreLease = ResolveToolEmbeddingStore(); - var embeddingGenerator = embeddingGeneratorLease.Generator; - var embeddingGeneratorFingerprint = ResolveEmbeddingGeneratorFingerprint(embeddingGenerator); - var embeddingStore = embeddingStoreLease.Store; - var storeCandidates = entries - .Select((entry, index) => new ToolEmbeddingCandidate( - index, - new McpGatewayToolEmbeddingLookup( - entry.Descriptor.ToolId, - ComputeDocumentHash(entry.Document), - embeddingGeneratorFingerprint), - entry.Descriptor.SourceId, - entry.Descriptor.ToolName)) - .ToList(); - - if (embeddingStore is not null) - { - try - { - var storedEmbeddings = await embeddingStore.GetAsync( - storeCandidates.Select(static candidate => candidate.Lookup).ToList(), - cancellationToken); - - foreach (var candidate in storeCandidates) - { - var storedEmbedding = storedEmbeddings.LastOrDefault(embedding => - MatchesStoredEmbedding(candidate.Lookup, embedding)); - if (storedEmbedding is not null) - { - ApplyEmbedding(entries, candidate.Index, storedEmbedding.Vector, ref vectorizedToolCount); - } - } - } - catch (Exception ex) when (ex is not OperationCanceledException) - { - diagnostics.Add(new McpGatewayDiagnostic( - EmbeddingStoreLoadFailedDiagnosticCode, - string.Format(CultureInfo.InvariantCulture, EmbeddingStoreLoadFailedMessageFormat, ex.GetBaseException().Message))); - _logger.LogWarning(ex, EmbeddingStoreLoadFailedLogMessage); - } - } - - var missingCandidates = storeCandidates - .Where(candidate => entries[candidate.Index].Magnitude <= double.Epsilon) - .ToList(); - - if (embeddingGenerator is null && vectorizedToolCount > 0) - { - diagnostics.Add(new McpGatewayDiagnostic( - EmbeddingGeneratorMissingDiagnosticCode, - EmbeddingGeneratorMissingMessage)); - } - - if (missingCandidates.Count > 0) - { - try - { - if (embeddingGenerator is not null) - { - var embeddings = (await embeddingGenerator.GenerateAsync( - missingCandidates.Select(candidate => entries[candidate.Index].Document), - cancellationToken: cancellationToken)) - .ToList(); - if (embeddings.Count == missingCandidates.Count) - { - var generatedEmbeddings = new List(missingCandidates.Count); - for (var index = 0; index < missingCandidates.Count; index++) - { - var candidate = missingCandidates[index]; - var vector = embeddings[index].Vector.ToArray(); - if (ApplyEmbedding(entries, candidate.Index, vector, ref vectorizedToolCount)) - { - generatedEmbeddings.Add(new McpGatewayToolEmbedding( - candidate.Lookup.ToolId, - candidate.SourceId, - candidate.ToolName, - candidate.Lookup.DocumentHash, - candidate.Lookup.EmbeddingGeneratorFingerprint, - vector)); - } - } - - if (generatedEmbeddings.Count > 0 && embeddingStore is not null) - { - try - { - await embeddingStore.UpsertAsync(generatedEmbeddings, cancellationToken); - } - catch (Exception ex) when (ex is not OperationCanceledException) - { - diagnostics.Add(new McpGatewayDiagnostic( - EmbeddingStoreSaveFailedDiagnosticCode, - string.Format(CultureInfo.InvariantCulture, EmbeddingStoreSaveFailedMessageFormat, ex.GetBaseException().Message))); - _logger.LogWarning(ex, EmbeddingStoreSaveFailedLogMessage); - } - } - } - else - { - diagnostics.Add(new McpGatewayDiagnostic( - EmbeddingCountMismatchDiagnosticCode, - string.Format( - CultureInfo.InvariantCulture, - EmbeddingCountMismatchMessageFormat, - embeddings.Count, - missingCandidates.Count))); - } - } - else - { - diagnostics.Add(new McpGatewayDiagnostic( - EmbeddingGeneratorMissingDiagnosticCode, - EmbeddingGeneratorMissingMessage)); - } - } - catch (Exception ex) when (ex is not OperationCanceledException) - { - diagnostics.Add(new McpGatewayDiagnostic( - EmbeddingFailedDiagnosticCode, - string.Format(CultureInfo.InvariantCulture, EmbeddingFailedMessageFormat, ex.GetBaseException().Message))); - _logger.LogWarning(ex, EmbeddingGenerationFailedLogMessage); - } - } - - isVectorSearchEnabled = vectorizedToolCount > 0 && embeddingGenerator is not null; - } - - var snapshot = new ToolCatalogSnapshot( - entries - .OrderBy(static item => item.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) - .ThenBy(static item => item.Descriptor.SourceId, StringComparer.OrdinalIgnoreCase) - .ToList(), - isVectorSearchEnabled); - - lock (_gate) - { - _snapshot = snapshot; - } - - _logger.LogInformation( - GatewayIndexRebuiltLogMessage, - snapshot.Entries.Count, - vectorizedToolCount); - - return new McpGatewayIndexBuildResult( - snapshot.Entries.Count, - vectorizedToolCount, - snapshot.HasVectors, - diagnostics); - } - finally - { - _rebuildLock.Release(); - } - } - - public async Task> ListToolsAsync(CancellationToken cancellationToken = default) - { - var snapshot = await GetSnapshotAsync(cancellationToken); - return snapshot.Entries - .Select(static item => item.Descriptor) - .ToList(); - } - - public async Task SearchAsync( + public Task SearchAsync( string? query, int? maxResults = null, CancellationToken cancellationToken = default) - => await SearchAsync( - new McpGatewaySearchRequest( - Query: query, - MaxResults: maxResults), - cancellationToken); + => _runtime.SearchAsync(query, maxResults, cancellationToken); - public async Task SearchAsync( + public Task SearchAsync( McpGatewaySearchRequest request, CancellationToken cancellationToken = default) - { - ArgumentNullException.ThrowIfNull(request); - - var snapshot = await GetSnapshotAsync(cancellationToken); - var limit = Math.Clamp(request.MaxResults.GetValueOrDefault(_defaultSearchLimit), 1, _maxSearchResults); - var diagnostics = new List(); + => _runtime.SearchAsync(request, cancellationToken); - if (snapshot.Entries.Count == 0) - { - return new McpGatewaySearchResult([], diagnostics, SearchModeEmpty); - } - - var rawQuery = request.Query?.Trim(); - var effectiveQuery = BuildEffectiveSearchQuery(request); - if (string.IsNullOrWhiteSpace(effectiveQuery)) - { - var browse = snapshot.Entries - .Take(limit) - .Select(static entry => ToSearchMatch(entry, 0d)) - .ToList(); - return new McpGatewaySearchResult(browse, diagnostics, SearchModeBrowse); - } - - IReadOnlyList ranked; - var rankingMode = SearchModeLexical; - if (snapshot.HasVectors) - { - try - { - await using var embeddingGeneratorLease = ResolveEmbeddingGenerator(); - if (embeddingGeneratorLease.Generator is IEmbeddingGenerator> generator) - { - var embedding = await generator.GenerateAsync(effectiveQuery, cancellationToken: cancellationToken); - var queryVector = embedding.Vector.ToArray(); - var queryMagnitude = CalculateMagnitude(queryVector); - if (queryMagnitude > double.Epsilon) - { - ranked = snapshot.Entries - .Select(entry => new ScoredToolEntry( - entry, - ApplySearchBoosts( - entry, - rawQuery ?? effectiveQuery, - CalculateCosine(entry, queryVector, queryMagnitude)))) - .OrderByDescending(static item => item.Score) - .ThenBy(static item => item.Entry.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) - .ToList(); - rankingMode = SearchModeVector; - } - else - { - ranked = RankLexically(snapshot.Entries, effectiveQuery); - diagnostics.Add(new McpGatewayDiagnostic(QueryVectorEmptyDiagnosticCode, QueryVectorEmptyMessage)); - } - } - else - { - ranked = RankLexically(snapshot.Entries, effectiveQuery); - diagnostics.Add(new McpGatewayDiagnostic( - LexicalFallbackDiagnosticCode, - LexicalFallbackMessage)); - } - } - catch (Exception ex) when (ex is not OperationCanceledException) - { - ranked = RankLexically(snapshot.Entries, effectiveQuery); - diagnostics.Add(new McpGatewayDiagnostic( - VectorSearchFailedDiagnosticCode, - string.Format(CultureInfo.InvariantCulture, VectorSearchFailedMessageFormat, ex.GetBaseException().Message))); - _logger.LogWarning(ex, GatewayVectorSearchFailedLogMessage); - } - } - else - { - ranked = RankLexically(snapshot.Entries, effectiveQuery); - diagnostics.Add(new McpGatewayDiagnostic( - LexicalFallbackDiagnosticCode, - LexicalFallbackMessage)); - } - - var matches = ranked - .Take(limit) - .Select(item => ToSearchMatch(item.Entry, item.Score)) - .ToList(); - - return new McpGatewaySearchResult(matches, diagnostics, rankingMode); - } - - private EmbeddingGeneratorLease ResolveEmbeddingGenerator() - { - if (_serviceProvider.GetService(typeof(IServiceScopeFactory)) is not IServiceScopeFactory scopeFactory) - { - return new EmbeddingGeneratorLease(ResolveEmbeddingGenerator(_serviceProvider)); - } - - var scope = scopeFactory.CreateAsyncScope(); - var generator = ResolveEmbeddingGenerator(scope.ServiceProvider); - return new EmbeddingGeneratorLease(generator, scope); - } - - private static IEmbeddingGenerator>? ResolveEmbeddingGenerator(IServiceProvider serviceProvider) - => serviceProvider.GetKeyedService>>(McpGatewayServiceKeys.EmbeddingGenerator) - ?? serviceProvider.GetService>>(); - - private ToolEmbeddingStoreLease ResolveToolEmbeddingStore() - { - if (_serviceProvider.GetService(typeof(IServiceScopeFactory)) is not IServiceScopeFactory scopeFactory) - { - return new ToolEmbeddingStoreLease(_serviceProvider.GetService()); - } - - var scope = scopeFactory.CreateAsyncScope(); - var store = scope.ServiceProvider.GetService(); - return new ToolEmbeddingStoreLease(store, scope); - } - - public async Task InvokeAsync( + public Task InvokeAsync( McpGatewayInvokeRequest request, CancellationToken cancellationToken = default) - { - ArgumentNullException.ThrowIfNull(request); - - var snapshot = await GetSnapshotAsync(cancellationToken); - var resolution = ResolveInvocationTarget(snapshot, request); - if (!resolution.IsSuccess || resolution.Entry is null) - { - return new McpGatewayInvokeResult( - false, - request.ToolId ?? string.Empty, - request.SourceId ?? string.Empty, - request.ToolName ?? string.Empty, - Output: null, - Error: resolution.Error); - } - - var entry = resolution.Entry; - var arguments = request.Arguments is { Count: > 0 } - ? new Dictionary(request.Arguments, StringComparer.OrdinalIgnoreCase) - : new Dictionary(StringComparer.OrdinalIgnoreCase); - - if (!string.IsNullOrWhiteSpace(request.Query) && - !arguments.ContainsKey(QueryArgumentName) && - SupportsArgument(entry.Descriptor, QueryArgumentName)) - { - arguments[QueryArgumentName] = request.Query; - } - - MapRequestArgument(arguments, entry.Descriptor, ContextArgumentName, request.Context); - MapRequestArgument(arguments, entry.Descriptor, ContextSummaryArgumentName, request.ContextSummary); - - try - { - var resolvedMcpTool = entry.Tool as McpClientTool ?? entry.Tool.GetService(); - if (resolvedMcpTool is not null) - { - var result = await AttachInvocationMeta(resolvedMcpTool, request).CallAsync( - arguments, - progress: null, - options: new RequestOptions(), - cancellationToken: cancellationToken); - - return new McpGatewayInvokeResult( - true, - entry.Descriptor.ToolId, - entry.Descriptor.SourceId, - entry.Descriptor.ToolName, - ExtractMcpOutput(result)); - } - - var function = entry.Tool as AIFunction ?? entry.Tool.GetService(); - if (function is null) - { - return new McpGatewayInvokeResult( - false, - entry.Descriptor.ToolId, - entry.Descriptor.SourceId, - entry.Descriptor.ToolName, - Output: null, - Error: string.Format(CultureInfo.InvariantCulture, ToolNotInvokableMessageFormat, entry.Descriptor.ToolName)); - } - - var resultValue = await function.InvokeAsync( - new AIFunctionArguments(arguments, StringComparer.OrdinalIgnoreCase), - cancellationToken); - return new McpGatewayInvokeResult( - true, - entry.Descriptor.ToolId, - entry.Descriptor.SourceId, - entry.Descriptor.ToolName, - NormalizeFunctionOutput(resultValue)); - } - catch (Exception ex) when (ex is not OperationCanceledException) - { - _logger.LogError(ex, GatewayInvocationFailedLogMessage, entry.Descriptor.ToolId); - return new McpGatewayInvokeResult( - false, - entry.Descriptor.ToolId, - entry.Descriptor.SourceId, - entry.Descriptor.ToolName, - Output: null, - Error: ex.GetBaseException().Message); - } - } + => _runtime.InvokeAsync(request, cancellationToken); public IReadOnlyList CreateMetaTools( string searchToolName = McpGatewayToolSet.DefaultSearchToolName, string invokeToolName = McpGatewayToolSet.DefaultInvokeToolName) - => new McpGatewayToolSet(this).CreateTools(searchToolName, invokeToolName); - - public async ValueTask DisposeAsync() - { - List registrations; - lock (_gate) - { - if (_disposed) - { - return; - } - - _disposed = true; - registrations = _registrations.ToList(); - _registrations.Clear(); - _snapshot = null; - } - - foreach (var registration in registrations) - { - await registration.DisposeAsync(); - } - - _rebuildLock.Dispose(); - } - - private void AddRegistration(McpGatewayToolSourceRegistration registration) - { - lock (_gate) - { - ThrowIfDisposed(); - _registrations.Add(registration); - _snapshot = null; - } - } - - private List CopyRegistrations() - { - lock (_gate) - { - ThrowIfDisposed(); - return _registrations.ToList(); - } - } - - private async Task GetSnapshotAsync(CancellationToken cancellationToken) - { - ToolCatalogSnapshot? snapshot; - lock (_gate) - { - snapshot = _snapshot; - } - - if (snapshot is not null) - { - return snapshot; - } - - await BuildIndexAsync(cancellationToken); - - lock (_gate) - { - return _snapshot ?? ToolCatalogSnapshot.Empty; - } - } - - private static McpGatewayToolDescriptor? BuildDescriptor( - McpGatewayToolSourceRegistration registration, - AITool tool) - { - if (string.IsNullOrWhiteSpace(tool.Name)) - { - return null; - } - - var toolName = tool.Name.Trim(); - var sourceKind = registration.Kind switch - { - McpGatewaySourceRegistrationKind.Http => McpGatewaySourceKind.HttpMcp, - McpGatewaySourceRegistrationKind.Stdio => McpGatewaySourceKind.StdioMcp, - McpGatewaySourceRegistrationKind.CustomMcpClient => McpGatewaySourceKind.CustomMcpClient, - _ => McpGatewaySourceKind.Local - }; - - var inputSchemaJson = ResolveInputSchemaJson(tool); - var requiredArguments = ExtractRequiredArguments(inputSchemaJson); - - return new McpGatewayToolDescriptor( - ToolId: $"{registration.SourceId}:{toolName}", - SourceId: registration.SourceId, - SourceKind: sourceKind, - ToolName: toolName, - DisplayName: ResolveDisplayName(tool), - Description: tool.Description ?? string.Empty, - RequiredArguments: requiredArguments, - InputSchemaJson: inputSchemaJson); - } - - private string BuildDescriptorDocument(McpGatewayToolDescriptor descriptor, AITool tool) - { - var builder = new StringBuilder(); - builder.Append(ToolNameLabel); - builder.AppendLine(descriptor.ToolName); - - if (!string.IsNullOrWhiteSpace(descriptor.DisplayName)) - { - builder.Append(DisplayNameLabel); - builder.AppendLine(descriptor.DisplayName); - } - - if (!string.IsNullOrWhiteSpace(descriptor.Description)) - { - builder.Append(DescriptionLabel); - builder.AppendLine(descriptor.Description); - } - - if (descriptor.RequiredArguments.Count > 0) - { - builder.Append(RequiredArgumentsLabel); - builder.AppendLine(string.Join(", ", descriptor.RequiredArguments)); - } - - AppendInputSchema(builder, descriptor.InputSchemaJson); - var document = builder.ToString().Trim(); - return document.Length <= _maxDescriptorLength - ? document - : document[.._maxDescriptorLength]; - } - - private static void AppendInputSchema(StringBuilder builder, string? inputSchemaJson) - { - if (string.IsNullOrWhiteSpace(inputSchemaJson)) - { - return; - } - - try - { - using var schemaDocument = JsonDocument.Parse(inputSchemaJson); - if (!schemaDocument.RootElement.TryGetProperty(InputSchemaPropertiesPropertyName, out var properties) || - properties.ValueKind != JsonValueKind.Object) - { - return; - } - - foreach (var property in properties.EnumerateObject()) - { - builder.Append(ParameterLabel); - builder.Append(property.Name); - builder.Append(": "); - - if (property.Value.TryGetProperty(InputSchemaDescriptionPropertyName, out var description) && - description.ValueKind == JsonValueKind.String) - { - builder.Append(description.GetString()); - builder.Append(". "); - } - - if (property.Value.TryGetProperty(InputSchemaTypePropertyName, out var type) && - type.ValueKind == JsonValueKind.String) - { - builder.Append(TypeLabel); - builder.Append(type.GetString()); - builder.Append(". "); - } - - if (property.Value.TryGetProperty(InputSchemaEnumPropertyName, out var enumValues) && - enumValues.ValueKind == JsonValueKind.Array) - { - var values = enumValues - .EnumerateArray() - .Where(static item => item.ValueKind == JsonValueKind.String) - .Select(static item => item.GetString()) - .Where(static value => !string.IsNullOrWhiteSpace(value)) - .Select(static value => value!) - .ToList(); - if (values.Count > 0) - { - builder.Append(TypicalValuesLabel); - builder.Append(string.Join(", ", values)); - builder.Append(". "); - } - } - - builder.AppendLine(); - } - } - catch (JsonException) - { - builder.Append(InputSchemaLabel); - builder.AppendLine(inputSchemaJson); - } - } - - private static string? ResolveDisplayName(AITool tool) - { - if (tool is McpClientTool mcpTool) - { - return mcpTool.ProtocolTool?.Title; - } - - var function = tool as AIFunction ?? tool.GetService(); - if (function?.AdditionalProperties is { Count: > 0 } && - function.AdditionalProperties.TryGetValue(DisplayNamePropertyName, out var displayName) && - displayName is string value && - !string.IsNullOrWhiteSpace(value)) - { - return value; - } - - return null; - } - - private static string? ResolveInputSchemaJson(AITool tool) - { - if (tool is McpClientTool mcpTool) - { - return SerializeSchema(mcpTool.ProtocolTool?.InputSchema); - } - - var function = tool as AIFunction ?? tool.GetService(); - if (function is null) - { - return null; - } - - return function.JsonSchema.ValueKind == JsonValueKind.Undefined - ? null - : function.JsonSchema.GetRawText(); - } - - private static string? SerializeSchema(object? schema) - { - return schema switch - { - null => null, - JsonElement element when element.ValueKind is JsonValueKind.Null or JsonValueKind.Undefined => null, - JsonElement element => element.GetRawText(), - JsonNode node => node.ToJsonString(), - _ => JsonSerializer.Serialize(schema) - }; - } - - private static IReadOnlyList ExtractRequiredArguments(string? inputSchemaJson) - { - if (string.IsNullOrWhiteSpace(inputSchemaJson)) - { - return []; - } - - try - { - using var schemaDocument = JsonDocument.Parse(inputSchemaJson); - if (!schemaDocument.RootElement.TryGetProperty(InputSchemaRequiredPropertyName, out var required) || - required.ValueKind != JsonValueKind.Array) - { - return []; - } - - return required - .EnumerateArray() - .Where(static item => item.ValueKind == JsonValueKind.String) - .Select(static item => item.GetString()) - .Where(static value => !string.IsNullOrWhiteSpace(value)) - .Select(static value => value!) - .Distinct(StringComparer.OrdinalIgnoreCase) - .ToList(); - } - catch (JsonException) - { - return []; - } - } - - private static double CalculateCosine(ToolCatalogEntry entry, float[] queryVector, double queryMagnitude) - { - if (entry.Vector is null || entry.Magnitude <= double.Epsilon || queryMagnitude <= double.Epsilon) - { - return 0d; - } - - var overlap = Math.Min(entry.Vector.Length, queryVector.Length); - if (overlap == 0) - { - return 0d; - } - - var dot = 0d; - for (var index = 0; index < overlap; index++) - { - dot += entry.Vector[index] * queryVector[index]; - } - - return dot / (entry.Magnitude * queryMagnitude); - } - - private static double CalculateMagnitude(IReadOnlyList vector) - { - if (vector.Count == 0) - { - return 0d; - } - - var magnitudeSquared = 0d; - foreach (var component in vector) - { - magnitudeSquared += component * component; - } - - return Math.Sqrt(magnitudeSquared); - } - - private static bool ApplyEmbedding( - IList entries, - int index, - IReadOnlyList vector, - ref int vectorizedToolCount) - { - if (vector.Count == 0) - { - return false; - } - - var normalizedVector = vector.ToArray(); - var magnitude = CalculateMagnitude(normalizedVector); - entries[index] = entries[index] with - { - Vector = normalizedVector, - Magnitude = magnitude - }; - - if (magnitude <= double.Epsilon) - { - return false; - } - - vectorizedToolCount++; - return true; - } - - private static string ComputeDocumentHash(string value) - => Convert.ToHexString(SHA256.HashData(Encoding.UTF8.GetBytes(value))); - - private static bool MatchesStoredEmbedding( - McpGatewayToolEmbeddingLookup lookup, - McpGatewayToolEmbedding embedding) - => string.Equals(embedding.ToolId, lookup.ToolId, StringComparison.OrdinalIgnoreCase) - && string.Equals(embedding.DocumentHash, lookup.DocumentHash, StringComparison.Ordinal) - && (lookup.EmbeddingGeneratorFingerprint is null - || string.Equals( - embedding.EmbeddingGeneratorFingerprint, - lookup.EmbeddingGeneratorFingerprint, - StringComparison.Ordinal)); - - private static string? ResolveEmbeddingGeneratorFingerprint( - IEmbeddingGenerator>? embeddingGenerator) - { - if (embeddingGenerator is null) - { - return null; - } - - var metadata = embeddingGenerator.GetService(typeof(EmbeddingGeneratorMetadata)) as EmbeddingGeneratorMetadata; - var generatorTypeName = embeddingGenerator.GetType().FullName ?? embeddingGenerator.GetType().Name; - - return ComputeDocumentHash(string.Join( - EmbeddingGeneratorFingerprintComponentSeparator, - metadata?.ProviderName ?? EmbeddingGeneratorFingerprintUnknownComponent, - metadata?.ProviderUri?.AbsoluteUri ?? EmbeddingGeneratorFingerprintUnknownComponent, - metadata?.DefaultModelId ?? EmbeddingGeneratorFingerprintUnknownComponent, - metadata?.DefaultModelDimensions?.ToString(CultureInfo.InvariantCulture) ?? EmbeddingGeneratorFingerprintUnknownComponent, - generatorTypeName ?? EmbeddingGeneratorFingerprintUnknownComponent)); - } - - private static string BuildEffectiveSearchQuery(McpGatewaySearchRequest request) - { - List parts = []; - - if (!string.IsNullOrWhiteSpace(request.Query)) - { - parts.Add(request.Query.Trim()); - } - - if (!string.IsNullOrWhiteSpace(request.ContextSummary)) - { - parts.Add(string.Concat(ContextSummaryPrefix, request.ContextSummary.Trim())); - } - - var flattenedContext = FlattenContext(request.Context); - if (!string.IsNullOrWhiteSpace(flattenedContext)) - { - parts.Add(string.Concat(ContextPrefix, flattenedContext)); - } - - return string.Join(" | ", parts); - } - - private static string? FlattenContext(IReadOnlyDictionary? context) - { - if (context is not { Count: > 0 }) - { - return null; - } - - var terms = new List(); - foreach (var (key, value) in context) - { - AppendContextTerms(terms, key, value); - } - - return terms.Count == 0 - ? null - : string.Join("; ", terms); - } - - private static void AppendContextTerms(List terms, string key, object? value) - { - if (value is null) - { - return; - } - - switch (value) - { - case string text when !string.IsNullOrWhiteSpace(text): - terms.Add(FormattableString.Invariant($"{key} {text.Trim()}")); - return; - - case JsonElement element: - AppendJsonElementTerms(terms, key, element); - return; - - case JsonNode node: - if (node is not null) - { - AppendJsonElementTerms(terms, key, JsonSerializer.SerializeToElement(node)); - } - return; - - case IReadOnlyDictionary dictionary: - foreach (var (childKey, childValue) in dictionary) - { - AppendContextTerms(terms, $"{key} {childKey}", childValue); - } - return; - - case IEnumerable> dictionaryEntries: - foreach (var (childKey, childValue) in dictionaryEntries) - { - AppendContextTerms(terms, $"{key} {childKey}", childValue); - } - return; - - case System.Collections.IDictionary legacyDictionary: - foreach (System.Collections.DictionaryEntry entry in legacyDictionary) - { - var childKey = Convert.ToString(entry.Key, CultureInfo.InvariantCulture); - if (!string.IsNullOrWhiteSpace(childKey)) - { - AppendContextTerms(terms, $"{key} {childKey}", entry.Value); - } - } - return; - - case System.Collections.IEnumerable enumerable when value is not string: - foreach (var item in enumerable) - { - AppendContextTerms(terms, key, item); - } - return; - - default: - var scalar = Convert.ToString(value, CultureInfo.InvariantCulture); - if (!string.IsNullOrWhiteSpace(scalar)) - { - terms.Add(FormattableString.Invariant($"{key} {scalar}")); - } - return; - } - } - - private static void AppendJsonElementTerms(List terms, string key, JsonElement element) - { - switch (element.ValueKind) - { - case JsonValueKind.Object: - foreach (var property in element.EnumerateObject()) - { - AppendJsonElementTerms(terms, $"{key} {property.Name}", property.Value); - } - return; - - case JsonValueKind.Array: - foreach (var item in element.EnumerateArray()) - { - AppendJsonElementTerms(terms, key, item); - } - return; - - case JsonValueKind.String: - var text = element.GetString(); - if (!string.IsNullOrWhiteSpace(text)) - { - terms.Add(FormattableString.Invariant($"{key} {text.Trim()}")); - } - return; - - case JsonValueKind.True: - case JsonValueKind.False: - terms.Add(FormattableString.Invariant($"{key} {element.GetBoolean()}")); - return; - - case JsonValueKind.Number: - terms.Add(FormattableString.Invariant($"{key} {element}")); - return; - - default: - return; - } - } - - private static void MapRequestArgument( - IDictionary arguments, - McpGatewayToolDescriptor descriptor, - string argumentName, - object? value) - { - if (value is null || - arguments.ContainsKey(argumentName) || - !SupportsArgument(descriptor, argumentName)) - { - return; - } - - if (value is string text && string.IsNullOrWhiteSpace(text)) - { - return; - } - - arguments[argumentName] = value; - } - - private static bool SupportsArgument( - McpGatewayToolDescriptor descriptor, - string argumentName) - { - if (descriptor.RequiredArguments.Contains(argumentName, StringComparer.OrdinalIgnoreCase)) - { - return true; - } - - if (string.IsNullOrWhiteSpace(descriptor.InputSchemaJson)) - { - return false; - } - - try - { - using var schemaDocument = JsonDocument.Parse(descriptor.InputSchemaJson); - if (!schemaDocument.RootElement.TryGetProperty(InputSchemaPropertiesPropertyName, out var properties) || - properties.ValueKind != JsonValueKind.Object) - { - return false; - } - - return properties - .EnumerateObject() - .Any(property => string.Equals(property.Name, argumentName, StringComparison.OrdinalIgnoreCase)); - } - catch (JsonException) - { - return false; - } - } - - private static McpClientTool AttachInvocationMeta(McpClientTool tool, McpGatewayInvokeRequest request) - { - var meta = BuildInvocationMeta(request); - return meta is null ? tool : tool.WithMeta(meta); - } - - private static JsonObject? BuildInvocationMeta(McpGatewayInvokeRequest request) - { - var payload = new JsonObject(); - if (!string.IsNullOrWhiteSpace(request.Query)) - { - payload[QueryArgumentName] = request.Query.Trim(); - } - - if (!string.IsNullOrWhiteSpace(request.ContextSummary)) - { - payload[ContextSummaryArgumentName] = request.ContextSummary.Trim(); - } - - if (request.Context is { Count: > 0 }) - { - var contextNode = JsonSerializer.SerializeToNode(request.Context); - if (contextNode is not null) - { - payload[ContextArgumentName] = contextNode; - } - } - - return payload.Count == 0 - ? null - : new JsonObject - { - [GatewayInvocationMetaKey] = payload - }; - } - - private static double ApplySearchBoosts(ToolCatalogEntry entry, string query, double score) - { - if (string.Equals(entry.Descriptor.ToolName, query, StringComparison.OrdinalIgnoreCase)) - { - score += 0.1d; - } - else if (entry.Descriptor.ToolName.Contains(query, StringComparison.OrdinalIgnoreCase)) - { - score += 0.03d; - } - - return Math.Clamp(score, 0d, 1d); - } - - private static IReadOnlyList RankLexically( - IReadOnlyList entries, - string query) - { - var searchTerms = BuildSearchTerms(query); - return entries - .Select(entry => new ScoredToolEntry(entry, CalculateLexicalScore(entry, query, searchTerms))) - .OrderByDescending(static item => item.Score) - .ThenBy(static item => item.Entry.Descriptor.ToolName, StringComparer.OrdinalIgnoreCase) - .ToList(); - } - - private static double CalculateLexicalScore( - ToolCatalogEntry entry, - string query, - IReadOnlySet searchTerms) - { - if (searchTerms.Count == 0) - { - return 0d; - } - - var corpus = BuildSearchTerms(entry.Document); - - var score = 0d; - foreach (var term in searchTerms) - { - if (corpus.Contains(term)) - { - score += 1d; - continue; - } - - if (corpus.Any(candidate => - candidate.StartsWith(term, StringComparison.OrdinalIgnoreCase) || - term.StartsWith(candidate, StringComparison.OrdinalIgnoreCase))) - { - score += 0.35d; - } - } - - if (entry.Descriptor.ToolName.Contains(query, StringComparison.OrdinalIgnoreCase)) - { - score += 2d; - } - - return score; - } - - private static HashSet BuildSearchTerms(string? text) - { - if (string.IsNullOrWhiteSpace(text)) - { - return []; - } - - var terms = new HashSet(StringComparer.OrdinalIgnoreCase); - foreach (var token in text.Split(TokenSeparators, StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) - { - if (token.Length < 2) - { - continue; - } - - var normalized = token.ToLowerInvariant(); - terms.Add(normalized); - - if (normalized.Length > 3 && normalized.EndsWith(PluralSuffixIes, StringComparison.Ordinal)) - { - terms.Add($"{normalized[..^3]}y"); - continue; - } - - if (normalized.Length > 3 && normalized.EndsWith(PluralSuffixEs, StringComparison.Ordinal)) - { - terms.Add(normalized[..^2]); - } - else if (normalized.Length > 3 && normalized.EndsWith('s')) - { - terms.Add(normalized[..^1]); - } - } - - return terms; - } - - private static McpGatewaySearchMatch ToSearchMatch(ToolCatalogEntry entry, double score) - => new( - entry.Descriptor.ToolId, - entry.Descriptor.SourceId, - entry.Descriptor.SourceKind, - entry.Descriptor.ToolName, - entry.Descriptor.DisplayName, - entry.Descriptor.Description, - entry.Descriptor.RequiredArguments, - entry.Descriptor.InputSchemaJson, - score); - - private static InvocationResolution ResolveInvocationTarget( - ToolCatalogSnapshot snapshot, - McpGatewayInvokeRequest request) - { - if (!string.IsNullOrWhiteSpace(request.ToolId)) - { - var byToolId = snapshot.Entries.FirstOrDefault(item => - string.Equals(item.Descriptor.ToolId, request.ToolId, StringComparison.OrdinalIgnoreCase)); - return byToolId is null - ? InvocationResolution.Fail(string.Format(CultureInfo.InvariantCulture, ToolIdNotFoundMessageFormat, request.ToolId)) - : InvocationResolution.Success(byToolId); - } - - if (string.IsNullOrWhiteSpace(request.ToolName)) - { - return InvocationResolution.Fail(ToolIdOrToolNameRequiredMessage); - } - - var candidates = snapshot.Entries - .Where(item => string.Equals(item.Descriptor.ToolName, request.ToolName, StringComparison.OrdinalIgnoreCase)) - .Where(item => string.IsNullOrWhiteSpace(request.SourceId) || - string.Equals(item.Descriptor.SourceId, request.SourceId, StringComparison.OrdinalIgnoreCase)) - .ToList(); - - return candidates.Count switch - { - 0 => InvocationResolution.Fail(string.Format(CultureInfo.InvariantCulture, ToolIdNotFoundMessageFormat, request.ToolName)), - 1 => InvocationResolution.Success(candidates[0]), - _ => InvocationResolution.Fail( - string.Format(CultureInfo.InvariantCulture, ToolNameAmbiguousMessageFormat, request.ToolName)) - }; - } - - private static object? ExtractMcpOutput(CallToolResult result) - { - if (result.StructuredContent is JsonElement element) - { - return element.Clone(); - } - - var text = result.Content? - .OfType() - .FirstOrDefault(static block => !string.IsNullOrWhiteSpace(block.Text)) - ?.Text; - if (string.IsNullOrWhiteSpace(text)) - { - return result; - } - - try - { - using var document = JsonDocument.Parse(text); - return document.RootElement.Clone(); - } - catch (JsonException) - { - return text; - } - } - - private static object? NormalizeFunctionOutput(object? value) - { - return value switch - { - JsonElement element => NormalizeJsonElement(element), - JsonDocument document => NormalizeJsonElement(document.RootElement), - _ => value - }; - } - - private static object? NormalizeJsonElement(JsonElement element) - { - return element.ValueKind switch - { - JsonValueKind.Null or JsonValueKind.Undefined => null, - JsonValueKind.String => element.GetString(), - JsonValueKind.True or JsonValueKind.False => element.GetBoolean(), - JsonValueKind.Number when element.TryGetInt64(out var int64Value) => int64Value, - JsonValueKind.Number when element.TryGetDecimal(out var decimalValue) => decimalValue, - JsonValueKind.Number when element.TryGetDouble(out var doubleValue) => doubleValue, - _ => element.Clone() - }; - } - - private void ThrowIfDisposed() - { - ObjectDisposedException.ThrowIf(_disposed, this); - } - - private sealed record InvocationResolution(bool IsSuccess, ToolCatalogEntry? Entry, string? Error) - { - public static InvocationResolution Success(ToolCatalogEntry entry) => new(true, entry, null); - - public static InvocationResolution Fail(string error) => new(false, null, error); - } - - private sealed record ToolEmbeddingCandidate( - int Index, - McpGatewayToolEmbeddingLookup Lookup, - string SourceId, - string ToolName); - - private sealed record ScoredToolEntry(ToolCatalogEntry Entry, double Score); - - private sealed record ToolCatalogEntry( - McpGatewayToolDescriptor Descriptor, - AITool Tool, - string Document, - float[]? Vector = null, - double Magnitude = 0d); - - private sealed record ToolCatalogSnapshot(IReadOnlyList Entries, bool HasVectors) - { - public static ToolCatalogSnapshot Empty { get; } = new([], false); - } - - private sealed class EmbeddingGeneratorLease( - IEmbeddingGenerator>? generator, - AsyncServiceScope? scope = null) - : IAsyncDisposable - { - public IEmbeddingGenerator>? Generator { get; } = generator; - - public ValueTask DisposeAsync() => scope?.DisposeAsync() ?? ValueTask.CompletedTask; - } - - private sealed class ToolEmbeddingStoreLease( - IMcpGatewayToolEmbeddingStore? store, - AsyncServiceScope? scope = null) - : IAsyncDisposable - { - public IMcpGatewayToolEmbeddingStore? Store { get; } = store; - - public ValueTask DisposeAsync() => scope?.DisposeAsync() ?? ValueTask.CompletedTask; + => _runtime.CreateMetaTools(searchToolName, invokeToolName); + + public ValueTask DisposeAsync() => _runtime.DisposeAsync(); + + private static McpGatewayRuntime CreateRuntime( + IServiceProvider serviceProvider, + IOptions options, + ILogger logger, + ILoggerFactory loggerFactory) + { + ArgumentNullException.ThrowIfNull(serviceProvider); + ArgumentNullException.ThrowIfNull(options); + ArgumentNullException.ThrowIfNull(logger); + ArgumentNullException.ThrowIfNull(loggerFactory); + + return new McpGatewayRuntime( + serviceProvider, + options, + loggerFactory.CreateLogger(), + loggerFactory); } } diff --git a/src/ManagedCode.MCPGateway/McpGatewayToolSet.cs b/src/ManagedCode.MCPGateway/McpGatewayToolSet.cs index 1043ae9..07be5e3 100644 --- a/src/ManagedCode.MCPGateway/McpGatewayToolSet.cs +++ b/src/ManagedCode.MCPGateway/McpGatewayToolSet.cs @@ -7,6 +7,8 @@ public sealed class McpGatewayToolSet(IMcpGateway gateway) { public const string DefaultSearchToolName = "gateway_tools_search"; public const string DefaultInvokeToolName = "gateway_tool_invoke"; + public const string SearchToolDescription = "Search the gateway catalog and return the best matching tools for a user task."; + public const string InvokeToolDescription = "Invoke a gateway tool by tool id. Search first when the correct tool is unknown."; public IReadOnlyList CreateTools( string searchToolName = DefaultSearchToolName, @@ -17,7 +19,7 @@ public IReadOnlyList CreateTools( new AIFunctionFactoryOptions { Name = searchToolName, - Description = "Search the gateway catalog and return the best matching tools for a user task." + Description = SearchToolDescription }); var invokeTool = AIFunctionFactory.Create( @@ -25,7 +27,7 @@ public IReadOnlyList CreateTools( new AIFunctionFactoryOptions { Name = invokeToolName, - Description = "Invoke a gateway tool by tool id. Search first when the correct tool is unknown." + Description = InvokeToolDescription }); return [searchTool, invokeTool]; diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayDiagnostic.cs b/src/ManagedCode.MCPGateway/Models/Catalog/McpGatewayDiagnostic.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewayDiagnostic.cs rename to src/ManagedCode.MCPGateway/Models/Catalog/McpGatewayDiagnostic.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayIndexBuildResult.cs b/src/ManagedCode.MCPGateway/Models/Catalog/McpGatewayIndexBuildResult.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewayIndexBuildResult.cs rename to src/ManagedCode.MCPGateway/Models/Catalog/McpGatewayIndexBuildResult.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewaySourceKind.cs b/src/ManagedCode.MCPGateway/Models/Catalog/McpGatewaySourceKind.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewaySourceKind.cs rename to src/ManagedCode.MCPGateway/Models/Catalog/McpGatewaySourceKind.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayToolDescriptor.cs b/src/ManagedCode.MCPGateway/Models/Catalog/McpGatewayToolDescriptor.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewayToolDescriptor.cs rename to src/ManagedCode.MCPGateway/Models/Catalog/McpGatewayToolDescriptor.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayToolEmbedding.cs b/src/ManagedCode.MCPGateway/Models/Embeddings/McpGatewayToolEmbedding.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewayToolEmbedding.cs rename to src/ManagedCode.MCPGateway/Models/Embeddings/McpGatewayToolEmbedding.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayToolEmbeddingLookup.cs b/src/ManagedCode.MCPGateway/Models/Embeddings/McpGatewayToolEmbeddingLookup.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewayToolEmbeddingLookup.cs rename to src/ManagedCode.MCPGateway/Models/Embeddings/McpGatewayToolEmbeddingLookup.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayInvokeRequest.cs b/src/ManagedCode.MCPGateway/Models/Invocation/McpGatewayInvokeRequest.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewayInvokeRequest.cs rename to src/ManagedCode.MCPGateway/Models/Invocation/McpGatewayInvokeRequest.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewayInvokeResult.cs b/src/ManagedCode.MCPGateway/Models/Invocation/McpGatewayInvokeResult.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewayInvokeResult.cs rename to src/ManagedCode.MCPGateway/Models/Invocation/McpGatewayInvokeResult.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewaySearchMatch.cs b/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchMatch.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewaySearchMatch.cs rename to src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchMatch.cs diff --git a/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchQueryNormalization.cs b/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchQueryNormalization.cs new file mode 100644 index 0000000..ae90f92 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchQueryNormalization.cs @@ -0,0 +1,7 @@ +namespace ManagedCode.MCPGateway; + +public enum McpGatewaySearchQueryNormalization +{ + Disabled = 0, + TranslateToEnglishWhenAvailable = 1 +} diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewaySearchRequest.cs b/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchRequest.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewaySearchRequest.cs rename to src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchRequest.cs diff --git a/src/ManagedCode.MCPGateway/Models/McpGatewaySearchResult.cs b/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchResult.cs similarity index 100% rename from src/ManagedCode.MCPGateway/Models/McpGatewaySearchResult.cs rename to src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchResult.cs diff --git a/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchStrategy.cs b/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchStrategy.cs new file mode 100644 index 0000000..d9a2f36 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Models/Search/McpGatewaySearchStrategy.cs @@ -0,0 +1,8 @@ +namespace ManagedCode.MCPGateway; + +public enum McpGatewaySearchStrategy +{ + Auto = 0, + Embeddings = 1, + Tokenizer = 2 +} diff --git a/src/ManagedCode.MCPGateway/Properties/AssemblyInfo.cs b/src/ManagedCode.MCPGateway/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..a8af786 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("ManagedCode.MCPGateway.Tests")] diff --git a/src/ManagedCode.MCPGateway/Registration/McpGatewayServiceCollectionExtensions.cs b/src/ManagedCode.MCPGateway/Registration/McpGatewayServiceCollectionExtensions.cs index be4d6d4..b8da418 100644 --- a/src/ManagedCode.MCPGateway/Registration/McpGatewayServiceCollectionExtensions.cs +++ b/src/ManagedCode.MCPGateway/Registration/McpGatewayServiceCollectionExtensions.cs @@ -1,6 +1,7 @@ using ManagedCode.MCPGateway.Abstractions; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Hosting; namespace ManagedCode.MCPGateway; @@ -10,17 +11,26 @@ public static IServiceCollection AddManagedCodeMcpGateway( this IServiceCollection services, Action? configure = null) { + ArgumentNullException.ThrowIfNull(services); + services.AddOptions(); if (configure is not null) { services.Configure(configure); } - services.TryAddSingleton(); - services.TryAddSingleton(serviceProvider => serviceProvider.GetRequiredService()); - services.TryAddSingleton(serviceProvider => serviceProvider.GetRequiredService()); + services.TryAddSingleton(); + services.TryAddSingleton(); services.TryAddSingleton(); return services; } + + public static IServiceCollection AddManagedCodeMcpGatewayIndexWarmup(this IServiceCollection services) + { + ArgumentNullException.ThrowIfNull(services); + + services.TryAddEnumerable(ServiceDescriptor.Singleton()); + return services; + } } diff --git a/src/ManagedCode.MCPGateway/Registration/McpGatewayServiceProviderExtensions.cs b/src/ManagedCode.MCPGateway/Registration/McpGatewayServiceProviderExtensions.cs new file mode 100644 index 0000000..f315467 --- /dev/null +++ b/src/ManagedCode.MCPGateway/Registration/McpGatewayServiceProviderExtensions.cs @@ -0,0 +1,15 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway; + +public static class McpGatewayServiceProviderExtensions +{ + public static Task InitializeManagedCodeMcpGatewayAsync( + this IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(serviceProvider); + return serviceProvider.GetRequiredService().BuildIndexAsync(cancellationToken); + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationLocalTests.cs b/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationLocalTests.cs new file mode 100644 index 0000000..f72d9e7 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationLocalTests.cs @@ -0,0 +1,220 @@ +using ManagedCode.MCPGateway.Abstractions; + +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewayInvocationTests +{ + [TUnit.Core.Test] + public async Task InvokeAsync_InvokesLocalFunctionAndMapsQueryArgument() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool( + "local", + TestFunctionFactory.CreateFunction(TextUppercase, "text_uppercase", "Convert query text to uppercase.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:text_uppercase", + Query: "hello gateway")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(invokeResult.Output).IsTypeOf(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("HELLO GATEWAY"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_MapsQueryArgumentWhenSchemaMarksItOptional() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool( + "local", + TestFunctionFactory.CreateFunction(OptionalQueryEcho, "optional_query_echo", "Echo optional query text in uppercase.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:optional_query_echo", + Query: "hello gateway")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(invokeResult.Output).IsTypeOf(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("HELLO GATEWAY"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_MapsContextSummaryToRequiredLocalArguments() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool( + "local", + TestFunctionFactory.CreateFunction(EchoContextSummary, "context_summary_echo", "Echo query and context summary.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:context_summary_echo", + Query: "open github", + ContextSummary: "user is on repository settings page")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("open github|user is on repository settings page"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_MapsStructuredContextToRequiredLocalArguments() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool( + "local", + TestFunctionFactory.CreateFunction(ReadStructuredContext, "structured_context_echo", "Read structured context payload.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:structured_context_echo", + Context: new Dictionary + { + ["domain"] = "genealogy", + ["page"] = "tree-profile" + })); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("genealogy|tree-profile"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_PrefersExplicitArgumentsOverMappedValues() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool( + "local", + TestFunctionFactory.CreateFunction(EchoContextSummary, "context_summary_echo", "Echo query and context summary.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:context_summary_echo", + Query: "mapped query", + ContextSummary: "mapped summary", + Arguments: new Dictionary + { + ["query"] = "explicit query", + ["contextSummary"] = "explicit summary" + })); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("explicit query|explicit summary"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_ResolvesByToolNameAndSourceId() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSharedSearchTools); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolName: "shared_search", + SourceId: "beta", + Query: "hello")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("beta:hello"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_ReturnsAmbiguousErrorWhenToolNameExistsInMultipleSources() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSharedSearchTools); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolName: "shared_search", + Query: "hello")); + + await Assert.That(invokeResult.IsSuccess).IsFalse(); + await Assert.That(invokeResult.Error!.Contains("ambiguous", StringComparison.OrdinalIgnoreCase)).IsTrue(); + } + + private static void ConfigureSharedSearchTools(McpGatewayOptions options) + { + options.AddTool("alpha", TestFunctionFactory.CreateFunction(AlphaSharedSearch, "shared_search", "Alpha search tool.")); + options.AddTool("beta", TestFunctionFactory.CreateFunction(BetaSharedSearch, "shared_search", "Beta search tool.")); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_ReturnsNotFoundWhenToolDoesNotExist() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool("local", TestFunctionFactory.CreateFunction(TextUppercase, "text_uppercase", "Convert query text to uppercase.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:missing_tool")); + + await Assert.That(invokeResult.IsSuccess).IsFalse(); + await Assert.That(invokeResult.Error!.Contains("was not found", StringComparison.OrdinalIgnoreCase)).IsTrue(); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_NormalizesJsonScalarOutputs() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool("local", TestFunctionFactory.CreateFunction(ReturnJsonString, "json_string_result", "Return a JSON string scalar.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:json_string_result")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(invokeResult.Output).IsTypeOf(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("done"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_ReturnsFailureWhenLocalFunctionThrows() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool("local", TestFunctionFactory.CreateFunction(ThrowingTool, "throwing_tool", "Throw an exception for test coverage.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "local:throwing_tool")); + + await Assert.That(invokeResult.IsSuccess).IsFalse(); + await Assert.That(invokeResult.Error).IsEqualTo("boom"); + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationMcpTests.cs b/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationMcpTests.cs new file mode 100644 index 0000000..6f46df3 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationMcpTests.cs @@ -0,0 +1,152 @@ +using System.Text.Json; + +using ManagedCode.MCPGateway.Abstractions; + +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewayInvocationTests +{ + [TUnit.Core.Test] + public async Task InvokeAsync_InvokesStructuredMcpToolAndMapsQueryArgument() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "test-mcp:github_repository_search", + Query: "managedcode")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(invokeResult.Output).IsTypeOf(); + + var output = (JsonElement)invokeResult.Output!; + await Assert.That(GetJsonProperty(output, "query").GetString()).IsEqualTo("managedcode"); + await Assert.That(GetJsonProperty(output, "source").GetString()).IsEqualTo("mcp"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_PassesContextMetaToMcpToolRequests() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "test-mcp:github_repository_search", + Query: "managedcode", + ContextSummary: "user is on repository settings page", + Context: new Dictionary + { + ["page"] = "settings", + ["domain"] = "github" + })); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(serverHost.CapturedMeta.Count > 0).IsTrue(); + + var payload = serverHost.CapturedMeta[^1]; + await Assert.That(payload.TryGetPropertyValue("managedCodeMcpGateway", out var gatewayNode)).IsTrue(); + + var gatewayMeta = gatewayNode!.AsObject(); + await Assert.That(gatewayMeta["query"]!.GetValue()).IsEqualTo("managedcode"); + await Assert.That(gatewayMeta["contextSummary"]!.GetValue()).IsEqualTo("user is on repository settings page"); + await Assert.That(gatewayMeta["context"]!["page"]!.GetValue()).IsEqualTo("settings"); + await Assert.That(gatewayMeta["context"]!["domain"]!.GetValue()).IsEqualTo("github"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_IgnoresUnserializableContextMeta() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); + }); + + var gateway = serviceProvider.GetRequiredService(); + var cyclicContext = new CyclicInvocationContext(); + cyclicContext.Self = cyclicContext; + + await gateway.BuildIndexAsync(); + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "test-mcp:github_repository_search", + Query: "managedcode", + Context: new Dictionary + { + ["broken"] = cyclicContext + })); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(serverHost.CapturedMeta.Count > 0).IsTrue(); + + var payload = serverHost.CapturedMeta[^1]; + await Assert.That(payload.TryGetPropertyValue("managedCodeMcpGateway", out var gatewayNode)).IsTrue(); + + var gatewayMeta = gatewayNode!.AsObject(); + await Assert.That(gatewayMeta["query"]!.GetValue()).IsEqualTo("managedcode"); + await Assert.That(gatewayMeta.ContainsKey("context")).IsFalse(); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_ParsesJsonTextContentFromMcpTool() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "test-mcp:json_text_search", + Query: "managedcode")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(invokeResult.Output).IsTypeOf(); + + var output = (JsonElement)invokeResult.Output!; + await Assert.That(GetJsonProperty(output, "query").GetString()).IsEqualTo("managedcode"); + await Assert.That(GetJsonProperty(output, "source").GetString()).IsEqualTo("text-json"); + } + + [TUnit.Core.Test] + public async Task InvokeAsync_ReturnsPlainTextWhenMcpTextContentIsNotJson() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); + }); + + var gateway = serviceProvider.GetRequiredService(); + await gateway.BuildIndexAsync(); + + var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( + ToolId: "test-mcp:plain_text_search", + Query: "managedcode")); + + await Assert.That(invokeResult.IsSuccess).IsTrue(); + await Assert.That(invokeResult.Output).IsTypeOf(); + await Assert.That((string)invokeResult.Output!).IsEqualTo("plain:managedcode"); + } + + private sealed class CyclicInvocationContext + { + public CyclicInvocationContext? Self { get; set; } + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationTests.cs b/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationTests.cs index d592e26..24c2b86 100644 --- a/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationTests.cs +++ b/tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationTests.cs @@ -1,342 +1,10 @@ using System.ComponentModel; using System.Text.Json; -using ManagedCode.MCPGateway.Abstractions; - -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; - namespace ManagedCode.MCPGateway.Tests; -public sealed class McpGatewayInvocationTests +public sealed partial class McpGatewayInvocationTests { - [TUnit.Core.Test] - public async Task InvokeAsync_InvokesLocalFunctionAndMapsQueryArgument() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool( - "local", - CreateFunction(TextUppercase, "text_uppercase", "Convert query text to uppercase.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:text_uppercase", - Query: "hello gateway")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That(invokeResult.Output).IsTypeOf(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("HELLO GATEWAY"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_MapsQueryArgumentWhenSchemaMarksItOptional() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool( - "local", - CreateFunction(OptionalQueryEcho, "optional_query_echo", "Echo optional query text in uppercase.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:optional_query_echo", - Query: "hello gateway")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That(invokeResult.Output).IsTypeOf(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("HELLO GATEWAY"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_MapsContextSummaryToRequiredLocalArguments() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool( - "local", - CreateFunction(EchoContextSummary, "context_summary_echo", "Echo query and context summary.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:context_summary_echo", - Query: "open github", - ContextSummary: "user is on repository settings page")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("open github|user is on repository settings page"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_MapsStructuredContextToRequiredLocalArguments() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool( - "local", - CreateFunction(ReadStructuredContext, "structured_context_echo", "Read structured context payload.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:structured_context_echo", - Context: new Dictionary - { - ["domain"] = "genealogy", - ["page"] = "tree-profile" - })); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("genealogy|tree-profile"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_PrefersExplicitArgumentsOverMappedValues() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool( - "local", - CreateFunction(EchoContextSummary, "context_summary_echo", "Echo query and context summary.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:context_summary_echo", - Query: "mapped query", - ContextSummary: "mapped summary", - Arguments: new Dictionary - { - ["query"] = "explicit query", - ["contextSummary"] = "explicit summary" - })); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("explicit query|explicit summary"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_ResolvesByToolNameAndSourceId() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("alpha", CreateFunction(AlphaSharedSearch, "shared_search", "Alpha search tool.")); - options.AddTool("beta", CreateFunction(BetaSharedSearch, "shared_search", "Beta search tool.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolName: "shared_search", - SourceId: "beta", - Query: "hello")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("beta:hello"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_ReturnsAmbiguousErrorWhenToolNameExistsInMultipleSources() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("alpha", CreateFunction(AlphaSharedSearch, "shared_search", "Alpha search tool.")); - options.AddTool("beta", CreateFunction(BetaSharedSearch, "shared_search", "Beta search tool.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolName: "shared_search", - Query: "hello")); - - await Assert.That(invokeResult.IsSuccess).IsFalse(); - await Assert.That(invokeResult.Error!.Contains("ambiguous", StringComparison.OrdinalIgnoreCase)).IsTrue(); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_ReturnsNotFoundWhenToolDoesNotExist() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("local", CreateFunction(TextUppercase, "text_uppercase", "Convert query text to uppercase.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:missing_tool")); - - await Assert.That(invokeResult.IsSuccess).IsFalse(); - await Assert.That(invokeResult.Error!.Contains("was not found", StringComparison.OrdinalIgnoreCase)).IsTrue(); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_NormalizesJsonScalarOutputs() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("local", CreateFunction(ReturnJsonString, "json_string_result", "Return a JSON string scalar.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:json_string_result")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That(invokeResult.Output).IsTypeOf(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("done"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_ReturnsFailureWhenLocalFunctionThrows() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("local", CreateFunction(ThrowingTool, "throwing_tool", "Throw an exception for test coverage.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "local:throwing_tool")); - - await Assert.That(invokeResult.IsSuccess).IsFalse(); - await Assert.That(invokeResult.Error).IsEqualTo("boom"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_InvokesStructuredMcpToolAndMapsQueryArgument() - { - await using var serverHost = await TestMcpServerHost.StartAsync(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "test-mcp:github_repository_search", - Query: "managedcode")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That(invokeResult.Output).IsTypeOf(); - - var output = (JsonElement)invokeResult.Output!; - await Assert.That(GetJsonProperty(output, "query").GetString()).IsEqualTo("managedcode"); - await Assert.That(GetJsonProperty(output, "source").GetString()).IsEqualTo("mcp"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_PassesContextMetaToMcpToolRequests() - { - await using var serverHost = await TestMcpServerHost.StartAsync(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "test-mcp:github_repository_search", - Query: "managedcode", - ContextSummary: "user is on repository settings page", - Context: new Dictionary - { - ["page"] = "settings", - ["domain"] = "github" - })); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That(serverHost.CapturedMeta.Count > 0).IsTrue(); - - var payload = serverHost.CapturedMeta[^1]; - await Assert.That(payload.TryGetPropertyValue("managedCodeMcpGateway", out var gatewayNode)).IsTrue(); - - var gatewayMeta = gatewayNode!.AsObject(); - await Assert.That(gatewayMeta["query"]!.GetValue()).IsEqualTo("managedcode"); - await Assert.That(gatewayMeta["contextSummary"]!.GetValue()).IsEqualTo("user is on repository settings page"); - await Assert.That(gatewayMeta["context"]!["page"]!.GetValue()).IsEqualTo("settings"); - await Assert.That(gatewayMeta["context"]!["domain"]!.GetValue()).IsEqualTo("github"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_ParsesJsonTextContentFromMcpTool() - { - await using var serverHost = await TestMcpServerHost.StartAsync(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "test-mcp:json_text_search", - Query: "managedcode")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That(invokeResult.Output).IsTypeOf(); - - var output = (JsonElement)invokeResult.Output!; - await Assert.That(GetJsonProperty(output, "query").GetString()).IsEqualTo("managedcode"); - await Assert.That(GetJsonProperty(output, "source").GetString()).IsEqualTo("text-json"); - } - - [TUnit.Core.Test] - public async Task InvokeAsync_ReturnsPlainTextWhenMcpTextContentIsNotJson() - { - await using var serverHost = await TestMcpServerHost.StartAsync(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); - }); - - var gateway = serviceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - - var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest( - ToolId: "test-mcp:plain_text_search", - Query: "managedcode")); - - await Assert.That(invokeResult.IsSuccess).IsTrue(); - await Assert.That(invokeResult.Output).IsTypeOf(); - await Assert.That((string)invokeResult.Output!).IsEqualTo("plain:managedcode"); - } - - private static AIFunction CreateFunction(Delegate callback, string name, string description) - => AIFunctionFactory.Create( - callback, - new AIFunctionFactoryOptions - { - Name = name, - Description = description - }); - private static string TextUppercase([Description("Text to uppercase.")] string query) => query.ToUpperInvariant(); private static string OptionalQueryEcho([Description("Text to uppercase.")] string? query = null) diff --git a/tests/ManagedCode.MCPGateway.Tests/MetaTools/McpGatewayMetaToolTests.cs b/tests/ManagedCode.MCPGateway.Tests/MetaTools/McpGatewayMetaToolTests.cs index eddb4a0..808368c 100644 --- a/tests/ManagedCode.MCPGateway.Tests/MetaTools/McpGatewayMetaToolTests.cs +++ b/tests/ManagedCode.MCPGateway.Tests/MetaTools/McpGatewayMetaToolTests.cs @@ -16,8 +16,8 @@ public async Task CreateMetaTools_SearchToolSupportsContextAwareRequests() var embeddingGenerator = new TestEmbeddingGenerator(); await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => { - options.AddTool("local", CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); - options.AddTool("local", CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); }, embeddingGenerator); var gateway = serviceProvider.GetRequiredService(); @@ -45,7 +45,7 @@ public async Task CreateMetaTools_InvokeToolSupportsContextSummary() { await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => { - options.AddTool("local", CreateFunction(EchoContextSummary, "context_summary_echo", "Echo query and context summary.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(EchoContextSummary, "context_summary_echo", "Echo query and context summary.")); }); var gateway = serviceProvider.GetRequiredService(); @@ -73,15 +73,6 @@ private static AIFunction GetFunction(IReadOnlyList tools, string toolNa => (tools.Single(tool => tool.Name == toolName) as AIFunction) ?? throw new InvalidOperationException($"Tool '{toolName}' is not an AIFunction."); - private static AIFunction CreateFunction(Delegate callback, string name, string description) - => AIFunctionFactory.Create( - callback, - new AIFunctionFactoryOptions - { - Name = name, - Description = description - }); - private static string SearchGitHub([Description("Search query text.")] string query) => $"github:{query}"; private static string SearchWeather([Description("City or weather request text.")] string query) => $"weather:{query}"; diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayInMemoryToolEmbeddingStoreTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayInMemoryToolEmbeddingStoreTests.cs index 99da675..73f977d 100644 --- a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayInMemoryToolEmbeddingStoreTests.cs +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayInMemoryToolEmbeddingStoreTests.cs @@ -5,30 +5,15 @@ public sealed class McpGatewayInMemoryToolEmbeddingStoreTests [TUnit.Core.Test] public async Task GetAsync_ReturnsEmbeddingsForMatchingLookups() { - var store = new McpGatewayInMemoryToolEmbeddingStore(); - await store.UpsertAsync( - [ - new McpGatewayToolEmbedding( - "local:github_search_issues", - "local", - "github_search_issues", - "hash-1", - "fingerprint-a", - [1f, 2f, 3f]), - new McpGatewayToolEmbedding( - "local:weather_search_forecast", - "local", - "weather_search_forecast", - "hash-2", - "fingerprint-a", - [4f, 5f, 6f]) - ]); + var store = await CreateStoreAsync( + CreateEmbedding("local:github_search_issues", "github_search_issues", "hash-1", "fingerprint-a", [1f, 2f, 3f]), + CreateEmbedding("local:weather_search_forecast", "weather_search_forecast", "hash-2", "fingerprint-a", [4f, 5f, 6f])); var result = await store.GetAsync( [ - new McpGatewayToolEmbeddingLookup("local:weather_search_forecast", "hash-2", "fingerprint-a"), - new McpGatewayToolEmbeddingLookup("local:missing", "hash-3", "fingerprint-a"), - new McpGatewayToolEmbeddingLookup("local:github_search_issues", "hash-1", "fingerprint-a") + CreateLookup("local:weather_search_forecast", "hash-2", "fingerprint-a"), + CreateLookup("local:missing", "hash-3", "fingerprint-a"), + CreateLookup("local:github_search_issues", "hash-1", "fingerprint-a") ]); await Assert.That(result.Count).IsEqualTo(2); @@ -42,28 +27,19 @@ public async Task UpsertAsync_ClonesVectorsOnWriteAndRead() var store = new McpGatewayInMemoryToolEmbeddingStore(); var inputVector = new[] { 1f, 2f, 3f }; - await store.UpsertAsync( - [ - new McpGatewayToolEmbedding( - "local:github_search_issues", - "local", - "github_search_issues", - "hash-1", - "fingerprint-a", - inputVector) - ]); + await store.UpsertAsync([CreateEmbedding("local:github_search_issues", "github_search_issues", "hash-1", "fingerprint-a", inputVector)]); inputVector[0] = 99f; var firstRead = await store.GetAsync( [ - new McpGatewayToolEmbeddingLookup("local:github_search_issues", "hash-1", "fingerprint-a") + CreateLookup("local:github_search_issues", "hash-1", "fingerprint-a") ]); firstRead[0].Vector[1] = 77f; var secondRead = await store.GetAsync( [ - new McpGatewayToolEmbeddingLookup("local:github_search_issues", "hash-1", "fingerprint-a") + CreateLookup("local:github_search_issues", "hash-1", "fingerprint-a") ]); await Assert.That(secondRead[0].Vector[0]).IsEqualTo(1f); @@ -74,29 +50,47 @@ await store.UpsertAsync( [TUnit.Core.Test] public async Task GetAsync_TreatsToolIdsCaseInsensitivelyAndSupportsFingerprintFallback() { - var store = new McpGatewayInMemoryToolEmbeddingStore(); - await store.UpsertAsync( - [ - new McpGatewayToolEmbedding( - "local:github_search_issues", - "local", - "github_search_issues", - "hash-1", - "fingerprint-a", - [1f, 2f, 3f]) - ]); + var store = await CreateStoreAsync( + CreateEmbedding("local:github_search_issues", "github_search_issues", "hash-1", "fingerprint-a", [1f, 2f, 3f])); var fingerprintMatch = await store.GetAsync( [ - new McpGatewayToolEmbeddingLookup("LOCAL:GITHUB_SEARCH_ISSUES", "hash-1", "fingerprint-a") + CreateLookup("LOCAL:GITHUB_SEARCH_ISSUES", "hash-1", "fingerprint-a") ]); var fingerprintAgnosticMatch = await store.GetAsync( [ - new McpGatewayToolEmbeddingLookup("LOCAL:GITHUB_SEARCH_ISSUES", "hash-1") + CreateLookup("LOCAL:GITHUB_SEARCH_ISSUES", "hash-1") ]); await Assert.That(fingerprintMatch.Count).IsEqualTo(1); await Assert.That(fingerprintAgnosticMatch.Count).IsEqualTo(1); await Assert.That(fingerprintAgnosticMatch[0].ToolId).IsEqualTo("local:github_search_issues"); } + + private static async Task CreateStoreAsync(params McpGatewayToolEmbedding[] embeddings) + { + var store = new McpGatewayInMemoryToolEmbeddingStore(); + await store.UpsertAsync(embeddings); + return store; + } + + private static McpGatewayToolEmbedding CreateEmbedding( + string toolId, + string toolName, + string documentHash, + string fingerprint, + float[] vector) + => new( + toolId, + "local", + toolName, + documentHash, + fingerprint, + vector); + + private static McpGatewayToolEmbeddingLookup CreateLookup( + string toolId, + string documentHash, + string? fingerprint = null) + => new(toolId, documentHash, fingerprint); } diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayInitializationTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayInitializationTests.cs new file mode 100644 index 0000000..bf68069 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayInitializationTests.cs @@ -0,0 +1,90 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewaySearchTests +{ + [TUnit.Core.Test] + public async Task InitializeManagedCodeMcpGatewayAsync_BuildsIndexThroughServiceProviderExtension() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); + + var buildResult = await serviceProvider.InitializeManagedCodeMcpGatewayAsync(); + var gateway = serviceProvider.GetRequiredService(); + var tools = await gateway.ListToolsAsync(); + + await Assert.That(buildResult.ToolCount).IsEqualTo(2); + await Assert.That(tools.Count).IsEqualTo(2); + } + + [TUnit.Core.Test] + public async Task AddManagedCodeMcpGatewayIndexWarmup_StartsBackgroundIndexBuild() + { + var probeGateway = new WarmupProbeGateway(); + var services = new ServiceCollection(); + services.AddLogging(static logging => logging.SetMinimumLevel(LogLevel.Debug)); + services.AddSingleton(probeGateway); + services.AddManagedCodeMcpGatewayIndexWarmup(); + + await using var serviceProvider = services.BuildServiceProvider(); + var hostedServices = serviceProvider.GetServices().ToList(); + using var cancellationSource = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + await Assert.That(hostedServices.Count).IsEqualTo(1); + + var hostedService = hostedServices.Single(); + await hostedService.StartAsync(cancellationSource.Token); + await probeGateway.BuildStarted.WaitAsync(cancellationSource.Token); + await hostedService.StopAsync(cancellationSource.Token); + + await Assert.That(probeGateway.BuildIndexCallCount).IsEqualTo(1); + } +} + +internal sealed class WarmupProbeGateway : IMcpGateway +{ + private readonly TaskCompletionSource _buildStarted = + new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _buildIndexCallCount; + + public int BuildIndexCallCount => Volatile.Read(ref _buildIndexCallCount); + + public Task BuildStarted => _buildStarted.Task; + + public Task BuildIndexAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + Interlocked.Increment(ref _buildIndexCallCount); + _buildStarted.TrySetResult(null); + return Task.FromResult(new McpGatewayIndexBuildResult(0, 0, false, [])); + } + + public Task> ListToolsAsync(CancellationToken cancellationToken = default) + => Task.FromResult>([]); + + public Task SearchAsync( + string? query, + int? maxResults = null, + CancellationToken cancellationToken = default) + => Task.FromResult(new McpGatewaySearchResult([], [], string.Empty)); + + public Task SearchAsync( + McpGatewaySearchRequest request, + CancellationToken cancellationToken = default) + => Task.FromResult(new McpGatewaySearchResult([], [], string.Empty)); + + public Task InvokeAsync( + McpGatewayInvokeRequest request, + CancellationToken cancellationToken = default) + => Task.FromResult(new McpGatewayInvokeResult(false, string.Empty, string.Empty, string.Empty, null)); + + public IReadOnlyList CreateMetaTools( + string searchToolName = McpGatewayToolSet.DefaultSearchToolName, + string invokeToolName = McpGatewayToolSet.DefaultInvokeToolName) + => []; + + public ValueTask DisposeAsync() => ValueTask.CompletedTask; +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchBuildTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchBuildTests.cs new file mode 100644 index 0000000..fd3b794 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchBuildTests.cs @@ -0,0 +1,371 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Client; +using System.Reflection; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewaySearchTests +{ + [TUnit.Core.Test] + public async Task BuildIndexAsync_ReportsEmbeddingCountMismatch() + { + var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + ReturnMismatchedBatchCount = true + }); + + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 1); + + await Assert.That(buildResult.IsVectorSearchEnabled).IsFalse(); + await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(0); + await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_count_mismatch")).IsTrue(); + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + } + + [TUnit.Core.Test] + public async Task McpGatewayOptions_DefaultSearchConfigurationUsesAutoAndTopFiveLimit() + { + var options = new McpGatewayOptions(); + + await Assert.That(options.SearchStrategy).IsEqualTo(McpGatewaySearchStrategy.Auto); + await Assert.That(options.SearchQueryNormalization).IsEqualTo(McpGatewaySearchQueryNormalization.TranslateToEnglishWhenAvailable); + await Assert.That(options.DefaultSearchLimit).IsEqualTo(5); + } + + [TUnit.Core.Test] + public async Task McpGatewayClientFactory_UsesAssemblyBuildVersionForClientInfo() + { + var clientOptions = McpGatewayClientFactory.CreateClientOptions(); + var expectedVersion = typeof(McpGatewayClientFactory).Assembly + .GetCustomAttribute()?.InformationalVersion + ?? typeof(McpGatewayClientFactory).Assembly.GetName().Version?.ToString(); + + await Assert.That(clientOptions.ClientInfo?.Version).IsEqualTo(expectedVersion); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_ReportsEmbeddingFailure() + { + var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + ThrowOnInput = static _ => true + }); + + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + + await Assert.That(buildResult.IsVectorSearchEnabled).IsFalse(); + await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(0); + await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_failed")).IsTrue(); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_SkipsDuplicateToolIds() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHubAgain, "github_search_issues", "Duplicate tool id for test coverage.")); + }); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + + await Assert.That(buildResult.ToolCount).IsEqualTo(1); + await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "duplicate_tool_id")).IsTrue(); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_RebuildsAfterNewToolIsRegistered() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + var registry = serviceProvider.GetRequiredService(); + + var firstBuild = await gateway.BuildIndexAsync(); + + registry.AddTool( + "local", + TestFunctionFactory.CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); + + var secondBuild = await gateway.BuildIndexAsync(); + + await Assert.That(firstBuild.ToolCount).IsEqualTo(1); + await Assert.That(secondBuild.ToolCount).IsEqualTo(2); + } + + [TUnit.Core.Test] + public async Task Registry_ConcurrentToolRegistrationRetainsAllTools() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(static _ => { }); + var gateway = serviceProvider.GetRequiredService(); + var registry = serviceProvider.GetRequiredService(); + + await Task.WhenAll(Enumerable.Range(0, 40).Select(index => Task.Run(() => + registry.AddTool( + "local", + TestFunctionFactory.CreateFunction( + SearchWeather, + $"weather_search_forecast_{index}", + $"Search weather forecast and temperature information for city {index}."))))); + + var buildResult = await gateway.BuildIndexAsync(); + + await Assert.That(buildResult.ToolCount).IsEqualTo(40); + } + + [TUnit.Core.Test] + public async Task AddManagedCodeMcpGateway_ResolvesRegistryAsSeparateService() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); + }); + + var gateway = serviceProvider.GetRequiredService(); + var registry = serviceProvider.GetRequiredService(); + + await Assert.That(ReferenceEquals(gateway, registry)).IsFalse(); + } + + [TUnit.Core.Test] + public async Task AddManagedCodeMcpGateway_RegistryAlsoActsAsCatalogSource() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(static _ => { }); + var registry = serviceProvider.GetRequiredService(); + + await Assert.That(registry).IsTypeOf(); + await Assert.That(registry is IMcpGatewayCatalogSource).IsTrue(); + } + + [TUnit.Core.Test] + public async Task McpGateway_ThrowsClearErrorWhenRegistryServiceIsMissing() + { + var services = new ServiceCollection(); + services.AddLogging(); + services.AddOptions(); + + await using var serviceProvider = services.BuildServiceProvider(); + var options = serviceProvider.GetRequiredService>(); + var logger = serviceProvider.GetRequiredService>(); + var loggerFactory = serviceProvider.GetRequiredService(); + + InvalidOperationException? exception = null; + try + { + _ = new McpGateway(serviceProvider, options, logger, loggerFactory); + } + catch (InvalidOperationException ex) + { + exception = ex; + } + + await Assert.That(exception).IsNotNull(); + await Assert.That(exception!.Message).Contains("AddManagedCodeMcpGateway"); + } + + [TUnit.Core.Test] + public async Task McpGateway_DoesNotExposeRegistryMutations() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(static _ => { }); + var gateway = serviceProvider.GetRequiredService(); + var registry = serviceProvider.GetRequiredService(); + + registry.AddTool( + "local", + TestFunctionFactory.CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); + + var tools = await gateway.ListToolsAsync(); + + await Assert.That(typeof(IMcpGatewayRegistry).IsAssignableFrom(gateway.GetType())).IsFalse(); + await Assert.That(tools.Count).IsEqualTo(1); + await Assert.That(tools.Single().ToolId).IsEqualTo("local:weather_search_forecast"); + } + + [TUnit.Core.Test] + public async Task Registry_RejectsMutationsAfterServiceProviderIsDisposed() + { + var serviceProvider = GatewayTestServiceProviderFactory.Create(static _ => { }); + var registry = serviceProvider.GetRequiredService(); + + await serviceProvider.DisposeAsync(); + + ObjectDisposedException? exception = null; + try + { + registry.AddTool( + "local", + TestFunctionFactory.CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); + } + catch (ObjectDisposedException ex) + { + exception = ex; + } + + await Assert.That(exception).IsNotNull(); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_RetriesFailedMcpClientFactoryOnNextBuild() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + + var attempts = 0; + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClientFactory( + "test-mcp", + _ => + { + attempts++; + if (attempts == 1) + { + throw new InvalidOperationException("temporary startup failure"); + } + + return ValueTask.FromResult(serverHost.Client); + }, + disposeClient: false); + }); + var gateway = serviceProvider.GetRequiredService(); + + var firstBuild = await gateway.BuildIndexAsync(); + var secondBuild = await gateway.BuildIndexAsync(); + + await Assert.That(attempts).IsEqualTo(2); + await Assert.That(firstBuild.ToolCount).IsEqualTo(0); + await Assert.That(firstBuild.Diagnostics.Any(static diagnostic => diagnostic.Code == "source_load_failed")).IsTrue(); + await Assert.That(secondBuild.ToolCount).IsEqualTo(3); + await Assert.That(secondBuild.Diagnostics.Any(static diagnostic => diagnostic.Code == "source_load_failed")).IsFalse(); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_ConcurrentCallsShareSingleBuild() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + + var attempts = 0; + var factoryStarted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var releaseFactory = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClientFactory( + "test-mcp", + async _ => + { + Interlocked.Increment(ref attempts); + factoryStarted.TrySetResult(null); + return await releaseFactory.Task; + }, + disposeClient: false); + }); + var gateway = serviceProvider.GetRequiredService(); + + var buildTasks = Enumerable.Range(0, 20) + .Select(_ => gateway.BuildIndexAsync()) + .ToArray(); + + await factoryStarted.Task; + releaseFactory.TrySetResult(serverHost.Client); + + var results = await Task.WhenAll(buildTasks); + + await Assert.That(attempts).IsEqualTo(1); + await Assert.That(results.All(static result => result.ToolCount == 3)).IsTrue(); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_CancelsUnderlyingSourceLoadAndAllowsRetry() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + + var attempts = 0; + var factoryStarted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClientFactory( + "test-mcp", + async cancellationToken => + { + var attempt = Interlocked.Increment(ref attempts); + if (attempt == 1) + { + factoryStarted.TrySetResult(null); + await Task.Delay(Timeout.InfiniteTimeSpan, cancellationToken); + } + + return serverHost.Client; + }, + disposeClient: false); + }); + var gateway = serviceProvider.GetRequiredService(); + using var cancellationSource = new CancellationTokenSource(); + + var firstBuildTask = gateway.BuildIndexAsync(cancellationSource.Token); + await factoryStarted.Task.WaitAsync(TimeSpan.FromSeconds(5)); + cancellationSource.Cancel(); + + OperationCanceledException? cancellationException = null; + try + { + await firstBuildTask.WaitAsync(TimeSpan.FromSeconds(5)); + } + catch (OperationCanceledException ex) + { + cancellationException = ex; + } + + var secondBuild = await gateway.BuildIndexAsync().WaitAsync(TimeSpan.FromSeconds(5)); + + await Assert.That(cancellationException).IsNotNull(); + await Assert.That(attempts).IsEqualTo(2); + await Assert.That(secondBuild.ToolCount).IsEqualTo(3); + } + + [TUnit.Core.Test] + public async Task ListToolsAsync_ExtractsRequiredArgumentsFromSerializedMcpSchema() + { + await using var serverHost = await TestMcpServerHost.StartAsync(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddMcpClient("test-mcp", serverHost.Client, disposeClient: false); + }); + var gateway = serviceProvider.GetRequiredService(); + + var tools = await gateway.ListToolsAsync(); + var descriptor = tools.Single(static tool => tool.ToolId == "test-mcp:github_repository_search"); + + await Assert.That(string.IsNullOrWhiteSpace(descriptor.InputSchemaJson)).IsFalse(); + await Assert.That(descriptor.RequiredArguments.Any(static argument => string.Equals(argument, "query", StringComparison.OrdinalIgnoreCase))).IsTrue(); + } + + [TUnit.Core.Test] + public async Task ListToolsAsync_BuildsIndexOnDemand() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); + var gateway = serviceProvider.GetRequiredService(); + + var tools = await gateway.ListToolsAsync(); + + await Assert.That(tools.Count).IsEqualTo(2); + await Assert.That(tools.Any(static tool => tool.ToolId == "local:github_search_issues")).IsTrue(); + await Assert.That(tools.Any(static tool => tool.ToolId == "local:weather_search_forecast")).IsTrue(); + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchEmbeddingStoreTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchEmbeddingStoreTests.cs new file mode 100644 index 0000000..71b7c49 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchEmbeddingStoreTests.cs @@ -0,0 +1,145 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewaySearchTests +{ + [TUnit.Core.Test] + public async Task BuildIndexAsync_ReusesStoredToolEmbeddingsOnNextBuild() + { + var embeddingStore = new McpGatewayInMemoryToolEmbeddingStore(); + var firstEmbeddingGenerator = new TestEmbeddingGenerator(); + + var firstBuildResult = await BuildSearchIndexAsync(embeddingStore, firstEmbeddingGenerator); + + await Assert.That(firstBuildResult.VectorizedToolCount).IsEqualTo(2); + await Assert.That(firstEmbeddingGenerator.Calls.Count).IsEqualTo(1); + + var secondEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + ThrowOnInput = static _ => true + }); + + await using var secondServiceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + secondEmbeddingGenerator, + embeddingStore); + var secondGateway = secondServiceProvider.GetRequiredService(); + + var secondBuildResult = await secondGateway.BuildIndexAsync(); + + await Assert.That(secondBuildResult.VectorizedToolCount).IsEqualTo(2); + await Assert.That(secondBuildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_failed")).IsFalse(); + await Assert.That(secondEmbeddingGenerator.Calls.Count).IsEqualTo(0); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_RegeneratesStoredToolEmbeddingsWhenGeneratorFingerprintChanges() + { + var embeddingStore = new TestToolEmbeddingStore(); + var firstEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + Metadata = new EmbeddingGeneratorMetadata( + "ManagedCode.MCPGateway.Tests", + new Uri("https://example.test"), + "test-embedding-a", + 21) + }); + + await SeedSearchEmbeddingsAsync(embeddingStore, firstEmbeddingGenerator); + + var secondEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + Metadata = new EmbeddingGeneratorMetadata( + "ManagedCode.MCPGateway.Tests", + new Uri("https://example.test"), + "test-embedding-b", + 21) + }); + + await using var secondServiceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + secondEmbeddingGenerator, + embeddingStore); + var secondGateway = secondServiceProvider.GetRequiredService(); + + var secondBuildResult = await secondGateway.BuildIndexAsync(); + + await Assert.That(secondBuildResult.VectorizedToolCount).IsEqualTo(2); + await Assert.That(secondEmbeddingGenerator.Calls.Count).IsEqualTo(1); + await Assert.That(secondEmbeddingGenerator.Calls[0].Count).IsEqualTo(2); + await Assert.That(embeddingStore.UpsertCalls.Count).IsEqualTo(2); + await Assert.That(embeddingStore.UpsertCalls[1].Count).IsEqualTo(2); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_DisablesVectorSearchWhenStoreHasVectorsButQueryGeneratorIsMissing() + { + var embeddingStore = new McpGatewayInMemoryToolEmbeddingStore(); + var initialEmbeddingGenerator = new TestEmbeddingGenerator(); + + await SeedSearchEmbeddingsAsync(embeddingStore, initialEmbeddingGenerator); + + await using var secondServiceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingStore: embeddingStore); + var secondGateway = secondServiceProvider.GetRequiredService(); + + var buildResult = await secondGateway.BuildIndexAsync(); + var searchResult = await secondGateway.SearchAsync("github pull requests", maxResults: 1); + + await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(2); + await Assert.That(buildResult.IsVectorSearchEnabled).IsFalse(); + await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_generator_missing")).IsTrue(); + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + } + + [TUnit.Core.Test] + public async Task BuildIndexAsync_GeneratesAndPersistsOnlyMissingStoredToolEmbeddings() + { + var embeddingStore = new TestToolEmbeddingStore(); + var initialEmbeddingGenerator = new TestEmbeddingGenerator(); + + await SeedSearchEmbeddingsAsync(embeddingStore, initialEmbeddingGenerator); + + embeddingStore.Remove("local:weather_search_forecast"); + + var incrementalEmbeddingGenerator = new TestEmbeddingGenerator(); + await using var incrementalServiceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + incrementalEmbeddingGenerator, + embeddingStore); + var incrementalGateway = incrementalServiceProvider.GetRequiredService(); + + var buildResult = await incrementalGateway.BuildIndexAsync(); + + await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(2); + await Assert.That(incrementalEmbeddingGenerator.Calls.Count).IsEqualTo(1); + await Assert.That(incrementalEmbeddingGenerator.Calls[0].Count).IsEqualTo(1); + await Assert.That(incrementalEmbeddingGenerator.Calls[0].Single().Contains("weather_search_forecast", StringComparison.Ordinal)).IsTrue(); + await Assert.That(embeddingStore.UpsertCalls.Count).IsEqualTo(2); + await Assert.That(embeddingStore.UpsertCalls[1].Count).IsEqualTo(1); + await Assert.That(embeddingStore.UpsertCalls[1].Single().ToolId).IsEqualTo("local:weather_search_forecast"); + } + + private static async Task BuildSearchIndexAsync( + IMcpGatewayToolEmbeddingStore embeddingStore, + IEmbeddingGenerator>? embeddingGenerator = null) + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator, + embeddingStore); + var gateway = serviceProvider.GetRequiredService(); + return await gateway.BuildIndexAsync(); + } + + private static async Task SeedSearchEmbeddingsAsync( + IMcpGatewayToolEmbeddingStore embeddingStore, + IEmbeddingGenerator> embeddingGenerator) + { + await BuildSearchIndexAsync(embeddingStore, embeddingGenerator); + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchLexicalTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchLexicalTests.cs new file mode 100644 index 0000000..994070d --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchLexicalTests.cs @@ -0,0 +1,258 @@ +using System.Collections; +using System.Text.Json; +using System.Text.Json.Nodes; +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewaySearchTests +{ + [TUnit.Core.Test] + public async Task SearchAsync_UsesContextDictionaryForLexicalFallback() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( + Query: "search", + MaxResults: 2, + Context: new Dictionary + { + ["page"] = "weather forecast", + ["intent"] = "temperature lookup" + })); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "lexical_fallback")).IsTrue(); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:weather_search_forecast"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_UsesSchemaTermsForLexicalFallback() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(FilterAdvisories, "advisory_lookup", "Lookup advisory records.")); + }); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("severity filter", maxResults: 1); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:advisory_lookup"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_DefaultAutoStrategyUsesTokenizerFallbackAndTopFiveLimit() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureDefaultAutoTokenizerFallbackTools); + var gateway = serviceProvider.GetRequiredService(); + + var searchResult = await gateway.SearchAsync("search"); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "lexical_fallback")).IsTrue(); + await Assert.That(searchResult.Matches.Count).IsEqualTo(5); + } + + [TUnit.Core.Test] + public async Task SearchAsync_DefaultAutoStrategyHandlesTypoHeavyQueryWithoutEmbeddings() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureDefaultAutoTokenizerFallbackTools); + var gateway = serviceProvider.GetRequiredService(); + + var searchResult = await gateway.SearchAsync("track shipmnt 1z999"); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "lexical_fallback")).IsTrue(); + await Assert.That(searchResult.Matches.Any(static match => match.ToolId == "local:commerce_shipping_tracking")).IsTrue(); + } + + [TUnit.Core.Test] + public async Task SearchAsync_UsesEnglishNormalizationWhenKeyedChatClientIsRegistered() + { + var chatClient = new TestChatClient(new TestChatClientOptions + { + RewriteQuery = static query => query.Contains("petit déjeuner", StringComparison.Ordinal) + ? "hotel with breakfast near city center" + : query + }); + + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureTravelTokenizerTools, + searchQueryChatClient: chatClient); + var gateway = serviceProvider.GetRequiredService(); + + var searchResult = await gateway.SearchAsync("trouver un hôtel avec petit déjeuner près du centre", maxResults: 1); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "query_normalized")).IsTrue(); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:travel_hotel_search"); + await Assert.That(chatClient.Calls.Count).IsEqualTo(1); + } + + [TUnit.Core.Test] + public async Task SearchAsync_FallsBackToOriginalQueryWhenNormalizationFails() + { + var chatClient = new TestChatClient(new TestChatClientOptions + { + ThrowOnInput = static query => query.Contains("demande de fusion", StringComparison.Ordinal) + }); + + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureTokenizerSearchToolsForNormalizationFallback, + searchQueryChatClient: chatClient); + var gateway = serviceProvider.GetRequiredService(); + + var searchResult = await gateway.SearchAsync("demande de fusion pour le depot managedcode", maxResults: 1); + + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "query_normalization_failed")).IsTrue(); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_pull_request_search"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_UsesNestedJsonAndEnumerableContextWhenQueryIsMissing() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( + Query: null, + MaxResults: 1, + ContextSummary: "user is browsing operational dashboards", + Context: new Dictionary + { + ["page"] = JsonSerializer.SerializeToElement(new + { + section = "forecast", + filters = new List { "temperature", "weekend" } + }), + ["intent"] = new JsonObject + { + ["category"] = "weather", + ["mode"] = "lookup" + }, + ["legacy"] = new Hashtable + { + ["location"] = "Paris", + ["active"] = true + }, + ["signals"] = new object?[] { "forecast", 5 } + })); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "lexical_fallback")).IsTrue(); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:weather_search_forecast"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_IgnoresUnserializableContextPayloads() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); + var gateway = serviceProvider.GetRequiredService(); + var cyclicContext = new CyclicContextPayload(); + cyclicContext.Self = cyclicContext; + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( + Query: "weather forecast", + MaxResults: 1, + Context: new Dictionary + { + ["broken"] = cyclicContext + })); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:weather_search_forecast"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_UsesBrowseModeWhenQueryAndContextAreMissing() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest()); + + await Assert.That(searchResult.RankingMode).IsEqualTo("browse"); + await Assert.That(searchResult.Matches.Count).IsEqualTo(2); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); + await Assert.That(searchResult.Matches[1].ToolId).IsEqualTo("local:weather_search_forecast"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_FallsBackWhenQueryEmbeddingFails() + { + var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + ThrowOnInput = static input => input.Contains("explode query", StringComparison.Ordinal) + }); + + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("explode query", maxResults: 1); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "vector_search_failed")).IsTrue(); + } + + [TUnit.Core.Test] + public async Task SearchAsync_FallsBackWhenQueryVectorIsEmpty() + { + var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + ReturnZeroVectorOnInput = static input => input.Contains("empty query vector", StringComparison.Ordinal) + }); + + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("empty query vector", maxResults: 1); + + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "query_vector_empty")).IsTrue(); + } + + private static void ConfigureDefaultAutoTokenizerFallbackTools(McpGatewayOptions options) + { + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "weather_air_quality_lookup", "Lookup air quality index, smoke exposure, and pollution levels for a location.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "commerce_shipping_tracking", "Track shipment status, carrier events, and delivery estimates.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "finance_invoice_search", "Find invoices by customer, invoice number, payment state, or due date.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "crm_contact_search", "Find CRM contacts by name, email, title, account, or segment.")); + } + + private static void ConfigureTravelTokenizerTools(McpGatewayOptions options) + { + options.SearchStrategy = McpGatewaySearchStrategy.Tokenizer; + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "travel_hotel_search", "Find hotels by city, district, amenities, breakfast, or cancellation policy.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "travel_itinerary_builder", "Build a travel itinerary with flights, stays, meetings, and transfer timing.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "travel_booking_lookup", "Lookup booking confirmation details, ticket numbers, and reservation status.")); + } + + private static void ConfigureTokenizerSearchToolsForNormalizationFallback(McpGatewayOptions options) + { + options.SearchStrategy = McpGatewaySearchStrategy.Tokenizer; + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_pull_request_search", "Search GitHub pull requests by repository, reviewer queue, review approvals, branch, or merge status. Aliases: merge request, demande de fusion.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_code_search", "Search GitHub source code for files, symbols, snippets, or API usages inside repositories.")); + } + + private sealed class CyclicContextPayload + { + public CyclicContextPayload? Self { get; set; } + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchTests.cs index 796a972..12d9398 100644 --- a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchTests.cs +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchTests.cs @@ -1,536 +1,17 @@ using System.ComponentModel; -using ManagedCode.MCPGateway.Abstractions; - using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; - -using ModelContextProtocol.Client; namespace ManagedCode.MCPGateway.Tests; -public sealed class McpGatewaySearchTests +public sealed partial class McpGatewaySearchTests { - [TUnit.Core.Test] - public async Task BuildIndexAsync_VectorizesToolDescriptorsAndCapturesSemanticDocuments() - { - var embeddingGenerator = new TestEmbeddingGenerator(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - var buildResult = await gateway.BuildIndexAsync(); - - await Assert.That(buildResult.ToolCount).IsEqualTo(2); - await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(2); - await Assert.That(buildResult.IsVectorSearchEnabled).IsTrue(); - await Assert.That(embeddingGenerator.Calls.Count).IsEqualTo(1); - await Assert.That(embeddingGenerator.Calls[0].Count).IsEqualTo(2); - await Assert.That(embeddingGenerator.Calls[0].Any(static text => - text.Contains("github_search_issues", StringComparison.Ordinal) && - text.Contains("Search GitHub issues and pull requests by user query.", StringComparison.Ordinal))).IsTrue(); - await Assert.That(embeddingGenerator.Calls[0].Any(static text => - text.Contains("weather_search_forecast", StringComparison.Ordinal) && - text.Contains("Search weather forecast and temperature information by city name.", StringComparison.Ordinal))).IsTrue(); - } - - [TUnit.Core.Test] - public async Task SearchAsync_RanksLocalToolsWithEmbeddings() - { - var embeddingGenerator = new TestEmbeddingGenerator(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 2); - - await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); - await Assert.That(searchResult.Matches[0].Score >= searchResult.Matches[1].Score).IsTrue(); - await Assert.That(embeddingGenerator.Calls.Count).IsEqualTo(2); - await Assert.That(embeddingGenerator.Calls[1].Single()).IsEqualTo("github pull requests"); - } - - [TUnit.Core.Test] - public async Task SearchAsync_UsesContextSummaryInEmbeddingQuery() - { - var embeddingGenerator = new TestEmbeddingGenerator(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( - Query: "search", - MaxResults: 2, - ContextSummary: "github pull requests")); - - await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); - await Assert.That(embeddingGenerator.Calls[1].Single().Contains("context summary: github pull requests", StringComparison.Ordinal)).IsTrue(); - } - - [TUnit.Core.Test] - public async Task SearchAsync_UsesContextOnlyInputWhenQueryIsMissing() - { - var embeddingGenerator = new TestEmbeddingGenerator(); - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( - ContextSummary: "weather forecast", - MaxResults: 1)); - - await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:weather_search_forecast"); - await Assert.That(embeddingGenerator.Calls[1].Single()).IsEqualTo("context summary: weather forecast"); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_ReusesStoredToolEmbeddingsOnNextBuild() - { - var embeddingStore = new McpGatewayInMemoryToolEmbeddingStore(); - var firstEmbeddingGenerator = new TestEmbeddingGenerator(); - - await using (var firstServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - firstEmbeddingGenerator, - embeddingStore)) - { - var gateway = firstServiceProvider.GetRequiredService(); - var buildResult = await gateway.BuildIndexAsync(); - - await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(2); - await Assert.That(firstEmbeddingGenerator.Calls.Count).IsEqualTo(1); - } - - var secondEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - ThrowOnInput = static _ => true - }); - - await using var secondServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - secondEmbeddingGenerator, - embeddingStore); - var secondGateway = secondServiceProvider.GetRequiredService(); - - var secondBuildResult = await secondGateway.BuildIndexAsync(); - - await Assert.That(secondBuildResult.VectorizedToolCount).IsEqualTo(2); - await Assert.That(secondBuildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_failed")).IsFalse(); - await Assert.That(secondEmbeddingGenerator.Calls.Count).IsEqualTo(0); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_RegeneratesStoredToolEmbeddingsWhenGeneratorFingerprintChanges() - { - var embeddingStore = new TestToolEmbeddingStore(); - var firstEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - Metadata = new EmbeddingGeneratorMetadata( - "ManagedCode.MCPGateway.Tests", - new Uri("https://example.test"), - "test-embedding-a", - 21) - }); - - await using (var firstServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - firstEmbeddingGenerator, - embeddingStore)) - { - var gateway = firstServiceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - } - - var secondEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - Metadata = new EmbeddingGeneratorMetadata( - "ManagedCode.MCPGateway.Tests", - new Uri("https://example.test"), - "test-embedding-b", - 21) - }); - - await using var secondServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - secondEmbeddingGenerator, - embeddingStore); - var secondGateway = secondServiceProvider.GetRequiredService(); - - var secondBuildResult = await secondGateway.BuildIndexAsync(); - - await Assert.That(secondBuildResult.VectorizedToolCount).IsEqualTo(2); - await Assert.That(secondEmbeddingGenerator.Calls.Count).IsEqualTo(1); - await Assert.That(secondEmbeddingGenerator.Calls[0].Count).IsEqualTo(2); - await Assert.That(embeddingStore.UpsertCalls.Count).IsEqualTo(2); - await Assert.That(embeddingStore.UpsertCalls[1].Count).IsEqualTo(2); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_DisablesVectorSearchWhenStoreHasVectorsButQueryGeneratorIsMissing() - { - var embeddingStore = new McpGatewayInMemoryToolEmbeddingStore(); - var initialEmbeddingGenerator = new TestEmbeddingGenerator(); - - await using (var initialServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - initialEmbeddingGenerator, - embeddingStore)) - { - var gateway = initialServiceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - } - - await using var secondServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingStore: embeddingStore); - var secondGateway = secondServiceProvider.GetRequiredService(); - - var buildResult = await secondGateway.BuildIndexAsync(); - var searchResult = await secondGateway.SearchAsync("github pull requests", maxResults: 1); - - await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(2); - await Assert.That(buildResult.IsVectorSearchEnabled).IsFalse(); - await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_generator_missing")).IsTrue(); - await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_GeneratesAndPersistsOnlyMissingStoredToolEmbeddings() - { - var embeddingStore = new TestToolEmbeddingStore(); - var initialEmbeddingGenerator = new TestEmbeddingGenerator(); - - await using (var initialServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - initialEmbeddingGenerator, - embeddingStore)) - { - var gateway = initialServiceProvider.GetRequiredService(); - await gateway.BuildIndexAsync(); - } - - embeddingStore.Remove("local:weather_search_forecast"); - - var incrementalEmbeddingGenerator = new TestEmbeddingGenerator(); - await using var incrementalServiceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - incrementalEmbeddingGenerator, - embeddingStore); - var incrementalGateway = incrementalServiceProvider.GetRequiredService(); - - var buildResult = await incrementalGateway.BuildIndexAsync(); - - await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(2); - await Assert.That(incrementalEmbeddingGenerator.Calls.Count).IsEqualTo(1); - await Assert.That(incrementalEmbeddingGenerator.Calls[0].Count).IsEqualTo(1); - await Assert.That(incrementalEmbeddingGenerator.Calls[0].Single().Contains("weather_search_forecast", StringComparison.Ordinal)).IsTrue(); - await Assert.That(embeddingStore.UpsertCalls.Count).IsEqualTo(2); - await Assert.That(embeddingStore.UpsertCalls[1].Count).IsEqualTo(1); - await Assert.That(embeddingStore.UpsertCalls[1].Single().ToolId).IsEqualTo("local:weather_search_forecast"); - } - - [TUnit.Core.Test] - public async Task SearchAsync_PrefersKeyedEmbeddingGeneratorOverUnkeyedRegistration() - { - var keyedEmbeddingGenerator = new TestEmbeddingGenerator(); - var fallbackEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - ThrowOnInput = static _ => true - }); - - var services = new ServiceCollection(); - services.AddLogging(static logging => logging.SetMinimumLevel(LogLevel.Debug)); - services.AddSingleton>>(fallbackEmbeddingGenerator); - services.AddKeyedSingleton>>( - McpGatewayServiceKeys.EmbeddingGenerator, - keyedEmbeddingGenerator); - services.AddManagedCodeMcpGateway(ConfigureSearchTools); - - await using var serviceProvider = services.BuildServiceProvider(); - var gateway = serviceProvider.GetRequiredService(); - - var buildResult = await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 2); - - await Assert.That(buildResult.IsVectorSearchEnabled).IsTrue(); - await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); - await Assert.That(keyedEmbeddingGenerator.Calls.Count).IsEqualTo(2); - await Assert.That(fallbackEmbeddingGenerator.Calls.Count).IsEqualTo(0); - } - - [TUnit.Core.Test] - public async Task SearchAsync_ResolvesScopedEmbeddingGeneratorPerOperation() - { - var tracker = new ScopedEmbeddingGeneratorTracker(); - var services = new ServiceCollection(); - services.AddLogging(static logging => logging.SetMinimumLevel(LogLevel.Debug)); - services.AddScoped>>(_ => new ScopedTestEmbeddingGenerator(tracker)); - services.AddManagedCodeMcpGateway(ConfigureSearchTools); - - await using var serviceProvider = services.BuildServiceProvider(new ServiceProviderOptions - { - ValidateScopes = true - }); - var gateway = serviceProvider.GetRequiredService(); - - var buildResult = await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 2); - - await Assert.That(buildResult.IsVectorSearchEnabled).IsTrue(); - await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); - await Assert.That(tracker.InstanceIds.Distinct().Count()).IsEqualTo(2); - await Assert.That(tracker.Calls.Count).IsEqualTo(2); - } - - [TUnit.Core.Test] - public async Task SearchAsync_UsesContextDictionaryForLexicalFallback() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( - Query: "search", - MaxResults: 2, - Context: new Dictionary - { - ["page"] = "weather forecast", - ["intent"] = "temperature lookup" - })); - - await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); - await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "lexical_fallback")).IsTrue(); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:weather_search_forecast"); - } - - [TUnit.Core.Test] - public async Task SearchAsync_UsesSchemaTermsForLexicalFallback() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("local", CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); - options.AddTool("local", CreateFunction(FilterAdvisories, "advisory_lookup", "Lookup advisory records.")); - }); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync("severity filter", maxResults: 1); - - await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:advisory_lookup"); - } - - [TUnit.Core.Test] - public async Task SearchAsync_UsesBrowseModeWhenQueryAndContextAreMissing() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest()); - - await Assert.That(searchResult.RankingMode).IsEqualTo("browse"); - await Assert.That(searchResult.Matches.Count).IsEqualTo(2); - await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); - await Assert.That(searchResult.Matches[1].ToolId).IsEqualTo("local:weather_search_forecast"); - } - - [TUnit.Core.Test] - public async Task SearchAsync_FallsBackWhenQueryEmbeddingFails() - { - var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - ThrowOnInput = static input => input.Contains("explode query", StringComparison.Ordinal) - }); - - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync("explode query", maxResults: 1); - - await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); - await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "vector_search_failed")).IsTrue(); - } - - [TUnit.Core.Test] - public async Task SearchAsync_FallsBackWhenQueryVectorIsEmpty() - { - var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - ReturnZeroVectorOnInput = static input => input.Contains("empty query vector", StringComparison.Ordinal) - }); - - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync("empty query vector", maxResults: 1); - - await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); - await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "query_vector_empty")).IsTrue(); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_ReportsEmbeddingCountMismatch() - { - var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - ReturnMismatchedBatchCount = true - }); - - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - var buildResult = await gateway.BuildIndexAsync(); - var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 1); - - await Assert.That(buildResult.IsVectorSearchEnabled).IsFalse(); - await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(0); - await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_count_mismatch")).IsTrue(); - await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_ReportsEmbeddingFailure() - { - var embeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions - { - ThrowOnInput = static _ => true - }); - - await using var serviceProvider = GatewayTestServiceProviderFactory.Create( - ConfigureSearchTools, - embeddingGenerator); - var gateway = serviceProvider.GetRequiredService(); - - var buildResult = await gateway.BuildIndexAsync(); - - await Assert.That(buildResult.IsVectorSearchEnabled).IsFalse(); - await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(0); - await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "embedding_failed")).IsTrue(); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_SkipsDuplicateToolIds() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("local", CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); - options.AddTool("local", CreateFunction(SearchGitHubAgain, "github_search_issues", "Duplicate tool id for test coverage.")); - }); - var gateway = serviceProvider.GetRequiredService(); - - var buildResult = await gateway.BuildIndexAsync(); - - await Assert.That(buildResult.ToolCount).IsEqualTo(1); - await Assert.That(buildResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "duplicate_tool_id")).IsTrue(); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_RebuildsAfterNewToolIsRegistered() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddTool("local", CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); - }); - - var gateway = serviceProvider.GetRequiredService(); - var registry = serviceProvider.GetRequiredService(); - - var firstBuild = await gateway.BuildIndexAsync(); - - registry.AddTool( - "local", - CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); - - var secondBuild = await gateway.BuildIndexAsync(); - - await Assert.That(firstBuild.ToolCount).IsEqualTo(1); - await Assert.That(secondBuild.ToolCount).IsEqualTo(2); - } - - [TUnit.Core.Test] - public async Task BuildIndexAsync_RetriesFailedMcpClientFactoryOnNextBuild() - { - await using var serverHost = await TestMcpServerHost.StartAsync(); - - var attempts = 0; - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => - { - options.AddMcpClientFactory( - "test-mcp", - async _ => - { - attempts++; - if (attempts == 1) - { - throw new InvalidOperationException("temporary startup failure"); - } - - return serverHost.Client; - }, - disposeClient: false); - }); - var gateway = serviceProvider.GetRequiredService(); - - var firstBuild = await gateway.BuildIndexAsync(); - var secondBuild = await gateway.BuildIndexAsync(); - - await Assert.That(attempts).IsEqualTo(2); - await Assert.That(firstBuild.ToolCount).IsEqualTo(0); - await Assert.That(firstBuild.Diagnostics.Any(static diagnostic => diagnostic.Code == "source_load_failed")).IsTrue(); - await Assert.That(secondBuild.ToolCount).IsEqualTo(3); - await Assert.That(secondBuild.Diagnostics.Any(static diagnostic => diagnostic.Code == "source_load_failed")).IsFalse(); - } - - [TUnit.Core.Test] - public async Task ListToolsAsync_BuildsIndexOnDemand() - { - await using var serviceProvider = GatewayTestServiceProviderFactory.Create(ConfigureSearchTools); - var gateway = serviceProvider.GetRequiredService(); - - var tools = await gateway.ListToolsAsync(); - - await Assert.That(tools.Count).IsEqualTo(2); - await Assert.That(tools.Any(static tool => tool.ToolId == "local:github_search_issues")).IsTrue(); - await Assert.That(tools.Any(static tool => tool.ToolId == "local:weather_search_forecast")).IsTrue(); - } - private static void ConfigureSearchTools(McpGatewayOptions options) { - options.AddTool("local", CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); - options.AddTool("local", CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query.")); + options.AddTool("local", TestFunctionFactory.CreateFunction(SearchWeather, "weather_search_forecast", "Search weather forecast and temperature information by city name.")); } - private static AIFunction CreateFunction(Delegate callback, string name, string description) - => AIFunctionFactory.Create( - callback, - new AIFunctionFactoryOptions - { - Name = name, - Description = description - }); - private static string SearchGitHub([Description("Search query text.")] string query) => $"github:{query}"; private static string SearchGitHubAgain([Description("Search query text.")] string query) => $"github-duplicate:{query}"; diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchVectorTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchVectorTests.cs new file mode 100644 index 0000000..71bc500 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchVectorTests.cs @@ -0,0 +1,146 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewaySearchTests +{ + [TUnit.Core.Test] + public async Task BuildIndexAsync_VectorizesToolDescriptorsAndCapturesSemanticDocuments() + { + var embeddingGenerator = new TestEmbeddingGenerator(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + + await Assert.That(buildResult.ToolCount).IsEqualTo(2); + await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(2); + await Assert.That(buildResult.IsVectorSearchEnabled).IsTrue(); + await Assert.That(embeddingGenerator.Calls.Count).IsEqualTo(1); + await Assert.That(embeddingGenerator.Calls[0].Count).IsEqualTo(2); + await Assert.That(embeddingGenerator.Calls[0].Any(static text => + text.Contains("github_search_issues", StringComparison.Ordinal) && + text.Contains("Search GitHub issues and pull requests by user query.", StringComparison.Ordinal))).IsTrue(); + await Assert.That(embeddingGenerator.Calls[0].Any(static text => + text.Contains("weather_search_forecast", StringComparison.Ordinal) && + text.Contains("Search weather forecast and temperature information by city name.", StringComparison.Ordinal))).IsTrue(); + } + + [TUnit.Core.Test] + public async Task SearchAsync_RanksLocalToolsWithEmbeddings() + { + var embeddingGenerator = new TestEmbeddingGenerator(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 2); + + await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); + await Assert.That(searchResult.Matches[0].Score >= searchResult.Matches[1].Score).IsTrue(); + await Assert.That(embeddingGenerator.Calls.Count).IsEqualTo(2); + await Assert.That(embeddingGenerator.Calls[1].Single()).IsEqualTo("github pull requests"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_UsesContextSummaryInEmbeddingQuery() + { + var embeddingGenerator = new TestEmbeddingGenerator(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( + Query: "search", + MaxResults: 2, + ContextSummary: "github pull requests")); + + await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); + await Assert.That(embeddingGenerator.Calls[1].Single().Contains("context summary: github pull requests", StringComparison.Ordinal)).IsTrue(); + } + + [TUnit.Core.Test] + public async Task SearchAsync_UsesContextOnlyInputWhenQueryIsMissing() + { + var embeddingGenerator = new TestEmbeddingGenerator(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync(new McpGatewaySearchRequest( + ContextSummary: "weather forecast", + MaxResults: 1)); + + await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:weather_search_forecast"); + await Assert.That(embeddingGenerator.Calls[1].Single()).IsEqualTo("context summary: weather forecast"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_PrefersKeyedEmbeddingGeneratorOverUnkeyedRegistration() + { + var keyedEmbeddingGenerator = new TestEmbeddingGenerator(); + var fallbackEmbeddingGenerator = new TestEmbeddingGenerator(new TestEmbeddingGeneratorOptions + { + ThrowOnInput = static _ => true + }); + + var services = new ServiceCollection(); + services.AddLogging(static logging => logging.SetMinimumLevel(LogLevel.Debug)); + services.AddSingleton>>(fallbackEmbeddingGenerator); + services.AddKeyedSingleton>>( + McpGatewayServiceKeys.EmbeddingGenerator, + keyedEmbeddingGenerator); + services.AddManagedCodeMcpGateway(ConfigureSearchTools); + + await using var serviceProvider = services.BuildServiceProvider(); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 2); + + await Assert.That(buildResult.IsVectorSearchEnabled).IsTrue(); + await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); + await Assert.That(keyedEmbeddingGenerator.Calls.Count).IsEqualTo(2); + await Assert.That(fallbackEmbeddingGenerator.Calls.Count).IsEqualTo(0); + } + + [TUnit.Core.Test] + public async Task SearchAsync_ResolvesScopedEmbeddingGeneratorPerOperation() + { + var tracker = new ScopedEmbeddingGeneratorTracker(); + var services = new ServiceCollection(); + services.AddLogging(static logging => logging.SetMinimumLevel(LogLevel.Debug)); + services.AddScoped>>(_ => new ScopedTestEmbeddingGenerator(tracker)); + services.AddManagedCodeMcpGateway(ConfigureSearchTools); + + await using var serviceProvider = services.BuildServiceProvider(new ServiceProviderOptions + { + ValidateScopes = true + }); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 2); + + await Assert.That(buildResult.IsVectorSearchEnabled).IsTrue(); + await Assert.That(searchResult.RankingMode).IsEqualTo("vector"); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); + await Assert.That(tracker.InstanceIds.Distinct().Count()).IsEqualTo(2); + await Assert.That(tracker.Calls.Count).IsEqualTo(2); + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationCatalog.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationCatalog.cs new file mode 100644 index 0000000..11a90f6 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationCatalog.cs @@ -0,0 +1,153 @@ +using System.ComponentModel; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewayTokenizerSearchEvaluationTests +{ + // The catalog intentionally reuses common operation words such as search, lookup, + // timeline, and summary across different domains. The descriptions carry the + // domain-specific separation so tokenizer search is evaluated on semantics, not luck. + private static readonly EvaluationToolSpec[] EvaluationTools = + [ + .. CreateSpecs( + GitHubTool, + ("github_repository_search", "Find GitHub repositories by owner, topic, star range, or programming language."), + ("github_issue_search", "Find GitHub issues by repository, label, milestone, or bug triage keywords."), + ("github_pull_request_search", "Search GitHub pull requests by repository, reviewer queue, review approvals, branch, or merge status. Aliases: merge request, repo review queue, demande de fusion, рев'ю пулреквестів."), + ("github_release_notes_lookup", "Lookup GitHub release notes, tags, and changelog entries for a repository."), + ("github_code_search", "Search GitHub source code for files, symbols, snippets, or API usages inside repositories.")), + .. CreateSpecs( + WeatherTool, + ("weather_current_conditions", "Get current weather conditions, temperature, and feels-like values for a place."), + ("weather_forecast_lookup", "Get weather forecast for a city, including rain probability, temperature trend, and wind."), + ("weather_alert_search", "Find severe weather alerts, storm warnings, and hazard bulletins for a region. Aliases: alerta de tormenta, alerte tempete, штормове попередження."), + ("weather_air_quality_lookup", "Lookup air quality index, smoke exposure, and pollution levels for a location."), + ("weather_historical_climate", "Review historical climate averages, temperature normals, and past precipitation patterns.")), + .. CreateSpecs( + CalendarTool, + ("calendar_find_free_slots", "Find who can meet, meeting availability, open calendar slots, and free time windows for attendees. Aliases: free slot, creneau libre, вільний слот."), + ("calendar_create_event", "Create a calendar event with attendees, title, time window, and location."), + ("calendar_reschedule_event", "Reschedule or move an existing meeting to a new date, time window, or attendee availability."), + ("calendar_cancel_event", "Cancel a calendar event and notify participants or owners."), + ("calendar_list_daily_agenda", "List daily agenda items, meetings, reminders, and upcoming schedule blocks. Aliases: agenda du jour, порядок денний.")), + .. CreateSpecs( + FilesystemTool, + ("filesystem_find_files", "Find or locate files, PDFs, and documents by folder, glob pattern, extension, or text content. Not for invoice payment status or financial reconciliation."), + ("filesystem_read_file", "Read file contents from a path, document, or note inside a workspace."), + ("filesystem_write_file", "Write or replace file contents at a target path or note."), + ("filesystem_move_file", "Move, rename, archive, or relocate files between folders."), + ("filesystem_list_directory", "List directory contents, recent files, and nested folders.")), + .. CreateSpecs( + SupportTool, + ("support_ticket_search", "Find support tickets by customer, issue summary, severity, or product area."), + ("support_ticket_create", "Create a new support ticket with customer details, summary, severity, and product area."), + ("support_ticket_update", "Update support ticket status, owner, severity, or resolution notes."), + ("support_sla_lookup", "Lookup support SLA response times, escalation policy, and entitlement by plan."), + ("support_customer_timeline", "Review the support timeline, escalations, and prior issues for a customer.")), + .. CreateSpecs( + FinanceTool, + ("finance_exchange_rate_lookup", "Lookup currency exchange rates, FX conversions, and quote timestamps. Aliases: tipo de cambio, taux de change, валютний курс."), + ("finance_invoice_search", "Find invoices, bills, and billing records by customer, invoice number, payment state, or due date. Aliases: facture, factura, рахунок."), + ("finance_payment_reconciliation", "Reconcile payments, settlements, and bank references to invoices."), + ("finance_refund_lookup", "Lookup refund requests, refund status, order reimbursement, payout status, and reversal references."), + ("finance_tax_summary", "Summarize tax amounts, VAT, and filing totals by period or jurisdiction.")), + .. CreateSpecs( + TravelTool, + ("travel_flight_search", "Search flights by origin, destination, travel date, airline, or cabin preference."), + ("travel_hotel_search", "Find hotels by city, district, amenities, breakfast, or cancellation policy. Aliases: hotel, hôtel, готель."), + ("travel_booking_lookup", "Lookup booking confirmation details, ticket numbers, and reservation status."), + ("travel_itinerary_builder", "Build a travel itinerary with flights, stays, meetings, and transfer timing."), + ("travel_ground_transport_lookup", "Find airport transfer, train, taxi, or metro options for a route.")), + .. CreateSpecs( + SecurityTool, + ("security_vulnerability_search", "Find vulnerabilities, CVEs, package advisories, or image scan findings."), + ("security_secret_rotation", "Rotate secrets, API keys, credentials, or certificates for a target system."), + ("security_access_review", "Review user access, permissions, and privileged roles for a system or team."), + ("security_audit_log_search", "Search audit logs for sign-ins, admin actions, and sensitive operations."), + ("security_incident_timeline", "Review security incident timeline, responders, and evidence events.")), + .. CreateSpecs( + CrmTool, + ("crm_contact_search", "Find CRM contacts and contact details by name, email, title, account, or segment. Aliases: contacto, contact, контакт."), + ("crm_company_lookup", "Lookup CRM company records, account details, industry, or territory."), + ("crm_deal_pipeline_search", "Search deal pipeline by account, stage, owner, amount, or forecast."), + ("crm_activity_timeline", "Review CRM activity timeline including recent touches, calls, emails, meetings, and notes."), + ("crm_lead_enrichment", "Enrich leads with company facts, role context, and routing signals.")), + .. CreateSpecs( + CommerceTool, + ("commerce_catalog_search", "Search product catalog by keyword, category, brand, or attribute filters."), + ("commerce_inventory_lookup", "Lookup inventory, stock by SKU, warehouse balance, and availability."), + ("commerce_order_search", "Find customer orders by email, order number, payment status, or channel."), + ("commerce_return_lookup", "Lookup return requests, RMA status, reasons, and refund linkage."), + ("commerce_shipping_tracking", "Track shipment status, package tracking, parcel events, carrier scans, and delivery estimates. Aliases: suivi de colis, seguimiento de envío, відстеження посилки.")) + ]; + + private static string GitHubTool( + [Description("Repository owner or organization handle.")] string owner, + [Description("Repository name, topic, or code area to inspect.")] string repository, + [Description("Search words, labels, reviewers, tags, or symbol names.")] string query, + [Description("Desired work item state such as open, closed, or merged.")] WorkItemState state) + => $"{owner}:{repository}:{query}:{state}"; + + private static string WeatherTool( + [Description("City, region, airport code, or destination name.")] string location, + [Description("Time range such as now, today, weekend, or next 5 days.")] string timeRange, + [Description("Weather focus such as rain, wind, storms, smoke, or pollution.")] string focus, + [Description("Preferred temperature unit.")] TemperatureUnit unit) + => $"{location}:{timeRange}:{focus}:{unit}"; + + private static string CalendarTool( + [Description("Person, attendee, or room involved in the meeting request.")] string attendee, + [Description("Date or day phrase such as today, tomorrow, or next friday.")] string date, + [Description("Time window such as morning, afternoon, or 14:00-16:00.")] string timeWindow, + [Description("Meeting title, agenda, or subject to create or change.")] string subject) + => $"{attendee}:{date}:{timeWindow}:{subject}"; + + private static string FilesystemTool( + [Description("Root folder, workspace path, or directory to inspect.")] string rootPath, + [Description("File name, glob pattern, extension, or exact target.")] string filePattern, + [Description("Text to read, write, match, or move alongside the file operation.")] string contentOrDestination, + [Description("Filesystem intent such as find, read, write, move, or list.")] FileIntent intent) + => $"{rootPath}:{filePattern}:{contentOrDestination}:{intent}"; + + private static string SupportTool( + [Description("Customer account, company, or requester name.")] string customer, + [Description("Issue summary, symptom, or ticket subject to investigate.")] string issueQuery, + [Description("Requested severity or urgency for the support workflow.")] TicketSeverity severity, + [Description("Product area, service, or plan linked to the support request.")] string productArea) + => $"{customer}:{issueQuery}:{severity}:{productArea}"; + + private static string FinanceTool( + [Description("Customer, vendor, or ledger subject linked to the finance request.")] string party, + [Description("Invoice number, refund reference, currency pair, or tax period.")] string reference, + [Description("Money amount, due state, or reconciliation clue such as unpaid or settled.")] string amountOrState, + [Description("Finance operation such as invoice, refund, exchange, reconciliation, or tax.")] FinanceIntent intent) + => $"{party}:{reference}:{amountOrState}:{intent}"; + + private static string TravelTool( + [Description("Departure city, airport, or transfer starting point.")] string origin, + [Description("Destination city, hotel district, venue, or route target.")] string destination, + [Description("Travel date, stay window, or itinerary timing.")] string travelDate, + [Description("Preference such as nonstop, breakfast, cancellation, or rail transfer.")] string preference) + => $"{origin}:{destination}:{travelDate}:{preference}"; + + private static string SecurityTool( + [Description("Host, image, system, account, or secret target to inspect.")] string asset, + [Description("Timeframe, event window, or audit period to search.")] string timeWindow, + [Description("Severity level, sensitivity, or privilege scope.")] string severityOrScope, + [Description("Security operation such as vulnerability, rotation, access, audit, or incident.")] string securityIntent) + => $"{asset}:{timeWindow}:{severityOrScope}:{securityIntent}"; + + private static string CrmTool( + [Description("Contact, company, account, or lead to search or enrich.")] string entity, + [Description("Email, stage, territory, or identifying CRM qualifier.")] string qualifier, + [Description("Time range for activity history, pipeline review, or follow-up window.")] string timeWindow, + [Description("CRM workflow such as contact lookup, activity review, deal search, or enrichment.")] string crmIntent) + => $"{entity}:{qualifier}:{timeWindow}:{crmIntent}"; + + private static string CommerceTool( + [Description("SKU, order number, tracking id, or return reference.")] string orderOrSku, + [Description("Customer email, buyer name, or account identifier.")] string customer, + [Description("Sales channel, warehouse, or current status like shipped or returned.")] string statusOrLocation, + [Description("Commerce workflow such as catalog, inventory, order, return, or tracking.")] string commerceIntent) + => $"{orderOrSku}:{customer}:{statusOrLocation}:{commerceIntent}"; +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationQueries.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationQueries.cs new file mode 100644 index 0000000..c56fd4c --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationQueries.cs @@ -0,0 +1,89 @@ +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewayTokenizerSearchEvaluationTests +{ + private static readonly EvaluationQuerySpec[] HighRelevanceQueries = + [ + new("developer review queue for managedcode repo", "github_pull_request_search"), + new("latest changelog tag for repository", "github_release_notes_lookup"), + new("where is AddManagedCodeMcpGateway defined", "github_code_search"), + new("rain this weekend in paris", "weather_forecast_lookup"), + new("smoke and pollution level in warsaw", "weather_air_quality_lookup"), + new("storm warning near berlin tonight", "weather_alert_search"), + new("who can meet tomorrow after lunch", "calendar_find_free_slots"), + new("move design sync to friday morning", "calendar_reschedule_event"), + new("show my agenda for today", "calendar_list_daily_agenda"), + new("where did the invoice pdf go in reports", "filesystem_find_files"), + new("open docs readme file", "filesystem_read_file"), + new("rename archive.zip into backups folder", "filesystem_move_file"), + new("enterprise customer tickets about login timeout", "support_ticket_search"), + new("what sla does platinum support get", "support_sla_lookup"), + new("unpaid invoice for acme", "finance_invoice_search"), + new("eur usd exchange rate today", "finance_exchange_rate_lookup"), + new("cheap nonstop flight paris berlin", "travel_flight_search"), + new("hotel near conference center with breakfast", "travel_hotel_search"), + new("critical cves in nginx image", "security_vulnerability_search"), + new("who accessed admin panel yesterday", "security_audit_log_search"), + new("find anna@example.com in contacts", "crm_contact_search"), + new("track shipment 1z999", "commerce_shipping_tracking"), + new("is sku keyboard-ergo available in warehouse", "commerce_inventory_lookup") + ]; + + private static readonly EvaluationQuerySpec[] BorderlineQueries = + [ + new("status of refund for order 1488", "finance_refund_lookup", "commerce_return_lookup"), + new("recent touches for contoso account", "crm_activity_timeline", "support_customer_timeline"), + new("customer timeline with escalations for contoso", "support_customer_timeline", "crm_activity_timeline"), + new("meeting notes and customer calls for contoso", "crm_activity_timeline", "calendar_list_daily_agenda"), + new("order payment issue with refund history", "finance_refund_lookup", "commerce_return_lookup", "support_ticket_search") + ]; + + private static readonly EvaluationQuerySpec[] MultilingualQueries = + [ + new("demande de fusion pour le depot managedcode", "github_pull_request_search"), + new("alerta de tormenta en berlin esta noche", "weather_alert_search"), + new("хто має вільний слот завтра після обіду", "calendar_find_free_slots"), + new("рахунок для клієнта acme", "finance_invoice_search"), + new("trouver un hôtel avec petit déjeuner près du centre", "travel_hotel_search"), + new("suivi de colis 1z999", "commerce_shipping_tracking"), + new("buscar contacto anna@example.com", "crm_contact_search"), + new("порядок денний на сьогодні", "calendar_list_daily_agenda") + ]; + + private static readonly EvaluationQuerySpec[] TypoQueries = + [ + new("review qeue for managedcode prs", "github_pull_request_search"), + new("weather forcast rain in paris weekend", "weather_forecast_lookup"), + new("whos free tomorrow afternon", "calendar_find_free_slots"), + new("open the readme fie in docs", "filesystem_read_file"), + new("unpaid invoce for acme", "finance_invoice_search"), + new("track shipmnt 1z999", "commerce_shipping_tracking"), + new("critical cve in nginx imgae", "security_vulnerability_search"), + new("contcat anna@example.com", "crm_contact_search") + ]; + + private static readonly EvaluationQuerySpec[] WeakIntentQueries = + [ + new("managedcode review approvals", "github_pull_request_search"), + new("air bad in warsaw", "weather_air_quality_lookup"), + new("who is free after lunch tomorrow", "calendar_find_free_slots"), + new("invoice pdf in reports", "filesystem_find_files", "finance_invoice_search"), + new("money back for order 1488", "finance_refund_lookup", "commerce_return_lookup"), + new("where is the parcel", "commerce_shipping_tracking"), + new("admin actions yesterday", "security_audit_log_search"), + new("anna contact details", "crm_contact_search"), + new("hotel breakfast near venue berlin", "travel_hotel_search") + ]; + + private static readonly string[] IrrelevantQueries = + [ + "best sourdough hydration ratio for rye bread", + "how to tune a jazz drum solo for swing practice", + "dragon habitat migration pattern in fantasy novels", + "origami crane folding sequence for beginners", + "volcanic lava viscosity classroom experiment", + "ambient techno playlist for yoga sunset session", + "ancient pottery glaze chemistry reference table", + "ballet toe shoe ribbon sewing tutorial" + ]; +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationSupport.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationSupport.cs new file mode 100644 index 0000000..1fb6e57 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationSupport.cs @@ -0,0 +1,117 @@ +using ManagedCode.MCPGateway.Abstractions; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewayTokenizerSearchEvaluationTests +{ + private static void RegisterEvaluationTools(McpGatewayOptions options) + { + foreach (var tool in EvaluationTools) + { + options.AddTool( + "local", + TestFunctionFactory.CreateFunction(tool.Callback, tool.Name, tool.Description)); + } + } + + private static async Task EvaluateMatchBucketAsync( + IMcpGateway gateway, + string bucketName, + IReadOnlyList evaluationQueries) + { + var top1Hits = 0; + var top3Hits = 0; + var top5Hits = 0; + var reciprocalRankSum = 0d; + + foreach (var evaluationQuery in evaluationQueries) + { + var searchResult = await gateway.SearchAsync(evaluationQuery.Query); + var expectedToolIds = evaluationQuery.AcceptableToolNames + .Select(static toolName => $"local:{toolName}") + .ToHashSet(StringComparer.Ordinal); + + if (searchResult.Matches.Count > 0 && + expectedToolIds.Contains(searchResult.Matches[0].ToolId)) + { + top1Hits++; + } + + var rank = searchResult.Matches + .Select((match, index) => new { match.ToolId, Rank = index + 1 }) + .FirstOrDefault(item => expectedToolIds.Contains(item.ToolId)) + ?.Rank; + + if (rank is int value) + { + if (value <= 3) + { + top3Hits++; + } + + if (value <= 5) + { + top5Hits++; + } + + reciprocalRankSum += 1d / value; + } + else + { + Console.WriteLine( + $"MISS {bucketName}: '{evaluationQuery.Query}' expected [{string.Join(", ", expectedToolIds)}] but got [{string.Join(", ", searchResult.Matches.Select(static match => match.ToolId))}]"); + } + } + + return new EvaluationMetrics( + Top1Accuracy: (double)top1Hits / evaluationQueries.Count, + Top3Accuracy: (double)top3Hits / evaluationQueries.Count, + Top5Accuracy: (double)top5Hits / evaluationQueries.Count, + MeanReciprocalRank: reciprocalRankSum / evaluationQueries.Count); + } + + private static async Task EvaluateIrrelevantBucketAsync(IMcpGateway gateway) + { + var lowConfidenceHits = 0; + var topScoreSum = 0d; + + foreach (var query in IrrelevantQueries) + { + var searchResult = await gateway.SearchAsync(query); + var topScore = searchResult.Matches.Count > 0 + ? searchResult.Matches[0].Score + : 0d; + topScoreSum += topScore; + + if (topScore <= 0.20d) + { + lowConfidenceHits++; + } + else + { + Console.WriteLine( + $"IRRELEVANT HIGH SCORE: '{query}' returned {topScore:F2} for [{string.Join(", ", searchResult.Matches.Select(static match => match.ToolId))}]"); + } + } + + return new IrrelevantMetrics( + LowConfidenceRate: (double)lowConfidenceHits / IrrelevantQueries.Length, + AverageTopScore: topScoreSum / IrrelevantQueries.Length); + } + + private static EvaluationToolSpec[] CreateSpecs( + Delegate callback, + params (string Name, string Description)[] definitions) + => definitions + .Select(definition => new EvaluationToolSpec(definition.Name, definition.Description, callback)) + .ToArray(); + + private sealed record EvaluationToolSpec(string Name, string Description, Delegate Callback); + + private sealed record EvaluationQuerySpec(string Query, params string[] AcceptableToolNames); + + private sealed record EvaluationMetrics(double Top1Accuracy, double Top3Accuracy, double Top5Accuracy, double MeanReciprocalRank); + + private sealed record IrrelevantMetrics(double LowConfidenceRate, double AverageTopScore); + +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationTests.cs new file mode 100644 index 0000000..815eb1b --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchEvaluationTests.cs @@ -0,0 +1,70 @@ +using ManagedCode.MCPGateway.Abstractions; +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed partial class McpGatewayTokenizerSearchEvaluationTests +{ + [TUnit.Core.Test] + public async Task SearchAsync_TokenizerSearchMeetsEvaluationThresholds() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + TokenizerSearchTestSupport.UseTokenizerSearch(options); + RegisterEvaluationTools(options); + }); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + var highRelevanceMetrics = await EvaluateMatchBucketAsync( + gateway, + "high-relevance", + HighRelevanceQueries); + var borderlineMetrics = await EvaluateMatchBucketAsync( + gateway, + "borderline", + BorderlineQueries); + var multilingualMetrics = await EvaluateMatchBucketAsync( + gateway, + "multilingual", + MultilingualQueries); + var typoMetrics = await EvaluateMatchBucketAsync( + gateway, + "typo", + TypoQueries); + var weakIntentMetrics = await EvaluateMatchBucketAsync( + gateway, + "weak-intent", + WeakIntentQueries); + var irrelevantMetrics = await EvaluateIrrelevantBucketAsync(gateway); + + Console.WriteLine( + $"ChatGptO200kBase / high-relevance: top1={highRelevanceMetrics.Top1Accuracy:P2}; top3={highRelevanceMetrics.Top3Accuracy:P2}; top5={highRelevanceMetrics.Top5Accuracy:P2}; mrr={highRelevanceMetrics.MeanReciprocalRank:F2}"); + Console.WriteLine( + $"ChatGptO200kBase / borderline: top1={borderlineMetrics.Top1Accuracy:P2}; top3={borderlineMetrics.Top3Accuracy:P2}; top5={borderlineMetrics.Top5Accuracy:P2}; mrr={borderlineMetrics.MeanReciprocalRank:F2}"); + Console.WriteLine( + $"ChatGptO200kBase / multilingual: top1={multilingualMetrics.Top1Accuracy:P2}; top3={multilingualMetrics.Top3Accuracy:P2}; top5={multilingualMetrics.Top5Accuracy:P2}; mrr={multilingualMetrics.MeanReciprocalRank:F2}"); + Console.WriteLine( + $"ChatGptO200kBase / typo: top1={typoMetrics.Top1Accuracy:P2}; top3={typoMetrics.Top3Accuracy:P2}; top5={typoMetrics.Top5Accuracy:P2}; mrr={typoMetrics.MeanReciprocalRank:F2}"); + Console.WriteLine( + $"ChatGptO200kBase / weak-intent: top1={weakIntentMetrics.Top1Accuracy:P2}; top3={weakIntentMetrics.Top3Accuracy:P2}; top5={weakIntentMetrics.Top5Accuracy:P2}; mrr={weakIntentMetrics.MeanReciprocalRank:F2}"); + Console.WriteLine( + $"ChatGptO200kBase / irrelevant: low-confidence={irrelevantMetrics.LowConfidenceRate:P2}; avg-top-score={irrelevantMetrics.AverageTopScore:F2}"); + + await Assert.That(buildResult.ToolCount).IsEqualTo(50); + await Assert.That(highRelevanceMetrics.Top1Accuracy >= 0.82d).IsTrue(); + await Assert.That(highRelevanceMetrics.Top3Accuracy >= 0.95d).IsTrue(); + await Assert.That(highRelevanceMetrics.Top5Accuracy >= 0.95d).IsTrue(); + await Assert.That(highRelevanceMetrics.MeanReciprocalRank >= 0.90d).IsTrue(); + await Assert.That(borderlineMetrics.Top3Accuracy >= 0.80d).IsTrue(); + await Assert.That(borderlineMetrics.Top5Accuracy >= 0.95d).IsTrue(); + await Assert.That(multilingualMetrics.Top3Accuracy >= 0.85d).IsTrue(); + await Assert.That(multilingualMetrics.Top5Accuracy >= 0.95d).IsTrue(); + await Assert.That(typoMetrics.Top3Accuracy >= 0.85d).IsTrue(); + await Assert.That(typoMetrics.Top5Accuracy >= 0.95d).IsTrue(); + await Assert.That(weakIntentMetrics.Top3Accuracy >= 0.85d).IsTrue(); + await Assert.That(weakIntentMetrics.Top5Accuracy >= 0.88d).IsTrue(); + await Assert.That(irrelevantMetrics.LowConfidenceRate >= 0.95d).IsTrue(); + await Assert.That(irrelevantMetrics.AverageTopScore <= 0.15d).IsTrue(); + } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchTests.cs b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchTests.cs new file mode 100644 index 0000000..579aca7 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/McpGatewayTokenizerSearchTests.cs @@ -0,0 +1,152 @@ +using System.ComponentModel; + +using ManagedCode.MCPGateway.Abstractions; + +using Microsoft.Extensions.DependencyInjection; + +namespace ManagedCode.MCPGateway.Tests; + +public sealed class McpGatewayTokenizerSearchTests +{ + [TUnit.Core.Test] + public async Task BuildIndexAsync_SkipsEmbeddingsWhenTokenizerStrategyIsSelected() + { + var embeddingGenerator = new TestEmbeddingGenerator(); + await using var serviceProvider = GatewayTestServiceProviderFactory.Create( + ConfigureTokenizerSearchTools, + embeddingGenerator); + var gateway = serviceProvider.GetRequiredService(); + + var buildResult = await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("github pull requests", maxResults: 1); + + await Assert.That(buildResult.VectorizedToolCount).IsEqualTo(0); + await Assert.That(buildResult.IsVectorSearchEnabled).IsFalse(); + await Assert.That(embeddingGenerator.Calls.Count).IsEqualTo(0); + await Assert.That(searchResult.RankingMode).IsEqualTo("lexical"); + await Assert.That(searchResult.Diagnostics.Any(static diagnostic => diagnostic.Code == "lexical_fallback")).IsFalse(); + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:github_search_issues"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_UsesTheBuiltInChatGptTokenizerProfile() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + TokenizerSearchTestSupport.UseTokenizerSearch(options); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + SearchGitHubPullRequests, + "github_pull_request_search", + "Search GitHub pull requests by repository, reviewer, branch, or merge status.")); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + FilterAdvisories, + "advisory_lookup", + "Lookup advisory records by severity, ecosystem, package, or CVE reference.")); + }); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var advisorySearch = await gateway.SearchAsync("critical severity advisory", maxResults: 1); + var pullRequestSearch = await gateway.SearchAsync("review queue for pull requests", maxResults: 1); + + await Assert.That(advisorySearch.RankingMode).IsEqualTo("lexical"); + await Assert.That(advisorySearch.Matches[0].ToolId).IsEqualTo("local:advisory_lookup"); + await Assert.That(pullRequestSearch.Matches[0].ToolId).IsEqualTo("local:github_pull_request_search"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_PrefersFilesystemLookupOverFinanceStatusForInvoicePdfQueries() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + TokenizerSearchTestSupport.UseTokenizerSearch(options); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + SearchWeatherForecast, + "filesystem_find_files", + "Find or locate files, PDFs, report documents, and exported invoice files by folder, reports workspace, glob pattern, extension, or text content. Not for invoice payment status.")); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + FilterAdvisories, + "finance_invoice_search", + "Find invoices, bills, and billing records by customer, invoice number, payment state, or due date.")); + }); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("where did the invoice pdf go in reports", maxResults: 1); + + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:filesystem_find_files"); + } + + [TUnit.Core.Test] + public async Task SearchAsync_PrefersInventoryLookupOverCatalogSearchForWarehouseAvailability() + { + await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options => + { + TokenizerSearchTestSupport.UseTokenizerSearch(options); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + SearchWeatherForecast, + "commerce_catalog_search", + "Search product catalog by keyword, category, brand, or attribute filters.")); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + SearchWeatherForecast, + "commerce_inventory_lookup", + "Lookup inventory, stock by SKU, warehouse balance, and availability.")); + }); + var gateway = serviceProvider.GetRequiredService(); + + await gateway.BuildIndexAsync(); + var searchResult = await gateway.SearchAsync("is sku keyboard-ergo available in warehouse", maxResults: 1); + + await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:commerce_inventory_lookup"); + } + + private static void ConfigureTokenizerSearchTools(McpGatewayOptions options) + { + TokenizerSearchTestSupport.UseTokenizerSearch(options); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + SearchGitHubPullRequests, + "github_search_issues", + "Search GitHub issues and pull requests by user query.")); + options.AddTool( + "local", + TestFunctionFactory.CreateFunction( + SearchWeatherForecast, + "weather_search_forecast", + "Search weather forecast and temperature information by city name.")); + } + + private static string SearchGitHubPullRequests( + [Description("Repository owner or organization handle.")] string owner, + [Description("Repository name or slug.")] string repository, + [Description("Free-form pull request search terms or reviewer names.")] string query, + [Description("Workflow state such as open, closed, or merged.")] WorkItemState state) + => $"{owner}/{repository}:{query}:{state}"; + + private static string SearchWeatherForecast( + [Description("City, airport code, or geo place name.")] string location, + [Description("Forecast window such as today, weekend, or next 5 days.")] string timeRange, + [Description("Preferred temperature unit.")] TemperatureUnit unit, + [Description("Optional weather focus such as rain, wind, or air quality.")] string focus) + => $"{location}:{timeRange}:{unit}:{focus}"; + + private static string FilterAdvisories( + [Description("Severity level to filter advisories.")] TicketSeverity severity, + [Description("Software ecosystem such as npm, nuget, or container image.")] string ecosystem, + [Description("Package name, image, or CVE reference to inspect.")] string packageOrReference) + => $"{severity}:{ecosystem}:{packageOrReference}"; + +} diff --git a/tests/ManagedCode.MCPGateway.Tests/Search/SearchTestEnums.cs b/tests/ManagedCode.MCPGateway.Tests/Search/SearchTestEnums.cs new file mode 100644 index 0000000..bd67851 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/Search/SearchTestEnums.cs @@ -0,0 +1,46 @@ +namespace ManagedCode.MCPGateway.Tests; + +internal static class TokenizerSearchTestSupport +{ + public static void UseTokenizerSearch(McpGatewayOptions options) + => options.SearchStrategy = McpGatewaySearchStrategy.Tokenizer; +} + +internal enum WorkItemState +{ + Open, + Closed, + Merged +} + +internal enum TemperatureUnit +{ + Celsius, + Fahrenheit +} + +internal enum FileIntent +{ + Find, + Read, + Write, + Move, + List +} + +internal enum TicketSeverity +{ + Low, + Medium, + High, + Critical +} + +internal enum FinanceIntent +{ + Invoice, + Refund, + Exchange, + Reconciliation, + Tax +} diff --git a/tests/ManagedCode.MCPGateway.Tests/TestSupport/GatewayTestServiceProviderFactory.cs b/tests/ManagedCode.MCPGateway.Tests/TestSupport/GatewayTestServiceProviderFactory.cs index 4a472b3..ad463d0 100644 --- a/tests/ManagedCode.MCPGateway.Tests/TestSupport/GatewayTestServiceProviderFactory.cs +++ b/tests/ManagedCode.MCPGateway.Tests/TestSupport/GatewayTestServiceProviderFactory.cs @@ -1,7 +1,7 @@ +using ManagedCode.MCPGateway.Abstractions; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using ManagedCode.MCPGateway.Abstractions; namespace ManagedCode.MCPGateway.Tests; @@ -10,7 +10,8 @@ internal static class GatewayTestServiceProviderFactory public static ServiceProvider Create( Action configure, IEmbeddingGenerator>? embeddingGenerator = null, - IMcpGatewayToolEmbeddingStore? embeddingStore = null) + IMcpGatewayToolEmbeddingStore? embeddingStore = null, + IChatClient? searchQueryChatClient = null) { var services = new ServiceCollection(); services.AddLogging(static logging => logging.SetMinimumLevel(LogLevel.Debug)); @@ -25,6 +26,11 @@ public static ServiceProvider Create( services.AddSingleton(embeddingStore); } + if (searchQueryChatClient is not null) + { + services.AddKeyedSingleton(McpGatewayServiceKeys.SearchQueryChatClient, searchQueryChatClient); + } + services.AddManagedCodeMcpGateway(configure); return services.BuildServiceProvider(); } diff --git a/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestChatClient.cs b/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestChatClient.cs new file mode 100644 index 0000000..84fe9e2 --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestChatClient.cs @@ -0,0 +1,55 @@ +using Microsoft.Extensions.AI; + +namespace ManagedCode.MCPGateway.Tests; + +internal sealed class TestChatClient(TestChatClientOptions? options = null) : IChatClient +{ + private readonly TestChatClientOptions _options = options ?? new(); + + public List Calls { get; } = []; + + public Task GetResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + var query = messages.LastOrDefault(static message => message.Role == ChatRole.User)?.Text ?? string.Empty; + Calls.Add(query); + + if (_options.ThrowOnInput?.Invoke(query) == true) + { + throw new InvalidOperationException("Query normalization failed for a test input."); + } + + var rewrittenQuery = _options.RewriteQuery?.Invoke(query) ?? query; + return Task.FromResult(new ChatResponse(new ChatMessage(ChatRole.Assistant, rewrittenQuery))); + } + + public async IAsyncEnumerable GetStreamingResponseAsync( + IEnumerable messages, + ChatOptions? options = null, + [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken cancellationToken = default) + { + _ = await GetResponseAsync(messages, options, cancellationToken); + yield break; + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + ArgumentNullException.ThrowIfNull(serviceType); + return null; + } + + public void Dispose() + { + } +} + +internal sealed class TestChatClientOptions +{ + public Func? RewriteQuery { get; init; } + + public Func? ThrowOnInput { get; init; } +} diff --git a/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestFunctionFactory.cs b/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestFunctionFactory.cs new file mode 100644 index 0000000..17306db --- /dev/null +++ b/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestFunctionFactory.cs @@ -0,0 +1,15 @@ +using Microsoft.Extensions.AI; + +namespace ManagedCode.MCPGateway.Tests; + +internal static class TestFunctionFactory +{ + public static AIFunction CreateFunction(Delegate callback, string name, string description) + => AIFunctionFactory.Create( + callback, + new AIFunctionFactoryOptions + { + Name = name, + Description = description + }); +} diff --git a/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestToolEmbeddingStore.cs b/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestToolEmbeddingStore.cs index 7c98a76..af14511 100644 --- a/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestToolEmbeddingStore.cs +++ b/tests/ManagedCode.MCPGateway.Tests/TestSupport/TestToolEmbeddingStore.cs @@ -4,7 +4,7 @@ namespace ManagedCode.MCPGateway.Tests; internal sealed class TestToolEmbeddingStore : IMcpGatewayToolEmbeddingStore { - private readonly Dictionary _embeddings = []; + private readonly McpGatewayToolEmbeddingStoreIndex _index = new(); public List> GetCalls { get; } = []; @@ -17,105 +17,18 @@ public Task> GetAsync( cancellationToken.ThrowIfCancellationRequested(); GetCalls.Add(lookups.ToList()); - - var matches = new List(lookups.Count); - foreach (var lookup in lookups) - { - if (TryGetEmbedding(lookup, out var embedding)) - { - matches.Add(Clone(embedding)); - } - } - - return Task.FromResult>(matches); + return Task.FromResult(_index.Get(lookups, cancellationToken)); } public Task UpsertAsync( IReadOnlyList embeddings, CancellationToken cancellationToken = default) { - cancellationToken.ThrowIfCancellationRequested(); - - var clonedBatch = embeddings - .Select(Clone) - .ToList(); + var clonedBatch = _index.Upsert(embeddings, cancellationToken); UpsertCalls.Add(clonedBatch); - foreach (var embedding in clonedBatch) - { - _embeddings[StoreKey.FromEmbedding(embedding)] = embedding; - } - return Task.CompletedTask; } - public void Remove(string toolId) - { - var keys = _embeddings.Keys - .Where(key => string.Equals(key.NormalizedToolId, NormalizeToolId(toolId), StringComparison.Ordinal)) - .ToList(); - - foreach (var key in keys) - { - _embeddings.Remove(key); - } - } - - private static McpGatewayToolEmbedding Clone(McpGatewayToolEmbedding embedding) - => embedding with - { - Vector = [.. embedding.Vector] - }; - - private bool TryGetEmbedding( - McpGatewayToolEmbeddingLookup lookup, - out McpGatewayToolEmbedding embedding) - { - var storeKey = StoreKey.FromLookup(lookup); - if (lookup.EmbeddingGeneratorFingerprint is not null) - { - return _embeddings.TryGetValue(storeKey, out embedding!); - } - - foreach (var (key, value) in _embeddings) - { - if (key.Matches(storeKey)) - { - embedding = value; - return true; - } - } - - embedding = default!; - return false; - } - - private static string NormalizeToolId(string toolId) => toolId.ToUpperInvariant(); - - private readonly record struct StoreKey( - string NormalizedToolId, - string DocumentHash, - string? EmbeddingGeneratorFingerprint) - { - public static StoreKey FromLookup(McpGatewayToolEmbeddingLookup lookup) - => new( - NormalizeToolId(lookup.ToolId), - lookup.DocumentHash, - lookup.EmbeddingGeneratorFingerprint); - - public static StoreKey FromEmbedding(McpGatewayToolEmbedding embedding) - => new( - NormalizeToolId(embedding.ToolId), - embedding.DocumentHash, - embedding.EmbeddingGeneratorFingerprint); - - public bool Matches(StoreKey other) - => string.Equals(NormalizedToolId, other.NormalizedToolId, StringComparison.Ordinal) - && string.Equals(DocumentHash, other.DocumentHash, StringComparison.Ordinal) - && (other.EmbeddingGeneratorFingerprint is null - || string.Equals( - EmbeddingGeneratorFingerprint, - other.EmbeddingGeneratorFingerprint, - StringComparison.Ordinal)); - } + public void Remove(string toolId) => _index.Remove(toolId); }