From f69fc618987819fd875468c10ab637f659cb82e3 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 16 Dec 2025 10:05:48 -0800 Subject: [PATCH 01/18] Add distributed SSE event stream store --- Directory.Packages.props | 2 + .../McpJsonUtilities.cs | 5 + .../ModelContextProtocol.Core.csproj | 1 + .../DistributedCacheEventStreamStore.cs | 353 ++++ ...DistributedCacheEventStreamStoreOptions.cs | 51 + .../ModelContextProtocol.Tests.csproj | 1 + .../DistributedCacheEventStreamStoreTests.cs | 1594 +++++++++++++++++ 7 files changed, 2007 insertions(+) create mode 100644 src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs create mode 100644 src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs create mode 100644 tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 8a09ce3e1..c0ed13101 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -11,6 +11,8 @@ + + diff --git a/src/ModelContextProtocol.Core/McpJsonUtilities.cs b/src/ModelContextProtocol.Core/McpJsonUtilities.cs index b3d98dd0e..081ce0005 100644 --- a/src/ModelContextProtocol.Core/McpJsonUtilities.cs +++ b/src/ModelContextProtocol.Core/McpJsonUtilities.cs @@ -1,6 +1,7 @@ using Microsoft.Extensions.AI; using ModelContextProtocol.Authentication; using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; @@ -158,6 +159,10 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(BlobResourceContents))] [JsonSerializable(typeof(TextResourceContents))] + // Distributed cache event stream store + [JsonSerializable(typeof(DistributedCacheEventStreamStore.StreamMetadata))] + [JsonSerializable(typeof(DistributedCacheEventStreamStore.StoredEvent))] + // Other MCP Types [JsonSerializable(typeof(IReadOnlyDictionary))] [JsonSerializable(typeof(ProgressToken))] diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index 9e22a5c0e..9e18c7b76 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -49,6 +49,7 @@ + diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs new file mode 100644 index 000000000..02f6b8be4 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -0,0 +1,353 @@ +using Microsoft.Extensions.Caching.Distributed; +using ModelContextProtocol.Protocol; +using System.Net.ServerSentEvents; +using System.Runtime.CompilerServices; +using System.Text.Json; + +namespace ModelContextProtocol.Server; + +/// +/// An implementation backed by . +/// +/// +/// +/// This implementation stores SSE events in a distributed cache, enabling resumability across +/// multiple server instances. Event IDs are encoded with session, stream, and sequence information +/// to allow efficient retrieval of events after a given point. +/// +/// +/// The writer maintains in-memory state for sequence number generation, as there is guaranteed +/// to be only one writer per stream. Readers may be created from separate processes. +/// +/// +public sealed class DistributedCacheEventStreamStore : ISseEventStreamStore +{ + private readonly IDistributedCache _cache; + private readonly DistributedCacheEventStreamStoreOptions _options; + + /// + /// Initializes a new instance of the class. + /// + /// The distributed cache to use for storage. + /// Optional configuration options for the store. + public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null) + { + Throw.IfNull(cache); + _cache = cache; + _options = options ?? new(); + } + + /// + public ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default) + { + Throw.IfNull(options); + var writer = new DistributedCacheEventStreamWriter(_cache, options.SessionId, options.StreamId, options.Mode, _options); + return new ValueTask(writer); + } + + /// + public async ValueTask GetStreamReaderAsync(string lastEventId, CancellationToken cancellationToken = default) + { + Throw.IfNull(lastEventId); + + // Parse the event ID to get session, stream, and sequence information + if (!EventIdCodec.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence)) + { + return null; + } + + // Check if the stream exists by looking for its metadata + var metadataKey = CacheKeys.StreamMetadata(sessionId, streamId); + var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); + if (metadataBytes is null) + { + return null; + } + + var metadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); + if (metadata is null) + { + return null; + } + + return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, sequence, metadata, _options); + } + + /// + /// Provides methods for encoding and decoding event IDs. + /// + internal static class EventIdCodec + { + private const char Separator = ':'; + + /// + /// Encodes session ID, stream ID, and sequence number into an event ID string. + /// + public static string Encode(string sessionId, string streamId, long sequence) + { + // Base64-encode session and stream IDs so the event ID can be parsed + // even if the original IDs contain the ':' separator character + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId)); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(streamId)); + return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}"; + } + + /// + /// Attempts to parse an event ID into its component parts. + /// + public static bool TryParse(string eventId, out string sessionId, out string streamId, out long sequence) + { + sessionId = string.Empty; + streamId = string.Empty; + sequence = 0; + + var parts = eventId.Split(Separator); + if (parts.Length != 3) + { + return false; + } + + try + { + sessionId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[0])); + streamId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[1])); + return long.TryParse(parts[2], out sequence); + } + catch + { + return false; + } + } + } + + /// + /// Provides methods for generating cache keys. + /// + internal static class CacheKeys + { + private const string Prefix = "mcp:sse:"; + + public static string StreamMetadata(string sessionId, string streamId) => + $"{Prefix}meta:{sessionId}:{streamId}"; + + public static string Event(string eventId) => + $"{Prefix}event:{eventId}"; + + public static string StreamEventCount(string sessionId, string streamId) => + $"{Prefix}count:{sessionId}:{streamId}"; + } + + /// + /// Metadata about a stream stored in the cache. + /// + internal sealed class StreamMetadata + { + public SseEventStreamMode Mode { get; set; } + public bool IsCompleted { get; set; } + public long LastSequence { get; set; } + } + + /// + /// Serialized representation of an SSE event stored in the cache. + /// + internal sealed class StoredEvent + { + public string? EventType { get; set; } + public string? EventId { get; set; } + public JsonRpcMessage? Data { get; set; } + } + + private sealed class DistributedCacheEventStreamWriter : ISseEventStreamWriter + { + private readonly IDistributedCache _cache; + private readonly DistributedCacheEventStreamStoreOptions _options; + private long _sequence; + private bool _disposed; + + public DistributedCacheEventStreamWriter( + IDistributedCache cache, + string sessionId, + string streamId, + SseEventStreamMode mode, + DistributedCacheEventStreamStoreOptions options) + { + _cache = cache; + SessionId = sessionId; + StreamId = streamId; + Mode = mode; + _options = options; + } + + public string SessionId { get; } + public string StreamId { get; } + public SseEventStreamMode Mode { get; private set; } + + public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) + { + Mode = mode; + await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); + } + + public async ValueTask> WriteEventAsync(SseItem sseItem, CancellationToken cancellationToken = default) + { + // Skip if already has an event ID + if (sseItem.EventId is not null) + { + return sseItem; + } + + // Generate a new sequence number and event ID + var sequence = Interlocked.Increment(ref _sequence); + var eventId = EventIdCodec.Encode(SessionId, StreamId, sequence); + var newItem = sseItem with { EventId = eventId }; + + // Store the event in the cache + var storedEvent = new StoredEvent + { + EventType = newItem.EventType, + EventId = eventId, + Data = newItem.Data, + }; + + var eventBytes = JsonSerializer.SerializeToUtf8Bytes(storedEvent, McpJsonUtilities.JsonContext.Default.StoredEvent); + var eventKey = CacheKeys.Event(eventId); + + await _cache.SetAsync(eventKey, eventBytes, new DistributedCacheEntryOptions + { + SlidingExpiration = _options.EventSlidingExpiration, + AbsoluteExpirationRelativeToNow = _options.EventAbsoluteExpiration, + }, cancellationToken).ConfigureAwait(false); + + // Update metadata with the latest sequence + await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); + + return newItem; + } + + private async ValueTask UpdateMetadataAsync(CancellationToken cancellationToken) + { + var metadata = new StreamMetadata + { + Mode = Mode, + IsCompleted = _disposed, + LastSequence = Interlocked.Read(ref _sequence), + }; + + var metadataBytes = JsonSerializer.SerializeToUtf8Bytes(metadata, McpJsonUtilities.JsonContext.Default.StreamMetadata); + var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); + + await _cache.SetAsync(metadataKey, metadataBytes, new DistributedCacheEntryOptions + { + SlidingExpiration = _options.MetadataSlidingExpiration, + AbsoluteExpirationRelativeToNow = _options.MetadataAbsoluteExpiration, + }, cancellationToken).ConfigureAwait(false); + } + + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + + // Mark the stream as completed in the metadata + await UpdateMetadataAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + private sealed class DistributedCacheEventStreamReader : ISseEventStreamReader + { + private readonly IDistributedCache _cache; + private readonly long _startSequence; + private readonly StreamMetadata _metadata; + private readonly DistributedCacheEventStreamStoreOptions _options; + + public DistributedCacheEventStreamReader( + IDistributedCache cache, + string sessionId, + string streamId, + long startSequence, + StreamMetadata metadata, + DistributedCacheEventStreamStoreOptions options) + { + _cache = cache; + SessionId = sessionId; + StreamId = streamId; + _startSequence = startSequence; + _metadata = metadata; + _options = options; + } + + public string SessionId { get; } + public string StreamId { get; } + + public async IAsyncEnumerable> ReadEventsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + // Start from the sequence after the last received event + var currentSequence = _startSequence; + + while (!cancellationToken.IsCancellationRequested) + { + // Refresh metadata to get the latest sequence and completion status + var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); + var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); + + StreamMetadata? currentMetadata = null; + if (metadataBytes is not null) + { + currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); + } + + var lastSequence = currentMetadata?.LastSequence ?? _metadata.LastSequence; + var isCompleted = currentMetadata?.IsCompleted ?? _metadata.IsCompleted; + var mode = currentMetadata?.Mode ?? _metadata.Mode; + + // Read all available events from currentSequence + 1 to lastSequence + while (currentSequence < lastSequence) + { + cancellationToken.ThrowIfCancellationRequested(); + + var nextSequence = currentSequence + 1; + var eventId = EventIdCodec.Encode(SessionId, StreamId, nextSequence); + var eventKey = CacheKeys.Event(eventId); + var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false); + + if (eventBytes is null) + { + // Event may have expired; skip to next + currentSequence = nextSequence; + continue; + } + + var storedEvent = JsonSerializer.Deserialize(eventBytes, McpJsonUtilities.JsonContext.Default.StoredEvent); + if (storedEvent is not null) + { + yield return new SseItem(storedEvent.Data, storedEvent.EventType) + { + EventId = storedEvent.EventId, + }; + } + + currentSequence = nextSequence; + } + + // If in polling mode, stop after returning currently available events + if (mode == SseEventStreamMode.Polling) + { + yield break; + } + + // If the stream is completed, stop + if (isCompleted) + { + yield break; + } + + // Wait before polling again for new events + await Task.Delay(_options.PollingInterval, cancellationToken).ConfigureAwait(false); + } + } + } +} diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs new file mode 100644 index 000000000..b6641d3fe --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs @@ -0,0 +1,51 @@ +namespace ModelContextProtocol.Server; + +/// +/// Configuration options for . +/// +public sealed class DistributedCacheEventStreamStoreOptions +{ + /// + /// Gets or sets the sliding expiration for individual events in the cache. + /// + /// + /// Events are refreshed on each access. If an event is not accessed within this + /// time period, it may be evicted from the cache. + /// + public TimeSpan? EventSlidingExpiration { get; set; } = TimeSpan.FromMinutes(30); + + /// + /// Gets or sets the absolute expiration for individual events in the cache. + /// + /// + /// Events will be evicted from the cache after this time period, regardless of access. + /// + public TimeSpan? EventAbsoluteExpiration { get; set; } = TimeSpan.FromHours(2); + + /// + /// Gets or sets the sliding expiration for stream metadata in the cache. + /// + /// + /// Stream metadata includes mode and completion status. This should typically be + /// set to a longer duration than event expiration to allow for resumability. + /// + public TimeSpan? MetadataSlidingExpiration { get; set; } = TimeSpan.FromHours(1); + + /// + /// Gets or sets the absolute expiration for stream metadata in the cache. + /// + /// + /// Stream metadata will be evicted from the cache after this time period, regardless of access. + /// + public TimeSpan? MetadataAbsoluteExpiration { get; set; } = TimeSpan.FromHours(4); + + /// + /// Gets or sets the interval between polling attempts when a reader is waiting for new events + /// in mode. + /// + /// + /// This only affects readers. A shorter interval provides lower latency for new events + /// but increases cache access frequency. + /// + public TimeSpan PollingInterval { get; set; } = TimeSpan.FromMilliseconds(100); +} diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index e0fb3d1fa..8bf01df43 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -41,6 +41,7 @@ + diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs new file mode 100644 index 000000000..b061acdfc --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -0,0 +1,1594 @@ +using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Net.ServerSentEvents; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for . +/// +public class DistributedCacheEventStreamStoreTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +{ + private CancellationToken CancellationToken => TestContext.Current.CancellationToken; + + private static IDistributedCache CreateMemoryCache() + { + var options = Options.Create(new MemoryDistributedCacheOptions()); + return new MemoryDistributedCache(options); + } + + #region Constructor & Initialization Tests + + [Fact] + public void Constructor_ThrowsArgumentNullException_WhenCacheIsNull() + { + Assert.Throws("cache", () => new DistributedCacheEventStreamStore(null!)); + } + + [Fact] + public async Task Constructor_UsesDefaultOptions_WhenOptionsParameterIsNull() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, options: null); + + // Act - Create a stream to verify the store works with default options + var streamOptions = new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }; + var writer = await store.CreateStreamAsync(streamOptions, CancellationToken); + + // Assert - The store should work normally with default options + Assert.NotNull(writer); + Assert.Equal("stream-1", writer.StreamId); + Assert.Equal(SseEventStreamMode.Default, writer.Mode); + } + + [Fact] + public async Task Constructor_UsesProvidedOptions_WhenOptionsParameterIsSpecified() + { + // Arrange + var cache = CreateMemoryCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventSlidingExpiration = TimeSpan.FromMinutes(10), + EventAbsoluteExpiration = TimeSpan.FromHours(1), + MetadataSlidingExpiration = TimeSpan.FromMinutes(20), + MetadataAbsoluteExpiration = TimeSpan.FromHours(2), + PollingInterval = TimeSpan.FromMilliseconds(50) + }; + var store = new DistributedCacheEventStreamStore(cache, customOptions); + + // Act - Create a stream to verify the store works with custom options + var streamOptions = new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }; + var writer = await store.CreateStreamAsync(streamOptions, CancellationToken); + + // Assert - The store should work with custom options + Assert.NotNull(writer); + Assert.Equal("stream-1", writer.StreamId); + } + + #endregion + + #region CreateStreamAsync Tests + + [Fact] + public async Task CreateStreamAsync_ReturnsWriter_WithCorrectStreamId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var options = new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "my-stream-id", + Mode = SseEventStreamMode.Default + }; + + // Act + var writer = await store.CreateStreamAsync(options, CancellationToken); + + // Assert + Assert.Equal("my-stream-id", writer.StreamId); + } + + [Fact] + public async Task CreateStreamAsync_ReturnsWriter_WithCorrectSessionId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var options = new SseEventStreamOptions + { + SessionId = "my-session-id", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }; + + // Act + var writer = await store.CreateStreamAsync(options, CancellationToken); + + // Assert - Write an event and verify the reader can find it by session + var item = new SseItem(null); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + Assert.Equal("my-session-id", reader.SessionId); + } + + [Fact] + public async Task CreateStreamAsync_ReturnsWriter_WithDefaultMode() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var options = new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }; + + // Act + var writer = await store.CreateStreamAsync(options, CancellationToken); + + // Assert + Assert.Equal(SseEventStreamMode.Default, writer.Mode); + } + + [Fact] + public async Task CreateStreamAsync_ReturnsWriter_WithPollingMode() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var options = new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }; + + // Act + var writer = await store.CreateStreamAsync(options, CancellationToken); + + // Assert + Assert.Equal(SseEventStreamMode.Polling, writer.Mode); + } + + [Fact] + public async Task CreateStreamAsync_ThrowsArgumentNullException_WhenOptionsIsNull() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Act & Assert + await Assert.ThrowsAsync("options", + async () => await store.CreateStreamAsync(null!, CancellationToken)); + } + + #endregion + + #region WriteEventAsync Tests + + [Fact] + public async Task WriteEventAsync_AssignsUniqueEventId_WhenItemHasNoEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var item = new SseItem(null); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.NotNull(result.EventId); + Assert.NotEmpty(result.EventId); + } + + [Fact] + public async Task WriteEventAsync_SkipsAssigningEventId_WhenItemAlreadyHasEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var existingEventId = "existing-event-id"; + var item = new SseItem(null) { EventId = existingEventId }; + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.Equal(existingEventId, result.EventId); + } + + [Fact] + public async Task WriteEventAsync_PreservesDataProperty_InReturnedItem() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var message = new JsonRpcNotification { Method = "test/notification" }; + var item = new SseItem(message); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Data should be preserved in the returned item (same reference) + Assert.Same(message, result.Data); + } + + [Fact] + public async Task WriteEventAsync_PreservesEventTypeProperty_InReturnedItem() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var item = new SseItem(null, "custom-event-type"); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.Equal("custom-event-type", result.EventType); + } + + [Fact] + public async Task WriteEventAsync_HandlesNullData_AssignsEventIdAndStoresEvent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var item = new SseItem(null); + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Event ID should be assigned + Assert.NotNull(result.EventId); + + // Assert - Event should be retrievable + var reader = await store.GetStreamReaderAsync(result.EventId, CancellationToken); + Assert.NotNull(reader); + } + + [Fact] + public async Task WriteEventAsync_StoresEventWithCorrectSlidingExpiration() + { + // Arrange - Use a mock cache to verify expiration options + var mockCache = new TrackingDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventSlidingExpiration = TimeSpan.FromMinutes(15) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var item = new SseItem(null); + + // Act + await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Verify at least one call used the expected sliding expiration + Assert.Contains(mockCache.SetCalls, call => + call.Key.Contains("event:") && + call.Options.SlidingExpiration == TimeSpan.FromMinutes(15)); + } + + [Fact] + public async Task WriteEventAsync_StoresEventWithCorrectAbsoluteExpiration() + { + // Arrange + var mockCache = new TrackingDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventAbsoluteExpiration = TimeSpan.FromHours(3) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var item = new SseItem(null); + + // Act + await writer.WriteEventAsync(item, CancellationToken); + + // Assert + Assert.Contains(mockCache.SetCalls, call => + call.Key.Contains("event:") && + call.Options.AbsoluteExpirationRelativeToNow == TimeSpan.FromHours(3)); + } + + [Fact] + public async Task WriteEventAsync_UpdatesStreamMetadata_AfterEachWrite() + { + // Arrange + var mockCache = new TrackingDistributedCache(); + var store = new DistributedCacheEventStreamStore(mockCache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + var item = new SseItem(null); + + // Act + await writer.WriteEventAsync(item, CancellationToken); + + // Assert - Metadata should have been updated + Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); + } + + #endregion + + #region SetModeAsync (Writer) Tests + + [Fact] + public async Task SetModeAsync_UpdatesModeProperty_OnWriter() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Act + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Assert + Assert.Equal(SseEventStreamMode.Polling, writer.Mode); + } + + [Fact] + public async Task SetModeAsync_PersistsModeChangeToMetadata() + { + // Arrange + var mockCache = new TrackingDistributedCache(); + var store = new DistributedCacheEventStreamStore(mockCache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + mockCache.SetCalls.Clear(); // Clear calls from CreateStreamAsync setup + + // Act + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Assert - Metadata should have been updated with the new mode + Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); + } + + [Fact] + public async Task SetModeAsync_ModeChangeReflectedInReader() + { + // Arrange + var cache = CreateMemoryCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(10) + }; + var store = new DistributedCacheEventStreamStore(cache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write an event to have something to read + var item = new SseItem(new JsonRpcNotification { Method = "test" }); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Get a reader based on the event ID (starting at sequence 1, reader will wait for seq 2+) + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Change mode to Polling while reader exists + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Assert - Reader should complete immediately in polling mode (no new events to read) + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(500)); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + } + + // In polling mode, reader should complete without waiting for new events + Assert.Empty(events); // No events after the one we used to create the reader + } + + #endregion + + #region DisposeAsync (Writer) Tests + + [Fact] + public async Task DisposeAsync_MarksStreamAsCompleted() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write an event so we can get a reader + var item = new SseItem(null); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Act + await writer.DisposeAsync(); + + // Assert - Reader should see the stream as completed and exit immediately + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(500)); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + } + + // The reader should complete without waiting for new events because stream is completed + Assert.Empty(events); // No new events after the one we used to create the reader + } + + [Fact] + public async Task DisposeAsync_IsIdempotent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Act - Call DisposeAsync multiple times + await writer.DisposeAsync(); + await writer.DisposeAsync(); + await writer.DisposeAsync(); + + // Assert - No exception thrown, operation is idempotent + // If we got here without exception, the test passes + } + + [Fact] + public async Task DisposeAsync_UpdatesMetadata_WithIsCompletedFlag() + { + // Arrange + var mockCache = new TrackingDistributedCache(); + var store = new DistributedCacheEventStreamStore(mockCache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + mockCache.SetCalls.Clear(); // Clear calls from CreateStreamAsync + + // Act + await writer.DisposeAsync(); + + // Assert - Metadata should have been updated + Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); + } + + #endregion + + #region GetStreamReaderAsync Tests + + [Fact] + public async Task GetStreamReaderAsync_ThrowsArgumentNullException_WhenLastEventIdIsNull() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Act & Assert + await Assert.ThrowsAsync("lastEventId", + async () => await store.GetStreamReaderAsync(null!, CancellationToken)); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsNull_WhenEventIdIsUnparseable() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Act - Try various invalid event ID formats + var result1 = await store.GetStreamReaderAsync("invalid-format", CancellationToken); + var result2 = await store.GetStreamReaderAsync("only:two:parts:here", CancellationToken); + var result3 = await store.GetStreamReaderAsync("", CancellationToken); + + // Assert + Assert.Null(result1); + Assert.Null(result2); + Assert.Null(result3); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsNull_WhenStreamMetadataDoesNotExist() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Create a valid-looking event ID for a stream that doesn't exist + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("nonexistent-session")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("nonexistent-stream")); + var fakeEventId = $"{sessionBase64}:{streamBase64}:1"; + + // Act + var reader = await store.GetStreamReaderAsync(fakeEventId, CancellationToken); + + // Assert + Assert.Null(reader); + } + + [Fact] + public async Task GetStreamReaderAsync_ReturnsReaderWithCorrectSessionIdAndStreamId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "my-session", + StreamId = "my-stream", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write an event to get a valid event ID + var item = new SseItem(null); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Act + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + + // Assert + Assert.NotNull(reader); + Assert.Equal("my-session", reader.SessionId); + Assert.Equal("my-stream", reader.StreamId); + } + + #endregion + + #region ReadEventsAsync (Reader) Tests + + [Fact] + public async Task ReadEventsAsync_ReturnsEventsInOrder() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write multiple events + var event1 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method1" }), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method2" }), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method3" }), CancellationToken); + + // Create a reader starting from before the first event (use a fake event ID with sequence 0) + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); + var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert - Events should be in order + Assert.Equal(3, events.Count); + Assert.Equal(event1.EventId, events[0].EventId); + Assert.Equal(event2.EventId, events[1].EventId); + Assert.Equal(event3.EventId, events[2].EventId); + } + + [Fact] + public async Task ReadEventsAsync_ReturnsEmpty_WhenNoNewEventsExist() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write one event + var writtenItem = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader starting from the last event (so there are no new events to read) + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert + Assert.Empty(events); + } + + [Fact] + public async Task ReadEventsAsync_PreservesCorrectDataEventTypeAndEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var message = new JsonRpcNotification { Method = "test/method" }; + var writtenItem = await writer.WriteEventAsync(new SseItem(message, "custom-event-type"), CancellationToken); + + // Create a reader starting from before the event + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); + var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert + Assert.Single(events); + var readEvent = events[0]; + Assert.Equal(writtenItem.EventId, readEvent.EventId); + Assert.Equal("custom-event-type", readEvent.EventType); + + var readMessage = Assert.IsType(readEvent.Data); + Assert.Equal("test/method", readMessage.Method); + } + + [Fact] + public async Task ReadEventsAsync_HandlesNullData() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writtenItem = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader starting from before the event + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); + var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert + Assert.Single(events); + Assert.Null(events[0].Data); + Assert.Equal(writtenItem.EventId, events[0].EventId); + } + + #endregion + + #region ReadEventsAsync - Polling Mode Tests + + [Fact] + public async Task ReadEventsAsync_InPollingMode_CompletesImmediatelyAfterReturningAvailableEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write events + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader from sequence 0 + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); + var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act - Should complete quickly without waiting for new events + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + stopwatch.Stop(); + + // Assert - Should have returned both events and completed quickly + Assert.Equal(2, events.Count); + Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Polling mode should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_ReturnsOnlyEventsAfterLastEventId() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write 3 events + var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Create a reader starting from event2 (should only return event3) + var reader = await store.GetStreamReaderAsync(event2.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert - Only event3 should be returned + Assert.Single(events); + Assert.Equal(event3.EventId, events[0].EventId); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_ReturnsEmptyIfNoNewEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write one event and create a reader from that event (no events after it) + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Assert - No new events should be returned + Assert.Empty(events); + } + + [Fact] + public async Task ReadEventsAsync_InPollingMode_DoesNotWaitForNewEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write one event so we have a valid event ID, then create reader from it + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Should complete immediately without waiting (no new events after the one we started from) + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + stopwatch.Stop(); + + // Assert - Should complete quickly with no events + Assert.Empty(events); + Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Polling mode should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); + } + + #endregion + + #region ReadEventsAsync - Default Mode Tests + + [Fact] + public async Task ReadEventsAsync_InDefaultMode_WaitsForNewEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write one event so we have a valid event ID + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Start reading and then write a new event + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(2)); + var events = new List>(); + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + if (events.Count >= 1) + { + // Got the event we were waiting for, cancel to stop + await cts.CancelAsync(); + } + } + }, CancellationToken); + + // Wait a bit, then write a new event + await Task.Delay(100, CancellationToken); + var newEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Wait for read to complete (either event received or timeout) + try + { + await readTask; + } + catch (OperationCanceledException) + { + // Expected when we cancel after receiving event + } + + // Assert - Should have received the new event + Assert.Single(events); + Assert.Equal(newEvent.EventId, events[0].EventId); + } + + [Fact] + public async Task ReadEventsAsync_InDefaultMode_YieldsNewlyWrittenEvents() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write initial event + var initialEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(initialEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Write multiple events while reader is active + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(3)); + var events = new List>(); + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + if (events.Count >= 3) + { + await cts.CancelAsync(); + } + } + }, CancellationToken); + + // Write 3 new events + await Task.Delay(100, CancellationToken); + var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + try + { + await readTask; + } + catch (OperationCanceledException) + { + // Expected + } + + // Assert - Should have received all 3 events in order + Assert.Equal(3, events.Count); + Assert.Equal(event1.EventId, events[0].EventId); + Assert.Equal(event2.EventId, events[1].EventId); + Assert.Equal(event3.EventId, events[2].EventId); + } + + [Fact] + public async Task ReadEventsAsync_InDefaultMode_CompletesWhenStreamIsDisposed() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write event to create a valid reader + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Start reading, then dispose the stream + var events = new List>(); + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + }, CancellationToken); + + // Wait a bit, then dispose the writer + await Task.Delay(100, CancellationToken); + await writer.DisposeAsync(); + + // Wait for read to complete with a timeout + var timeoutTask = Task.Delay(TimeSpan.FromSeconds(2), CancellationToken); + var completedTask = await Task.WhenAny(readTask, timeoutTask); + + // Assert - The read should complete gracefully (not timeout) + Assert.Same(readTask, completedTask); + } + + [Fact] + public async Task ReadEventsAsync_InDefaultMode_RespectsCancellation() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write event to create a valid reader + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Act - Start reading and then cancel + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var events = new List>(); + OperationCanceledException? capturedException = null; + + var readTask = Task.Run(async () => + { + try + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + } + } + catch (OperationCanceledException ex) + { + capturedException = ex; + } + }, CancellationToken); + + // Cancel after a short delay + await Task.Delay(100, CancellationToken); + await cts.CancelAsync(); + + // Wait for read task to complete + await readTask; + stopwatch.Stop(); + + // Assert - Either cancelled exception or graceful exit, but should complete quickly + Assert.Empty(events); // No events should have been received + Assert.True(stopwatch.ElapsedMilliseconds < 1000, $"Should complete quickly after cancellation, took {stopwatch.ElapsedMilliseconds}ms"); + } + + #endregion + + #region ReadEventsAsync - Mode Changes Tests + + [Fact] + public async Task ReadEventsAsync_RespectsModeSwitchFromDefaultToPolling() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write an event to create a valid reader + var writtenEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Start reading in default mode (will wait for new events) + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(3)); + var events = new List>(); + var readCompleted = false; + + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + } + readCompleted = true; + }, CancellationToken); + + // Wait a bit, then switch to polling mode + await Task.Delay(100, CancellationToken); + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Wait for read task to complete (should complete quickly after mode switch) + var timeoutTask = Task.Delay(TimeSpan.FromSeconds(1), CancellationToken); + var completedTask = await Task.WhenAny(readTask, timeoutTask); + + // Assert - Read should have completed after switching to polling mode + Assert.Same(readTask, completedTask); + Assert.True(readCompleted); + Assert.Empty(events); // No new events were written after the one we used to create the reader + } + + [Fact] + public async Task ReadEventsAsync_PollingModeReturnsEventsThenCompletes() + { + // Arrange - Start in default mode, write some events, switch to polling, reader should return remaining events + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(50) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write initial event and create reader from sequence 0 + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); + var startEventId = $"{sessionBase64}:{streamBase64}:0"; + + // Write events first + var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Switch to polling mode + await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); + + // Get reader + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Act - Read should return events and complete immediately + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + stopwatch.Stop(); + + // Assert + Assert.Equal(2, events.Count); + Assert.Equal(event1.EventId, events[0].EventId); + Assert.Equal(event2.EventId, events[1].EventId); + Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); + } + + #endregion + + #region Cross-Session Isolation Tests + + [Fact] + public async Task MultipleStreams_AreIsolated_EventsDoNotLeakBetweenStreams() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Create two streams with different session/stream IDs + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-2", + StreamId = "stream-2", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write events to each stream + var event1 = await writer1.WriteEventAsync(new SseItem(null, "event-from-stream1"), CancellationToken); + var event2 = await writer2.WriteEventAsync(new SseItem(null, "event-from-stream2"), CancellationToken); + + // Create readers for each stream from sequence 0 + var session1Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); + var stream1Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); + var start1 = $"{session1Base64}:{stream1Base64}:0"; + + var session2Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-2")); + var stream2Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-2")); + var start2 = $"{session2Base64}:{stream2Base64}:0"; + + var reader1 = await store.GetStreamReaderAsync(start1, CancellationToken); + var reader2 = await store.GetStreamReaderAsync(start2, CancellationToken); + Assert.NotNull(reader1); + Assert.NotNull(reader2); + + // Act - Read from each reader + var events1 = new List>(); + await foreach (var evt in reader1.ReadEventsAsync(CancellationToken)) + { + events1.Add(evt); + } + + var events2 = new List>(); + await foreach (var evt in reader2.ReadEventsAsync(CancellationToken)) + { + events2.Add(evt); + } + + // Assert - Each reader should only see its own stream's events + Assert.Single(events1); + Assert.Equal("event-from-stream1", events1[0].EventType); + Assert.Equal(event1.EventId, events1[0].EventId); + + Assert.Single(events2); + Assert.Equal("event-from-stream2", events2[0].EventType); + Assert.Equal(event2.EventId, events2[0].EventId); + } + + [Fact] + public async Task MultipleStreams_SameSession_DifferentStreamIds_AreIsolated() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + // Create two streams with same session but different stream IDs + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "shared-session", + StreamId = "stream-A", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "shared-session", + StreamId = "stream-B", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write events to each stream + await writer1.WriteEventAsync(new SseItem(null, "from-A"), CancellationToken); + await writer2.WriteEventAsync(new SseItem(null, "from-B"), CancellationToken); + + // Create readers from sequence 0 + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("shared-session")); + var streamABase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-A")); + var streamBBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-B")); + + var reader1 = await store.GetStreamReaderAsync($"{sessionBase64}:{streamABase64}:0", CancellationToken); + var reader2 = await store.GetStreamReaderAsync($"{sessionBase64}:{streamBBase64}:0", CancellationToken); + Assert.NotNull(reader1); + Assert.NotNull(reader2); + + // Act + var events1 = new List>(); + await foreach (var evt in reader1.ReadEventsAsync(CancellationToken)) + { + events1.Add(evt); + } + + var events2 = new List>(); + await foreach (var evt in reader2.ReadEventsAsync(CancellationToken)) + { + events2.Add(evt); + } + + // Assert + Assert.Single(events1); + Assert.Equal("from-A", events1[0].EventType); + + Assert.Single(events2); + Assert.Equal("from-B", events2[0].EventType); + } + + [Fact] + public async Task EventIds_AreGloballyUnique_AcrossStreams() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + + var writer1 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var writer2 = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-2", + StreamId = "stream-2", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Act - Write events to each stream + var event1a = await writer1.WriteEventAsync(new SseItem(null), CancellationToken); + var event1b = await writer1.WriteEventAsync(new SseItem(null), CancellationToken); + var event2a = await writer2.WriteEventAsync(new SseItem(null), CancellationToken); + var event2b = await writer2.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert - All event IDs should be unique + var allEventIds = new[] { event1a.EventId, event1b.EventId, event2a.EventId, event2b.EventId }; + Assert.Equal(4, allEventIds.Distinct().Count()); + } + + #endregion + + #region Distributed Cache Integration Tests + + [Fact] + public async Task WriteEventAsync_UsesConfiguredSlidingExpiration() + { + // Arrange + var mockCache = new TrackingDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventSlidingExpiration = TimeSpan.FromMinutes(30) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + mockCache.SetCalls.Clear(); + + // Act + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert - Event should be written with the configured sliding expiration + Assert.Contains(mockCache.SetCalls, call => + call.Key.Contains("event:") && + call.Options.SlidingExpiration == TimeSpan.FromMinutes(30)); + } + + [Fact] + public async Task WriteEventAsync_UsesConfiguredAbsoluteExpiration() + { + // Arrange + var mockCache = new TrackingDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + EventAbsoluteExpiration = TimeSpan.FromHours(6) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + mockCache.SetCalls.Clear(); + + // Act + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert - Event should be written with the configured absolute expiration (relative to now) + var eventCall = mockCache.SetCalls.FirstOrDefault(call => call.Key.Contains("event:")); + Assert.NotNull(eventCall.Key); + Assert.NotNull(eventCall.Options.AbsoluteExpirationRelativeToNow); + Assert.Equal(TimeSpan.FromHours(6), eventCall.Options.AbsoluteExpirationRelativeToNow); + } + + [Fact] + public async Task WriteEventAsync_UsesConfiguredMetadataExpiration() + { + // Arrange - Metadata is written when events are written + var mockCache = new TrackingDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions + { + MetadataSlidingExpiration = TimeSpan.FromMinutes(45), + MetadataAbsoluteExpiration = TimeSpan.FromHours(12) + }; + var store = new DistributedCacheEventStreamStore(mockCache, customOptions); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Act - Write an event, which also updates metadata + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Assert + var metadataCall = mockCache.SetCalls.FirstOrDefault(call => call.Key.Contains("meta:")); + Assert.NotNull(metadataCall.Key); + Assert.Equal(TimeSpan.FromMinutes(45), metadataCall.Options.SlidingExpiration); + Assert.Equal(TimeSpan.FromHours(12), metadataCall.Options.AbsoluteExpirationRelativeToNow); + } + + #endregion + + #region Options Configuration Tests + + [Fact] + public async Task CustomPollingInterval_AffectsDefaultModePolling() + { + // Arrange - Use very short polling interval + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + { + PollingInterval = TimeSpan.FromMilliseconds(20) + }); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Default + }, CancellationToken); + + // Write an initial event + var initialEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var reader = await store.GetStreamReaderAsync(initialEvent.EventId!, CancellationToken); + Assert.NotNull(reader); + + // Start reading and measure time to receive a new event + using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(2)); + var events = new List>(); + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + + var readTask = Task.Run(async () => + { + await foreach (var evt in reader.ReadEventsAsync(cts.Token)) + { + events.Add(evt); + if (events.Count >= 1) + { + await cts.CancelAsync(); + } + } + }, CancellationToken); + + // Wait a bit, write event, measure time for it to be detected + await Task.Delay(50, CancellationToken); + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + try + { + await readTask; + } + catch (OperationCanceledException) + { + // Expected + } + + stopwatch.Stop(); + + // Assert - Event should be detected quickly due to short polling interval + Assert.Single(events); + // With 20ms polling, detection should be fast (well under 500ms) + Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Expected quick detection, took {stopwatch.ElapsedMilliseconds}ms"); + } + + [Fact] + public void DefaultOptions_HaveReasonableDefaults() + { + // Arrange & Act + var options = new DistributedCacheEventStreamStoreOptions(); + + // Assert - Check that defaults are set reasonably + Assert.True(options.PollingInterval >= TimeSpan.FromMilliseconds(50), "Polling interval should be at least 50ms"); + Assert.True(options.EventSlidingExpiration > TimeSpan.Zero, "Event sliding expiration should be positive"); + Assert.True(options.EventAbsoluteExpiration > TimeSpan.Zero, "Event absolute expiration should be positive"); + Assert.True(options.MetadataSlidingExpiration > TimeSpan.Zero, "Metadata sliding expiration should be positive"); + Assert.True(options.MetadataAbsoluteExpiration > TimeSpan.Zero, "Metadata absolute expiration should be positive"); + } + + #endregion + + #region Helper Classes + + /// + /// A distributed cache that tracks all operations for verification in tests. + /// + private sealed class TrackingDistributedCache : IDistributedCache + { + private readonly MemoryDistributedCache _innerCache = new(Options.Create(new MemoryDistributedCacheOptions())); + + public List<(string Key, DistributedCacheEntryOptions Options)> SetCalls { get; } = []; + + public byte[]? Get(string key) => _innerCache.Get(key); + public Task GetAsync(string key, CancellationToken token = default) => _innerCache.GetAsync(key, token); + public void Refresh(string key) => _innerCache.Refresh(key); + public Task RefreshAsync(string key, CancellationToken token = default) => _innerCache.RefreshAsync(key, token); + public void Remove(string key) => _innerCache.Remove(key); + public Task RemoveAsync(string key, CancellationToken token = default) => _innerCache.RemoveAsync(key, token); + + public void Set(string key, byte[] value, DistributedCacheEntryOptions options) + { + SetCalls.Add((key, options)); + _innerCache.Set(key, value, options); + } + + public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions options, CancellationToken token = default) + { + SetCalls.Add((key, options)); + return _innerCache.SetAsync(key, value, options, token); + } + } + + #endregion +} From e34492781f69034a60de002d2ce5b78562a10809 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 6 Jan 2026 15:26:20 -0800 Subject: [PATCH 02/18] Small fixes and test improvements --- .../DistributedCacheEventIdFormatter.cs | 61 +++ .../DistributedCacheEventStreamStore.cs | 109 ++--- .../ModelContextProtocol.Tests.csproj | 4 + .../DistributedCacheEventStreamStoreTests.cs | 424 ++++++++++++------ 4 files changed, 383 insertions(+), 215 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs new file mode 100644 index 000000000..8aec3f0bd --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// This is a shared source file included in both ModelContextProtocol.Core and the test project. +// Do not reference symbols internal to the core project, as they won't be available in tests. + +using System; + +namespace ModelContextProtocol.Server; + +/// +/// Provides methods for formatting and parsing event IDs used by . +/// +/// +/// Event IDs are formatted as "{base64(sessionId)}:{base64(streamId)}:{sequence}". +/// Base64 encoding is used because the MCP specification allows session IDs to contain +/// any visible ASCII character (0x21-0x7E), including the ':' separator character. +/// +internal static class DistributedCacheEventIdFormatter +{ + private const char Separator = ':'; + + /// + /// Formats session ID, stream ID, and sequence number into an event ID string. + /// + public static string Format(string sessionId, string streamId, long sequence) + { + // Base64-encode session and stream IDs so the event ID can be parsed + // even if the original IDs contain the ':' separator character + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId)); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(streamId)); + return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}"; + } + + /// + /// Attempts to parse an event ID into its component parts. + /// + public static bool TryParse(string eventId, out string sessionId, out string streamId, out long sequence) + { + sessionId = string.Empty; + streamId = string.Empty; + sequence = 0; + + var parts = eventId.Split(Separator); + if (parts.Length != 3) + { + return false; + } + + try + { + sessionId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[0])); + streamId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[1])); + return long.TryParse(parts[2], out sequence); + } + catch + { + return false; + } + } +} diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs index 02f6b8be4..1b93c1d12 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -51,7 +51,7 @@ public ValueTask CreateStreamAsync(SseEventStreamOptions Throw.IfNull(lastEventId); // Parse the event ID to get session, stream, and sequence information - if (!EventIdCodec.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence)) + if (!DistributedCacheEventIdFormatter.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence)) { return null; } @@ -70,54 +70,8 @@ public ValueTask CreateStreamAsync(SseEventStreamOptions return null; } - return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, sequence, metadata, _options); - } - - /// - /// Provides methods for encoding and decoding event IDs. - /// - internal static class EventIdCodec - { - private const char Separator = ':'; - - /// - /// Encodes session ID, stream ID, and sequence number into an event ID string. - /// - public static string Encode(string sessionId, string streamId, long sequence) - { - // Base64-encode session and stream IDs so the event ID can be parsed - // even if the original IDs contain the ':' separator character - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId)); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(streamId)); - return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}"; - } - - /// - /// Attempts to parse an event ID into its component parts. - /// - public static bool TryParse(string eventId, out string sessionId, out string streamId, out long sequence) - { - sessionId = string.Empty; - streamId = string.Empty; - sequence = 0; - - var parts = eventId.Split(Separator); - if (parts.Length != 3) - { - return false; - } - - try - { - sessionId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[0])); - streamId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[1])); - return long.TryParse(parts[2], out sequence); - } - catch - { - return false; - } - } + var startSequence = sequence + 1; + return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, startSequence, metadata, _options); } /// @@ -198,7 +152,7 @@ public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken c // Generate a new sequence number and event ID var sequence = Interlocked.Increment(ref _sequence); - var eventId = EventIdCodec.Encode(SessionId, StreamId, sequence); + var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, sequence); var newItem = sseItem with { EventId = eventId }; // Store the event in the cache @@ -261,7 +215,7 @@ private sealed class DistributedCacheEventStreamReader : ISseEventStreamReader { private readonly IDistributedCache _cache; private readonly long _startSequence; - private readonly StreamMetadata _metadata; + private readonly StreamMetadata _initialMetadata; private readonly DistributedCacheEventStreamStoreOptions _options; public DistributedCacheEventStreamReader( @@ -269,14 +223,14 @@ public DistributedCacheEventStreamReader( string sessionId, string streamId, long startSequence, - StreamMetadata metadata, + StreamMetadata initialMetadata, DistributedCacheEventStreamStoreOptions options) { _cache = cache; SessionId = sessionId; StreamId = streamId; _startSequence = startSequence; - _metadata = metadata; + _initialMetadata = initialMetadata; _options = options; } @@ -288,36 +242,25 @@ public DistributedCacheEventStreamReader( // Start from the sequence after the last received event var currentSequence = _startSequence; + // Use the initial metadata passed to the constructor for the first read. + var lastSequence = _initialMetadata.LastSequence; + var isCompleted = _initialMetadata.IsCompleted; + var mode = _initialMetadata.Mode; + while (!cancellationToken.IsCancellationRequested) { - // Refresh metadata to get the latest sequence and completion status - var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); - var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); - - StreamMetadata? currentMetadata = null; - if (metadataBytes is not null) - { - currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); - } - - var lastSequence = currentMetadata?.LastSequence ?? _metadata.LastSequence; - var isCompleted = currentMetadata?.IsCompleted ?? _metadata.IsCompleted; - var mode = currentMetadata?.Mode ?? _metadata.Mode; - // Read all available events from currentSequence + 1 to lastSequence - while (currentSequence < lastSequence) + for (; currentSequence <= lastSequence; currentSequence++) { cancellationToken.ThrowIfCancellationRequested(); - var nextSequence = currentSequence + 1; - var eventId = EventIdCodec.Encode(SessionId, StreamId, nextSequence); + var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, currentSequence); var eventKey = CacheKeys.Event(eventId); var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false); if (eventBytes is null) { // Event may have expired; skip to next - currentSequence = nextSequence; continue; } @@ -329,8 +272,6 @@ public DistributedCacheEventStreamReader( EventId = storedEvent.EventId, }; } - - currentSequence = nextSequence; } // If in polling mode, stop after returning currently available events @@ -339,7 +280,7 @@ public DistributedCacheEventStreamReader( yield break; } - // If the stream is completed, stop + // If the stream is completed and we've read all events, stop if (isCompleted) { yield break; @@ -347,6 +288,26 @@ public DistributedCacheEventStreamReader( // Wait before polling again for new events await Task.Delay(_options.PollingInterval, cancellationToken).ConfigureAwait(false); + + // Refresh metadata to get the latest sequence and completion status + var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); + var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); + + if (metadataBytes is null) + { + // Metadata expired - treat stream as complete to avoid infinite loop + yield break; + } + + var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); + if (currentMetadata is null) + { + yield break; + } + + lastSequence = currentMetadata.LastSequence; + isCompleted = currentMetadata.IsCompleted; + mode = currentMetadata.Mode; } } } diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 8bf01df43..84b0ee994 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -29,6 +29,10 @@ + + + + diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index b061acdfc..6245b4ccb 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -13,7 +13,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; /// public class DistributedCacheEventStreamStoreTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { - private CancellationToken CancellationToken => TestContext.Current.CancellationToken; + private static CancellationToken CancellationToken => TestContext.Current.CancellationToken; private static IDistributedCache CreateMemoryCache() { @@ -21,8 +21,6 @@ private static IDistributedCache CreateMemoryCache() return new MemoryDistributedCache(options); } - #region Constructor & Initialization Tests - [Fact] public void Constructor_ThrowsArgumentNullException_WhenCacheIsNull() { @@ -80,10 +78,6 @@ public async Task Constructor_UsesProvidedOptions_WhenOptionsParameterIsSpecifie Assert.Equal("stream-1", writer.StreamId); } - #endregion - - #region CreateStreamAsync Tests - [Fact] public async Task CreateStreamAsync_ReturnsWriter_WithCorrectStreamId() { @@ -181,10 +175,6 @@ await Assert.ThrowsAsync("options", async () => await store.CreateStreamAsync(null!, CancellationToken)); } - #endregion - - #region WriteEventAsync Tests - [Fact] public async Task WriteEventAsync_AssignsUniqueEventId_WhenItemHasNoEventId() { @@ -306,7 +296,7 @@ public async Task WriteEventAsync_HandlesNullData_AssignsEventIdAndStoresEvent() public async Task WriteEventAsync_StoresEventWithCorrectSlidingExpiration() { // Arrange - Use a mock cache to verify expiration options - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var customOptions = new DistributedCacheEventStreamStoreOptions { EventSlidingExpiration = TimeSpan.FromMinutes(15) @@ -334,7 +324,7 @@ public async Task WriteEventAsync_StoresEventWithCorrectSlidingExpiration() public async Task WriteEventAsync_StoresEventWithCorrectAbsoluteExpiration() { // Arrange - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var customOptions = new DistributedCacheEventStreamStoreOptions { EventAbsoluteExpiration = TimeSpan.FromHours(3) @@ -362,7 +352,7 @@ public async Task WriteEventAsync_StoresEventWithCorrectAbsoluteExpiration() public async Task WriteEventAsync_UpdatesStreamMetadata_AfterEachWrite() { // Arrange - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var store = new DistributedCacheEventStreamStore(mockCache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { @@ -380,10 +370,6 @@ public async Task WriteEventAsync_UpdatesStreamMetadata_AfterEachWrite() Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); } - #endregion - - #region SetModeAsync (Writer) Tests - [Fact] public async Task SetModeAsync_UpdatesModeProperty_OnWriter() { @@ -408,7 +394,7 @@ public async Task SetModeAsync_UpdatesModeProperty_OnWriter() public async Task SetModeAsync_PersistsModeChangeToMetadata() { // Arrange - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var store = new DistributedCacheEventStreamStore(mockCache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { @@ -466,10 +452,6 @@ public async Task SetModeAsync_ModeChangeReflectedInReader() Assert.Empty(events); // No events after the one we used to create the reader } - #endregion - - #region DisposeAsync (Writer) Tests - [Fact] public async Task DisposeAsync_MarksStreamAsCompleted() { @@ -531,7 +513,7 @@ public async Task DisposeAsync_IsIdempotent() public async Task DisposeAsync_UpdatesMetadata_WithIsCompletedFlag() { // Arrange - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var store = new DistributedCacheEventStreamStore(mockCache); var writer = await store.CreateStreamAsync(new SseEventStreamOptions { @@ -549,10 +531,6 @@ public async Task DisposeAsync_UpdatesMetadata_WithIsCompletedFlag() Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); } - #endregion - - #region GetStreamReaderAsync Tests - [Fact] public async Task GetStreamReaderAsync_ThrowsArgumentNullException_WhenLastEventIdIsNull() { @@ -591,9 +569,7 @@ public async Task GetStreamReaderAsync_ReturnsNull_WhenStreamMetadataDoesNotExis var store = new DistributedCacheEventStreamStore(cache); // Create a valid-looking event ID for a stream that doesn't exist - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("nonexistent-session")); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("nonexistent-stream")); - var fakeEventId = $"{sessionBase64}:{streamBase64}:1"; + var fakeEventId = DistributedCacheEventIdFormatter.Format("nonexistent-session", "nonexistent-stream", 1); // Act var reader = await store.GetStreamReaderAsync(fakeEventId, CancellationToken); @@ -628,10 +604,6 @@ public async Task GetStreamReaderAsync_ReturnsReaderWithCorrectSessionIdAndStrea Assert.Equal("my-stream", reader.StreamId); } - #endregion - - #region ReadEventsAsync (Reader) Tests - [Fact] public async Task ReadEventsAsync_ReturnsEventsInOrder() { @@ -651,9 +623,7 @@ public async Task ReadEventsAsync_ReturnsEventsInOrder() var event3 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method3" }), CancellationToken); // Create a reader starting from before the first event (use a fake event ID with sequence 0) - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); - var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); Assert.NotNull(reader); @@ -719,9 +689,7 @@ public async Task ReadEventsAsync_PreservesCorrectDataEventTypeAndEventId() var writtenItem = await writer.WriteEventAsync(new SseItem(message, "custom-event-type"), CancellationToken); // Create a reader starting from before the event - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); - var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); Assert.NotNull(reader); @@ -758,9 +726,7 @@ public async Task ReadEventsAsync_HandlesNullData() var writtenItem = await writer.WriteEventAsync(new SseItem(null), CancellationToken); // Create a reader starting from before the event - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); - var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); Assert.NotNull(reader); @@ -777,10 +743,6 @@ public async Task ReadEventsAsync_HandlesNullData() Assert.Equal(writtenItem.EventId, events[0].EventId); } - #endregion - - #region ReadEventsAsync - Polling Mode Tests - [Fact] public async Task ReadEventsAsync_InPollingMode_CompletesImmediatelyAfterReturningAvailableEvents() { @@ -799,9 +761,7 @@ public async Task ReadEventsAsync_InPollingMode_CompletesImmediatelyAfterReturni await writer.WriteEventAsync(new SseItem(null), CancellationToken); // Create a reader from sequence 0 - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); - var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); Assert.NotNull(reader); @@ -914,10 +874,6 @@ public async Task ReadEventsAsync_InPollingMode_DoesNotWaitForNewEvents() Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Polling mode should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); } - #endregion - - #region ReadEventsAsync - Default Mode Tests - [Fact] public async Task ReadEventsAsync_InDefaultMode_WaitsForNewEvents() { @@ -1132,10 +1088,6 @@ public async Task ReadEventsAsync_InDefaultMode_RespectsCancellation() Assert.True(stopwatch.ElapsedMilliseconds < 1000, $"Should complete quickly after cancellation, took {stopwatch.ElapsedMilliseconds}ms"); } - #endregion - - #region ReadEventsAsync - Mode Changes Tests - [Fact] public async Task ReadEventsAsync_RespectsModeSwitchFromDefaultToPolling() { @@ -1203,9 +1155,7 @@ public async Task ReadEventsAsync_PollingModeReturnsEventsThenCompletes() }, CancellationToken); // Write initial event and create reader from sequence 0 - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); - var startEventId = $"{sessionBase64}:{streamBase64}:0"; + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); // Write events first var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); @@ -1234,10 +1184,6 @@ public async Task ReadEventsAsync_PollingModeReturnsEventsThenCompletes() Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Should complete quickly, took {stopwatch.ElapsedMilliseconds}ms"); } - #endregion - - #region Cross-Session Isolation Tests - [Fact] public async Task MultipleStreams_AreIsolated_EventsDoNotLeakBetweenStreams() { @@ -1265,13 +1211,8 @@ public async Task MultipleStreams_AreIsolated_EventsDoNotLeakBetweenStreams() var event2 = await writer2.WriteEventAsync(new SseItem(null, "event-from-stream2"), CancellationToken); // Create readers for each stream from sequence 0 - var session1Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-1")); - var stream1Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-1")); - var start1 = $"{session1Base64}:{stream1Base64}:0"; - - var session2Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session-2")); - var stream2Base64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-2")); - var start2 = $"{session2Base64}:{stream2Base64}:0"; + var start1 = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var start2 = DistributedCacheEventIdFormatter.Format("session-2", "stream-2", 0); var reader1 = await store.GetStreamReaderAsync(start1, CancellationToken); var reader2 = await store.GetStreamReaderAsync(start2, CancellationToken); @@ -1328,12 +1269,8 @@ public async Task MultipleStreams_SameSession_DifferentStreamIds_AreIsolated() await writer2.WriteEventAsync(new SseItem(null, "from-B"), CancellationToken); // Create readers from sequence 0 - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("shared-session")); - var streamABase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-A")); - var streamBBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream-B")); - - var reader1 = await store.GetStreamReaderAsync($"{sessionBase64}:{streamABase64}:0", CancellationToken); - var reader2 = await store.GetStreamReaderAsync($"{sessionBase64}:{streamBBase64}:0", CancellationToken); + var reader1 = await store.GetStreamReaderAsync(DistributedCacheEventIdFormatter.Format("shared-session", "stream-A", 0), CancellationToken); + var reader2 = await store.GetStreamReaderAsync(DistributedCacheEventIdFormatter.Format("shared-session", "stream-B", 0), CancellationToken); Assert.NotNull(reader1); Assert.NotNull(reader2); @@ -1390,15 +1327,11 @@ public async Task EventIds_AreGloballyUnique_AcrossStreams() Assert.Equal(4, allEventIds.Distinct().Count()); } - #endregion - - #region Distributed Cache Integration Tests - [Fact] public async Task WriteEventAsync_UsesConfiguredSlidingExpiration() { // Arrange - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var customOptions = new DistributedCacheEventStreamStoreOptions { EventSlidingExpiration = TimeSpan.FromMinutes(30) @@ -1426,7 +1359,7 @@ public async Task WriteEventAsync_UsesConfiguredSlidingExpiration() public async Task WriteEventAsync_UsesConfiguredAbsoluteExpiration() { // Arrange - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var customOptions = new DistributedCacheEventStreamStoreOptions { EventAbsoluteExpiration = TimeSpan.FromHours(6) @@ -1455,7 +1388,7 @@ public async Task WriteEventAsync_UsesConfiguredAbsoluteExpiration() public async Task WriteEventAsync_UsesConfiguredMetadataExpiration() { // Arrange - Metadata is written when events are written - var mockCache = new TrackingDistributedCache(); + var mockCache = new TestDistributedCache(); var customOptions = new DistributedCacheEventStreamStoreOptions { MetadataSlidingExpiration = TimeSpan.FromMinutes(45), @@ -1479,99 +1412,310 @@ public async Task WriteEventAsync_UsesConfiguredMetadataExpiration() Assert.Equal(TimeSpan.FromHours(12), metadataCall.Options.AbsoluteExpirationRelativeToNow); } - #endregion + [Fact] + public void DefaultOptions_HaveReasonableDefaults() + { + // Arrange & Act + var options = new DistributedCacheEventStreamStoreOptions(); - #region Options Configuration Tests + // Assert - Check that defaults are set reasonably + Assert.True(options.PollingInterval >= TimeSpan.FromMilliseconds(50), "Polling interval should be at least 50ms"); + Assert.True(options.EventSlidingExpiration > TimeSpan.Zero, "Event sliding expiration should be positive"); + Assert.True(options.EventAbsoluteExpiration > TimeSpan.Zero, "Event absolute expiration should be positive"); + Assert.True(options.MetadataSlidingExpiration > TimeSpan.Zero, "Metadata sliding expiration should be positive"); + Assert.True(options.MetadataAbsoluteExpiration > TimeSpan.Zero, "Metadata absolute expiration should be positive"); + } [Fact] - public async Task CustomPollingInterval_AffectsDefaultModePolling() + public async Task ReadEventsAsync_Completes_WhenMetadataExpires() { - // Arrange - Use very short polling interval - var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, new DistributedCacheEventStreamStoreOptions + // Arrange - Use a cache that allows us to simulate metadata expiration + var trackingCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions { - PollingInterval = TimeSpan.FromMilliseconds(20) - }); + PollingInterval = TimeSpan.FromMilliseconds(10) // Fast polling to detect the bug quickly + }; + var store = new DistributedCacheEventStreamStore(trackingCache, customOptions); + + // Create a stream and write an event var writer = await store.CreateStreamAsync(new SseEventStreamOptions { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Default // Non-polling mode to trigger the waiting loop }, CancellationToken); - // Write an initial event - var initialEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); - var reader = await store.GetStreamReaderAsync(initialEvent.EventId!, CancellationToken); + var item = new SseItem(new JsonRpcNotification { Method = "test" }); + var writtenItem = await writer.WriteEventAsync(item, CancellationToken); + + // Get a reader starting after the first event (so it will wait for more events) + var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); Assert.NotNull(reader); - // Start reading and measure time to receive a new event - using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - cts.CancelAfter(TimeSpan.FromSeconds(2)); + // Now simulate metadata expiration + trackingCache.ExpireMetadata(); + + // Act - Read events; the reader should complete gracefully when metadata expires + // instead of looping indefinitely with the stale initial metadata var events = new List>(); - var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } - var readTask = Task.Run(async () => + // If we reach here without timeout, the reader correctly handled metadata expiration + Assert.Empty(events); // No new events after the initial one used to create the reader + } + + [Fact] + public async Task ReadEventsAsync_DoesNotReadMetadata_InPollingMode() + { + // Arrange - Use a tracking cache to count metadata reads + var trackingCache = new TestDistributedCache(); + var customOptions = new DistributedCacheEventStreamStoreOptions { - await foreach (var evt in reader.ReadEventsAsync(cts.Token)) - { - events.Add(evt); - if (events.Count >= 1) - { - await cts.CancelAsync(); - } - } + PollingInterval = TimeSpan.FromMilliseconds(10) + }; + var store = new DistributedCacheEventStreamStore(trackingCache, customOptions); + + // Create a stream in POLLING mode - this allows the reader to exit after reading available events + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling }, CancellationToken); - // Wait a bit, write event, measure time for it to be detected - await Task.Delay(50, CancellationToken); - await writer.WriteEventAsync(new SseItem(null), CancellationToken); + var item1 = new SseItem(new JsonRpcNotification { Method = "test1" }); + var item2 = new SseItem(new JsonRpcNotification { Method = "test2" }); + await writer.WriteEventAsync(item1, CancellationToken); + await writer.WriteEventAsync(item2, CancellationToken); - try - { - await readTask; - } - catch (OperationCanceledException) + // Get a reader starting before all events (use a fake event ID at sequence 0) + var zeroSequenceEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(zeroSequenceEventId, CancellationToken); + Assert.NotNull(reader); + + // GetStreamReaderAsync should have read metadata exactly once + Assert.Equal(1, trackingCache.MetadataReadCount); + + // Act - Read all events + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) { - // Expected + events.Add(evt); } - stopwatch.Stop(); + // Assert - In polling mode, the reader should: + // 1. Use initial metadata from GetStreamReaderAsync (no additional read needed) + // 2. Read all available events (2 events) + // 3. Exit immediately because mode is Polling + // + // Metadata read count should remain at 1 (only the initial read from GetStreamReaderAsync) + Assert.Equal(2, events.Count); + Assert.Equal(1, trackingCache.MetadataReadCount); + } - // Assert - Event should be detected quickly due to short polling interval - Assert.Single(events); - // With 20ms polling, detection should be fast (well under 500ms) - Assert.True(stopwatch.ElapsedMilliseconds < 500, $"Expected quick detection, took {stopwatch.ElapsedMilliseconds}ms"); + [Fact] + public void EventIdFormatter_Format_CreatesValidEventId() + { + // Act + var eventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 42); + + // Assert + Assert.NotNull(eventId); + Assert.NotEmpty(eventId); + Assert.Contains(":", eventId); // Should contain separators } [Fact] - public void DefaultOptions_HaveReasonableDefaults() + public void EventIdFormatter_TryParse_RoundTripsSuccessfully() { - // Arrange & Act - var options = new DistributedCacheEventStreamStoreOptions(); + // Arrange + var originalSessionId = "my-session-id"; + var originalStreamId = "my-stream-id"; + var originalSequence = 12345L; - // Assert - Check that defaults are set reasonably - Assert.True(options.PollingInterval >= TimeSpan.FromMilliseconds(50), "Polling interval should be at least 50ms"); - Assert.True(options.EventSlidingExpiration > TimeSpan.Zero, "Event sliding expiration should be positive"); - Assert.True(options.EventAbsoluteExpiration > TimeSpan.Zero, "Event absolute expiration should be positive"); - Assert.True(options.MetadataSlidingExpiration > TimeSpan.Zero, "Metadata sliding expiration should be positive"); - Assert.True(options.MetadataAbsoluteExpiration > TimeSpan.Zero, "Metadata absolute expiration should be positive"); + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); } - #endregion + [Fact] + public void EventIdFormatter_TryParse_HandlesSpecialCharactersInSessionId() + { + // Arrange - Session IDs can contain any visible ASCII character per MCP spec + var originalSessionId = "session:with:colons:and|pipes"; + var originalStreamId = "stream-1"; + var originalSequence = 1L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } - #region Helper Classes + [Fact] + public void EventIdFormatter_TryParse_HandlesSpecialCharactersInStreamId() + { + // Arrange + var originalSessionId = "session-1"; + var originalStreamId = "stream:with:colons:and|special!chars@#$%"; + var originalSequence = 1L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesUnicodeCharacters() + { + // Arrange + var originalSessionId = "session-日本語-émojis-🎉"; + var originalStreamId = "stream-中文-العربية"; + var originalSequence = 999L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesZeroSequence() + { + // Act + var eventId = DistributedCacheEventIdFormatter.Format("session", "stream", 0); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out _, out _, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(0, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_HandlesLargeSequence() + { + // Act + var eventId = DistributedCacheEventIdFormatter.Format("session", "stream", long.MaxValue); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out _, out _, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(long.MaxValue, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForEmptyString() + { + // Act + var parsed = DistributedCacheEventIdFormatter.TryParse("", out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.False(parsed); + Assert.Equal(string.Empty, sessionId); + Assert.Equal(string.Empty, streamId); + Assert.Equal(0, sequence); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForInvalidFormat() + { + // Act & Assert - Various invalid formats + Assert.False(DistributedCacheEventIdFormatter.TryParse("no-separators", out _, out _, out _)); + Assert.False(DistributedCacheEventIdFormatter.TryParse("only:one", out _, out _, out _)); + Assert.False(DistributedCacheEventIdFormatter.TryParse("too:many:parts:here", out _, out _, out _)); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForInvalidBase64() + { + // Act - Invalid base64 in first part + var parsed = DistributedCacheEventIdFormatter.TryParse("!!!invalid!!!:c3RyZWFt:1", out _, out _, out _); + + // Assert + Assert.False(parsed); + } + + [Fact] + public void EventIdFormatter_TryParse_ReturnsFalse_ForNonNumericSequence() + { + // Arrange - Valid base64 but non-numeric sequence + var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("session")); + var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes("stream")); + var invalidEventId = $"{sessionBase64}:{streamBase64}:not-a-number"; + + // Act + var parsed = DistributedCacheEventIdFormatter.TryParse(invalidEventId, out _, out _, out _); + + // Assert + Assert.False(parsed); + } /// /// A distributed cache that tracks all operations for verification in tests. + /// Supports tracking Set calls, counting metadata reads, and simulating metadata expiration. /// - private sealed class TrackingDistributedCache : IDistributedCache + private sealed class TestDistributedCache : IDistributedCache { private readonly MemoryDistributedCache _innerCache = new(Options.Create(new MemoryDistributedCacheOptions())); + private int _metadataReadCount; + private bool _metadataExpired; public List<(string Key, DistributedCacheEntryOptions Options)> SetCalls { get; } = []; + public int MetadataReadCount => _metadataReadCount; + + public void ExpireMetadata() => _metadataExpired = true; + + public byte[]? Get(string key) + { + if (key.Contains("meta:")) + { + Interlocked.Increment(ref _metadataReadCount); + if (_metadataExpired) + { + return null; + } + } + return _innerCache.Get(key); + } + + public Task GetAsync(string key, CancellationToken token = default) + { + if (key.Contains("meta:")) + { + Interlocked.Increment(ref _metadataReadCount); + if (_metadataExpired) + { + return Task.FromResult(null); + } + } + return _innerCache.GetAsync(key, token); + } - public byte[]? Get(string key) => _innerCache.Get(key); - public Task GetAsync(string key, CancellationToken token = default) => _innerCache.GetAsync(key, token); public void Refresh(string key) => _innerCache.Refresh(key); public Task RefreshAsync(string key, CancellationToken token = default) => _innerCache.RefreshAsync(key, token); public void Remove(string key) => _innerCache.Remove(key); @@ -1589,6 +1733,4 @@ public Task SetAsync(string key, byte[] value, DistributedCacheEntryOptions opti return _innerCache.SetAsync(key, value, options, token); } } - - #endregion } From f70bd12e738c3bb4172eb9b724e16de839d28388 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Fri, 9 Jan 2026 14:05:42 -0800 Subject: [PATCH 03/18] Adjust for latest changes to `ISseEventStreamStore` --- .../DistributedCacheEventStreamStore.cs | 21 +- ...DistributedCacheEventStreamStoreOptions.cs | 2 +- .../DistributedCacheEventStreamStoreTests.cs | 206 +++--------------- 3 files changed, 36 insertions(+), 193 deletions(-) diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs index 1b93c1d12..8a9e35c33 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -114,6 +114,9 @@ internal sealed class StoredEvent private sealed class DistributedCacheEventStreamWriter : ISseEventStreamWriter { private readonly IDistributedCache _cache; + private readonly string _sessionId; + private readonly string _streamId; + private SseEventStreamMode _mode; private readonly DistributedCacheEventStreamStoreOptions _options; private long _sequence; private bool _disposed; @@ -126,19 +129,15 @@ public DistributedCacheEventStreamWriter( DistributedCacheEventStreamStoreOptions options) { _cache = cache; - SessionId = sessionId; - StreamId = streamId; - Mode = mode; + _sessionId = sessionId; + _streamId = streamId; + _mode = mode; _options = options; } - public string SessionId { get; } - public string StreamId { get; } - public SseEventStreamMode Mode { get; private set; } - public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) { - Mode = mode; + _mode = mode; await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); } @@ -152,7 +151,7 @@ public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken c // Generate a new sequence number and event ID var sequence = Interlocked.Increment(ref _sequence); - var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, sequence); + var eventId = DistributedCacheEventIdFormatter.Format(_sessionId, _streamId, sequence); var newItem = sseItem with { EventId = eventId }; // Store the event in the cache @@ -182,13 +181,13 @@ private async ValueTask UpdateMetadataAsync(CancellationToken cancellationToken) { var metadata = new StreamMetadata { - Mode = Mode, + Mode = _mode, IsCompleted = _disposed, LastSequence = Interlocked.Read(ref _sequence), }; var metadataBytes = JsonSerializer.SerializeToUtf8Bytes(metadata, McpJsonUtilities.JsonContext.Default.StreamMetadata); - var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); + var metadataKey = CacheKeys.StreamMetadata(_sessionId, _streamId); await _cache.SetAsync(metadataKey, metadataBytes, new DistributedCacheEntryOptions { diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs index b6641d3fe..1c8452136 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStoreOptions.cs @@ -41,7 +41,7 @@ public sealed class DistributedCacheEventStreamStoreOptions /// /// Gets or sets the interval between polling attempts when a reader is waiting for new events - /// in mode. + /// in mode. /// /// /// This only affects readers. A shorter interval provides lower latency for new events diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index 6245b4ccb..a78b77b41 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -27,142 +27,6 @@ public void Constructor_ThrowsArgumentNullException_WhenCacheIsNull() Assert.Throws("cache", () => new DistributedCacheEventStreamStore(null!)); } - [Fact] - public async Task Constructor_UsesDefaultOptions_WhenOptionsParameterIsNull() - { - // Arrange - var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache, options: null); - - // Act - Create a stream to verify the store works with default options - var streamOptions = new SseEventStreamOptions - { - SessionId = "session-1", - StreamId = "stream-1", - Mode = SseEventStreamMode.Default - }; - var writer = await store.CreateStreamAsync(streamOptions, CancellationToken); - - // Assert - The store should work normally with default options - Assert.NotNull(writer); - Assert.Equal("stream-1", writer.StreamId); - Assert.Equal(SseEventStreamMode.Default, writer.Mode); - } - - [Fact] - public async Task Constructor_UsesProvidedOptions_WhenOptionsParameterIsSpecified() - { - // Arrange - var cache = CreateMemoryCache(); - var customOptions = new DistributedCacheEventStreamStoreOptions - { - EventSlidingExpiration = TimeSpan.FromMinutes(10), - EventAbsoluteExpiration = TimeSpan.FromHours(1), - MetadataSlidingExpiration = TimeSpan.FromMinutes(20), - MetadataAbsoluteExpiration = TimeSpan.FromHours(2), - PollingInterval = TimeSpan.FromMilliseconds(50) - }; - var store = new DistributedCacheEventStreamStore(cache, customOptions); - - // Act - Create a stream to verify the store works with custom options - var streamOptions = new SseEventStreamOptions - { - SessionId = "session-1", - StreamId = "stream-1", - Mode = SseEventStreamMode.Default - }; - var writer = await store.CreateStreamAsync(streamOptions, CancellationToken); - - // Assert - The store should work with custom options - Assert.NotNull(writer); - Assert.Equal("stream-1", writer.StreamId); - } - - [Fact] - public async Task CreateStreamAsync_ReturnsWriter_WithCorrectStreamId() - { - // Arrange - var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); - var options = new SseEventStreamOptions - { - SessionId = "session-1", - StreamId = "my-stream-id", - Mode = SseEventStreamMode.Default - }; - - // Act - var writer = await store.CreateStreamAsync(options, CancellationToken); - - // Assert - Assert.Equal("my-stream-id", writer.StreamId); - } - - [Fact] - public async Task CreateStreamAsync_ReturnsWriter_WithCorrectSessionId() - { - // Arrange - var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); - var options = new SseEventStreamOptions - { - SessionId = "my-session-id", - StreamId = "stream-1", - Mode = SseEventStreamMode.Default - }; - - // Act - var writer = await store.CreateStreamAsync(options, CancellationToken); - - // Assert - Write an event and verify the reader can find it by session - var item = new SseItem(null); - var writtenItem = await writer.WriteEventAsync(item, CancellationToken); - - var reader = await store.GetStreamReaderAsync(writtenItem.EventId!, CancellationToken); - Assert.NotNull(reader); - Assert.Equal("my-session-id", reader.SessionId); - } - - [Fact] - public async Task CreateStreamAsync_ReturnsWriter_WithDefaultMode() - { - // Arrange - var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); - var options = new SseEventStreamOptions - { - SessionId = "session-1", - StreamId = "stream-1", - Mode = SseEventStreamMode.Default - }; - - // Act - var writer = await store.CreateStreamAsync(options, CancellationToken); - - // Assert - Assert.Equal(SseEventStreamMode.Default, writer.Mode); - } - - [Fact] - public async Task CreateStreamAsync_ReturnsWriter_WithPollingMode() - { - // Arrange - var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); - var options = new SseEventStreamOptions - { - SessionId = "session-1", - StreamId = "stream-1", - Mode = SseEventStreamMode.Polling - }; - - // Act - var writer = await store.CreateStreamAsync(options, CancellationToken); - - // Assert - Assert.Equal(SseEventStreamMode.Polling, writer.Mode); - } - [Fact] public async Task CreateStreamAsync_ThrowsArgumentNullException_WhenOptionsIsNull() { @@ -185,7 +49,7 @@ public async Task WriteEventAsync_AssignsUniqueEventId_WhenItemHasNoEventId() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); var item = new SseItem(null); @@ -208,7 +72,7 @@ public async Task WriteEventAsync_SkipsAssigningEventId_WhenItemAlreadyHasEventI { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); var existingEventId = "existing-event-id"; @@ -231,7 +95,7 @@ public async Task WriteEventAsync_PreservesDataProperty_InReturnedItem() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); var message = new JsonRpcNotification { Method = "test/notification" }; @@ -254,7 +118,7 @@ public async Task WriteEventAsync_PreservesEventTypeProperty_InReturnedItem() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); var item = new SseItem(null, "custom-event-type"); @@ -306,7 +170,7 @@ public async Task WriteEventAsync_StoresEventWithCorrectSlidingExpiration() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); var item = new SseItem(null); @@ -334,7 +198,7 @@ public async Task WriteEventAsync_StoresEventWithCorrectAbsoluteExpiration() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); var item = new SseItem(null); @@ -358,7 +222,7 @@ public async Task WriteEventAsync_UpdatesStreamMetadata_AfterEachWrite() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); var item = new SseItem(null); @@ -370,26 +234,6 @@ public async Task WriteEventAsync_UpdatesStreamMetadata_AfterEachWrite() Assert.Contains(mockCache.SetCalls, call => call.Key.Contains("meta:")); } - [Fact] - public async Task SetModeAsync_UpdatesModeProperty_OnWriter() - { - // Arrange - var cache = CreateMemoryCache(); - var store = new DistributedCacheEventStreamStore(cache); - var writer = await store.CreateStreamAsync(new SseEventStreamOptions - { - SessionId = "session-1", - StreamId = "stream-1", - Mode = SseEventStreamMode.Default - }, CancellationToken); - - // Act - await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); - - // Assert - Assert.Equal(SseEventStreamMode.Polling, writer.Mode); - } - [Fact] public async Task SetModeAsync_PersistsModeChangeToMetadata() { @@ -400,7 +244,7 @@ public async Task SetModeAsync_PersistsModeChangeToMetadata() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); mockCache.SetCalls.Clear(); // Clear calls from CreateStreamAsync setup @@ -426,7 +270,7 @@ public async Task SetModeAsync_ModeChangeReflectedInReader() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write an event to have something to read @@ -462,7 +306,7 @@ public async Task DisposeAsync_MarksStreamAsCompleted() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write an event so we can get a reader @@ -497,7 +341,7 @@ public async Task DisposeAsync_IsIdempotent() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Act - Call DisposeAsync multiple times @@ -519,7 +363,7 @@ public async Task DisposeAsync_UpdatesMetadata_WithIsCompletedFlag() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); mockCache.SetCalls.Clear(); // Clear calls from CreateStreamAsync @@ -588,7 +432,7 @@ public async Task GetStreamReaderAsync_ReturnsReaderWithCorrectSessionIdAndStrea { SessionId = "my-session", StreamId = "my-stream", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write an event to get a valid event ID @@ -875,7 +719,7 @@ public async Task ReadEventsAsync_InPollingMode_DoesNotWaitForNewEvents() } [Fact] - public async Task ReadEventsAsync_InDefaultMode_WaitsForNewEvents() + public async Task ReadEventsAsync_InStreamingMode_WaitsForNewEvents() { // Arrange var cache = CreateMemoryCache(); @@ -887,7 +731,7 @@ public async Task ReadEventsAsync_InDefaultMode_WaitsForNewEvents() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write one event so we have a valid event ID @@ -932,7 +776,7 @@ public async Task ReadEventsAsync_InDefaultMode_WaitsForNewEvents() } [Fact] - public async Task ReadEventsAsync_InDefaultMode_YieldsNewlyWrittenEvents() + public async Task ReadEventsAsync_InStreamingMode_YieldsNewlyWrittenEvents() { // Arrange var cache = CreateMemoryCache(); @@ -944,7 +788,7 @@ public async Task ReadEventsAsync_InDefaultMode_YieldsNewlyWrittenEvents() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write initial event @@ -991,7 +835,7 @@ public async Task ReadEventsAsync_InDefaultMode_YieldsNewlyWrittenEvents() } [Fact] - public async Task ReadEventsAsync_InDefaultMode_CompletesWhenStreamIsDisposed() + public async Task ReadEventsAsync_InStreamingMode_CompletesWhenStreamIsDisposed() { // Arrange var cache = CreateMemoryCache(); @@ -1003,7 +847,7 @@ public async Task ReadEventsAsync_InDefaultMode_CompletesWhenStreamIsDisposed() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write event to create a valid reader @@ -1034,7 +878,7 @@ public async Task ReadEventsAsync_InDefaultMode_CompletesWhenStreamIsDisposed() } [Fact] - public async Task ReadEventsAsync_InDefaultMode_RespectsCancellation() + public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation() { // Arrange var cache = CreateMemoryCache(); @@ -1046,7 +890,7 @@ public async Task ReadEventsAsync_InDefaultMode_RespectsCancellation() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write event to create a valid reader @@ -1089,7 +933,7 @@ public async Task ReadEventsAsync_InDefaultMode_RespectsCancellation() } [Fact] - public async Task ReadEventsAsync_RespectsModeSwitchFromDefaultToPolling() + public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() { // Arrange var cache = CreateMemoryCache(); @@ -1101,7 +945,7 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromDefaultToPolling() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write an event to create a valid reader @@ -1151,7 +995,7 @@ public async Task ReadEventsAsync_PollingModeReturnsEventsThenCompletes() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default + Mode = SseEventStreamMode.Streaming }, CancellationToken); // Write initial event and create reader from sequence 0 @@ -1442,7 +1286,7 @@ public async Task ReadEventsAsync_Completes_WhenMetadataExpires() { SessionId = "session-1", StreamId = "stream-1", - Mode = SseEventStreamMode.Default // Non-polling mode to trigger the waiting loop + Mode = SseEventStreamMode.Streaming // Non-polling mode to trigger the waiting loop }, CancellationToken); var item = new SseItem(new JsonRpcNotification { Method = "test" }); From 984df247e68ff044d1db344fc53f48134696781b Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Fri, 9 Jan 2026 16:02:51 -0800 Subject: [PATCH 04/18] Clean up tests, throw on expired cache entries --- .../DistributedCacheEventIdFormatter.cs | 12 +- .../DistributedCacheEventStreamStore.cs | 25 +-- .../DistributedCacheEventStreamStoreTests.cs | 146 +++++++++++++----- 3 files changed, 117 insertions(+), 66 deletions(-) diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs index 8aec3f0bd..e54ac0b9d 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs @@ -4,7 +4,7 @@ // This is a shared source file included in both ModelContextProtocol.Core and the test project. // Do not reference symbols internal to the core project, as they won't be available in tests. -using System; +using System.Text; namespace ModelContextProtocol.Server; @@ -13,8 +13,6 @@ namespace ModelContextProtocol.Server; /// /// /// Event IDs are formatted as "{base64(sessionId)}:{base64(streamId)}:{sequence}". -/// Base64 encoding is used because the MCP specification allows session IDs to contain -/// any visible ASCII character (0x21-0x7E), including the ':' separator character. /// internal static class DistributedCacheEventIdFormatter { @@ -27,8 +25,8 @@ public static string Format(string sessionId, string streamId, long sequence) { // Base64-encode session and stream IDs so the event ID can be parsed // even if the original IDs contain the ':' separator character - var sessionBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(sessionId)); - var streamBase64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(streamId)); + var sessionBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(sessionId)); + var streamBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(streamId)); return $"{sessionBase64}{Separator}{streamBase64}{Separator}{sequence}"; } @@ -49,8 +47,8 @@ public static bool TryParse(string eventId, out string sessionId, out string str try { - sessionId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[0])); - streamId = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(parts[1])); + sessionId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[0])); + streamId = Encoding.UTF8.GetString(Convert.FromBase64String(parts[1])); return long.TryParse(parts[2], out sequence); } catch diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs index 8a9e35c33..2ba3f8345 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -255,13 +255,8 @@ public DistributedCacheEventStreamReader( var eventId = DistributedCacheEventIdFormatter.Format(SessionId, StreamId, currentSequence); var eventKey = CacheKeys.Event(eventId); - var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false); - - if (eventBytes is null) - { - // Event may have expired; skip to next - continue; - } + var eventBytes = await _cache.GetAsync(eventKey, cancellationToken).ConfigureAwait(false) + ?? throw new McpException($"SSE event with ID '{eventId}' was not found in the cache. The event may have expired."); var storedEvent = JsonSerializer.Deserialize(eventBytes, McpJsonUtilities.JsonContext.Default.StoredEvent); if (storedEvent is not null) @@ -290,19 +285,11 @@ public DistributedCacheEventStreamReader( // Refresh metadata to get the latest sequence and completion status var metadataKey = CacheKeys.StreamMetadata(SessionId, StreamId); - var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); - - if (metadataBytes is null) - { - // Metadata expired - treat stream as complete to avoid infinite loop - yield break; - } + var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false) + ?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' was not found in the cache. The metadata may have expired."); - var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); - if (currentMetadata is null) - { - yield break; - } + var currentMetadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata) + ?? throw new McpException($"Stream metadata for session '{SessionId}' and stream '{StreamId}' could not be deserialized."); lastSequence = currentMetadata.LastSequence; isCompleted = currentMetadata.IsCompleted; diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index a78b77b41..eb5d5c309 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -756,8 +756,8 @@ public async Task ReadEventsAsync_InStreamingMode_WaitsForNewEvents() } }, CancellationToken); - // Wait a bit, then write a new event - await Task.Delay(100, CancellationToken); + // Write a new event - the reader should pick it up since it's in streaming mode + // and won't complete until cancelled or the stream is disposed var newEvent = await writer.WriteEventAsync(new SseItem(null), CancellationToken); // Wait for read to complete (either event received or timeout) @@ -812,8 +812,7 @@ public async Task ReadEventsAsync_InStreamingMode_YieldsNewlyWrittenEvents() } }, CancellationToken); - // Write 3 new events - await Task.Delay(100, CancellationToken); + // Write 3 new events - the reader should pick them up since it's in streaming mode var event1 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); var event2 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); var event3 = await writer.WriteEventAsync(new SseItem(null), CancellationToken); @@ -856,25 +855,20 @@ public async Task ReadEventsAsync_InStreamingMode_CompletesWhenStreamIsDisposed( Assert.NotNull(reader); // Act - Start reading, then dispose the stream - var events = new List>(); var readTask = Task.Run(async () => { await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) { - events.Add(evt); } }, CancellationToken); - // Wait a bit, then dispose the writer - await Task.Delay(100, CancellationToken); + // Dispose the writer - the reader should detect this and exit gracefully await writer.DisposeAsync(); - // Wait for read to complete with a timeout - var timeoutTask = Task.Delay(TimeSpan.FromSeconds(2), CancellationToken); - var completedTask = await Task.WhenAny(readTask, timeoutTask); - - // Assert - The read should complete gracefully (not timeout) - Assert.Same(readTask, completedTask); + // Assert - The read should complete gracefully within timeout + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(2)); + await readTask.WaitAsync(timeoutCts.Token); } [Fact] @@ -900,8 +894,9 @@ public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation() // Act - Start reading and then cancel using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - var stopwatch = System.Diagnostics.Stopwatch.StartNew(); var events = new List>(); + var messageReceivedTcs = new TaskCompletionSource(); + var continueReadingTcs = new TaskCompletionSource(); OperationCanceledException? capturedException = null; var readTask = Task.Run(async () => @@ -911,6 +906,8 @@ public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation() await foreach (var evt in reader.ReadEventsAsync(cts.Token)) { events.Add(evt); + messageReceivedTcs.SetResult(true); + await continueReadingTcs.Task; } } catch (OperationCanceledException ex) @@ -919,17 +916,23 @@ public async Task ReadEventsAsync_InStreamingMode_RespectsCancellation() } }, CancellationToken); - // Cancel after a short delay - await Task.Delay(100, CancellationToken); + // Write a message for the reader to consume + await writer.WriteEventAsync(new SseItem(null), CancellationToken); + + // Wait for the first message to be received + await messageReceivedTcs.Task; + + // Cancel so that ReadEventsAsync throws before reading the next message await cts.CancelAsync(); + // Allow the message reader to continue + continueReadingTcs.SetResult(true); + // Wait for read task to complete await readTask; - stopwatch.Stop(); - // Assert - Either cancelled exception or graceful exit, but should complete quickly - Assert.Empty(events); // No events should have been received - Assert.True(stopwatch.ElapsedMilliseconds < 1000, $"Should complete quickly after cancellation, took {stopwatch.ElapsedMilliseconds}ms"); + Assert.Single(events); + Assert.NotNull(capturedException); } [Fact] @@ -953,7 +956,7 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() var reader = await store.GetStreamReaderAsync(writtenEvent.EventId!, CancellationToken); Assert.NotNull(reader); - // Start reading in default mode (will wait for new events) + // Start reading in streaming mode (will wait for new events) using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); cts.CancelAfter(TimeSpan.FromSeconds(3)); var events = new List>(); @@ -968,16 +971,13 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() readCompleted = true; }, CancellationToken); - // Wait a bit, then switch to polling mode - await Task.Delay(100, CancellationToken); + // Switch to polling mode - the reader should detect this and exit await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); - // Wait for read task to complete (should complete quickly after mode switch) - var timeoutTask = Task.Delay(TimeSpan.FromSeconds(1), CancellationToken); - var completedTask = await Task.WhenAny(readTask, timeoutTask); - - // Assert - Read should have completed after switching to polling mode - Assert.Same(readTask, completedTask); + // Assert - Read should complete within timeout after switching to polling mode + using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(1)); + await readTask.WaitAsync(timeoutCts.Token); Assert.True(readCompleted); Assert.Empty(events); // No new events were written after the one we used to create the reader } @@ -1271,7 +1271,7 @@ public void DefaultOptions_HaveReasonableDefaults() } [Fact] - public async Task ReadEventsAsync_Completes_WhenMetadataExpires() + public async Task ReadEventsAsync_ThrowsMcpException_WhenMetadataExpires() { // Arrange - Use a cache that allows us to simulate metadata expiration var trackingCache = new TestDistributedCache(); @@ -1299,16 +1299,59 @@ public async Task ReadEventsAsync_Completes_WhenMetadataExpires() // Now simulate metadata expiration trackingCache.ExpireMetadata(); - // Act - Read events; the reader should complete gracefully when metadata expires - // instead of looping indefinitely with the stale initial metadata - var events = new List>(); - await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + // Act & Assert - Reader should throw McpException when metadata expires + var exception = await Assert.ThrowsAsync(async () => { - events.Add(evt); - } + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + // Should not yield any events before throwing + } + }); - // If we reach here without timeout, the reader correctly handled metadata expiration - Assert.Empty(events); // No new events after the initial one used to create the reader + Assert.Contains("session-1", exception.Message); + Assert.Contains("stream-1", exception.Message); + Assert.Contains("metadata", exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ReadEventsAsync_ThrowsMcpException_WhenEventExpires() + { + // Arrange - Use a cache that allows us to simulate event expiration + var trackingCache = new TestDistributedCache(); + var store = new DistributedCacheEventStreamStore(trackingCache); + + // Create a stream and write multiple events + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var event1 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method1" }), CancellationToken); + var event2 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method2" }), CancellationToken); + var event3 = await writer.WriteEventAsync(new SseItem(new JsonRpcNotification { Method = "method3" }), CancellationToken); + + // Create a reader starting from before the first event + var startEventId = DistributedCacheEventIdFormatter.Format("session-1", "stream-1", 0); + var reader = await store.GetStreamReaderAsync(startEventId, CancellationToken); + Assert.NotNull(reader); + + // Simulate event2 expiring from the cache + trackingCache.ExpireEvent(event2.EventId!); + + // Act & Assert - Reader should throw McpException when an event is missing + var exception = await Assert.ThrowsAsync(async () => + { + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + }); + + Assert.Contains(event2.EventId!, exception.Message); + Assert.Contains("not found", exception.Message, StringComparison.OrdinalIgnoreCase); } [Fact] @@ -1521,18 +1564,20 @@ public void EventIdFormatter_TryParse_ReturnsFalse_ForNonNumericSequence() /// /// A distributed cache that tracks all operations for verification in tests. - /// Supports tracking Set calls, counting metadata reads, and simulating metadata expiration. + /// Supports tracking Set calls, counting metadata reads, and simulating metadata/event expiration. /// private sealed class TestDistributedCache : IDistributedCache { private readonly MemoryDistributedCache _innerCache = new(Options.Create(new MemoryDistributedCacheOptions())); private int _metadataReadCount; private bool _metadataExpired; + private readonly HashSet _expiredEventIds = []; public List<(string Key, DistributedCacheEntryOptions Options)> SetCalls { get; } = []; public int MetadataReadCount => _metadataReadCount; public void ExpireMetadata() => _metadataExpired = true; + public void ExpireEvent(string eventId) => _expiredEventIds.Add(eventId); public byte[]? Get(string key) { @@ -1544,6 +1589,10 @@ private sealed class TestDistributedCache : IDistributedCache return null; } } + if (IsExpiredEvent(key)) + { + return null; + } return _innerCache.Get(key); } @@ -1557,9 +1606,26 @@ private sealed class TestDistributedCache : IDistributedCache return Task.FromResult(null); } } + if (IsExpiredEvent(key)) + { + return Task.FromResult(null); + } return _innerCache.GetAsync(key, token); } + private bool IsExpiredEvent(string key) + { + // Cache key format is "mcp:sse:event:{eventId}" + foreach (var expiredEventId in _expiredEventIds) + { + if (key.EndsWith(expiredEventId)) + { + return true; + } + } + return false; + } + public void Refresh(string key) => _innerCache.Refresh(key); public Task RefreshAsync(string key, CancellationToken token = default) => _innerCache.RefreshAsync(key, token); public void Remove(string key) => _innerCache.Remove(key); From 80b2cb4cf9d5f1eff8a6550323a7e336432e9860 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Wed, 14 Jan 2026 16:33:24 -0800 Subject: [PATCH 05/18] Add logging --- .../DistributedCacheEventStreamStore.cs | 84 +++++++++++++++++-- 1 file changed, 76 insertions(+), 8 deletions(-) diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs index 2ba3f8345..d53170d06 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -1,4 +1,6 @@ using Microsoft.Extensions.Caching.Distributed; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; @@ -20,28 +22,32 @@ namespace ModelContextProtocol.Server; /// to be only one writer per stream. Readers may be created from separate processes. /// /// -public sealed class DistributedCacheEventStreamStore : ISseEventStreamStore +public sealed partial class DistributedCacheEventStreamStore : ISseEventStreamStore { private readonly IDistributedCache _cache; private readonly DistributedCacheEventStreamStoreOptions _options; + private readonly ILogger _logger; /// /// Initializes a new instance of the class. /// /// The distributed cache to use for storage. /// Optional configuration options for the store. - public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null) + /// Optional logger for diagnostic output. + public DistributedCacheEventStreamStore(IDistributedCache cache, DistributedCacheEventStreamStoreOptions? options = null, ILogger? logger = null) { Throw.IfNull(cache); _cache = cache; _options = options ?? new(); + _logger = logger ?? NullLogger.Instance; } /// public ValueTask CreateStreamAsync(SseEventStreamOptions options, CancellationToken cancellationToken = default) { Throw.IfNull(options); - var writer = new DistributedCacheEventStreamWriter(_cache, options.SessionId, options.StreamId, options.Mode, _options); + LogStreamCreated(options.SessionId, options.StreamId, options.Mode); + var writer = new DistributedCacheEventStreamWriter(_cache, options.SessionId, options.StreamId, options.Mode, _options, _logger); return new ValueTask(writer); } @@ -53,6 +59,7 @@ public ValueTask CreateStreamAsync(SseEventStreamOptions // Parse the event ID to get session, stream, and sequence information if (!DistributedCacheEventIdFormatter.TryParse(lastEventId, out var sessionId, out var streamId, out var sequence)) { + LogEventIdParsingFailed(lastEventId); return null; } @@ -61,17 +68,20 @@ public ValueTask CreateStreamAsync(SseEventStreamOptions var metadataBytes = await _cache.GetAsync(metadataKey, cancellationToken).ConfigureAwait(false); if (metadataBytes is null) { + LogStreamMetadataNotFound(sessionId, streamId); return null; } var metadata = JsonSerializer.Deserialize(metadataBytes, McpJsonUtilities.JsonContext.Default.StreamMetadata); if (metadata is null) { + LogStreamMetadataDeserializationFailed(sessionId, streamId); return null; } var startSequence = sequence + 1; - return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, startSequence, metadata, _options); + LogStreamReaderCreated(sessionId, streamId, startSequence, metadata.LastSequence); + return new DistributedCacheEventStreamReader(_cache, sessionId, streamId, startSequence, metadata, _options, _logger); } /// @@ -111,13 +121,14 @@ internal sealed class StoredEvent public JsonRpcMessage? Data { get; set; } } - private sealed class DistributedCacheEventStreamWriter : ISseEventStreamWriter + private sealed partial class DistributedCacheEventStreamWriter : ISseEventStreamWriter { private readonly IDistributedCache _cache; private readonly string _sessionId; private readonly string _streamId; private SseEventStreamMode _mode; private readonly DistributedCacheEventStreamStoreOptions _options; + private readonly ILogger _logger; private long _sequence; private bool _disposed; @@ -126,17 +137,20 @@ public DistributedCacheEventStreamWriter( string sessionId, string streamId, SseEventStreamMode mode, - DistributedCacheEventStreamStoreOptions options) + DistributedCacheEventStreamStoreOptions options, + ILogger logger) { _cache = cache; _sessionId = sessionId; _streamId = streamId; _mode = mode; _options = options; + _logger = logger; } public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken cancellationToken = default) { + LogStreamModeChanged(_sessionId, _streamId, mode); _mode = mode; await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); } @@ -146,6 +160,7 @@ public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken c // Skip if already has an event ID if (sseItem.EventId is not null) { + LogEventAlreadyHasId(_sessionId, _streamId, sseItem.EventId); return sseItem; } @@ -174,6 +189,7 @@ public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken c // Update metadata with the latest sequence await UpdateMetadataAsync(cancellationToken).ConfigureAwait(false); + LogEventWritten(_sessionId, _streamId, eventId, sequence); return newItem; } @@ -207,15 +223,29 @@ public async ValueTask DisposeAsync() // Mark the stream as completed in the metadata await UpdateMetadataAsync(CancellationToken.None).ConfigureAwait(false); + LogStreamWriterDisposed(_sessionId, _streamId, Interlocked.Read(ref _sequence)); } + + [LoggerMessage(Level = LogLevel.Debug, Message = "Stream mode changed for session '{SessionId}', stream '{StreamId}' to {Mode}.")] + private partial void LogStreamModeChanged(string sessionId, string streamId, SseEventStreamMode mode); + + [LoggerMessage(Level = LogLevel.Trace, Message = "Event already has ID '{EventId}' for session '{SessionId}', stream '{StreamId}'. Skipping ID generation.")] + private partial void LogEventAlreadyHasId(string sessionId, string streamId, string eventId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Event written to session '{SessionId}', stream '{StreamId}' with ID '{EventId}' (sequence {Sequence}).")] + private partial void LogEventWritten(string sessionId, string streamId, string eventId, long sequence); + + [LoggerMessage(Level = LogLevel.Information, Message = "Stream writer disposed for session '{SessionId}', stream '{StreamId}'. Total events written: {TotalEvents}.")] + private partial void LogStreamWriterDisposed(string sessionId, string streamId, long totalEvents); } - private sealed class DistributedCacheEventStreamReader : ISseEventStreamReader + private sealed partial class DistributedCacheEventStreamReader : ISseEventStreamReader { private readonly IDistributedCache _cache; private readonly long _startSequence; private readonly StreamMetadata _initialMetadata; private readonly DistributedCacheEventStreamStoreOptions _options; + private readonly ILogger _logger; public DistributedCacheEventStreamReader( IDistributedCache cache, @@ -223,7 +253,8 @@ public DistributedCacheEventStreamReader( string streamId, long startSequence, StreamMetadata initialMetadata, - DistributedCacheEventStreamStoreOptions options) + DistributedCacheEventStreamStoreOptions options, + ILogger logger) { _cache = cache; SessionId = sessionId; @@ -231,6 +262,7 @@ public DistributedCacheEventStreamReader( _startSequence = startSequence; _initialMetadata = initialMetadata; _options = options; + _logger = logger; } public string SessionId { get; } @@ -246,6 +278,8 @@ public DistributedCacheEventStreamReader( var isCompleted = _initialMetadata.IsCompleted; var mode = _initialMetadata.Mode; + LogReadingEventsStarted(SessionId, StreamId, _startSequence, lastSequence); + while (!cancellationToken.IsCancellationRequested) { // Read all available events from currentSequence + 1 to lastSequence @@ -261,6 +295,7 @@ public DistributedCacheEventStreamReader( var storedEvent = JsonSerializer.Deserialize(eventBytes, McpJsonUtilities.JsonContext.Default.StoredEvent); if (storedEvent is not null) { + LogEventRead(SessionId, StreamId, eventId, currentSequence); yield return new SseItem(storedEvent.Data, storedEvent.EventType) { EventId = storedEvent.EventId, @@ -271,16 +306,19 @@ public DistributedCacheEventStreamReader( // If in polling mode, stop after returning currently available events if (mode == SseEventStreamMode.Polling) { + LogReadingEventsCompletedPolling(SessionId, StreamId, currentSequence - 1); yield break; } // If the stream is completed and we've read all events, stop if (isCompleted) { + LogReadingEventsCompletedStreamEnded(SessionId, StreamId, currentSequence - 1); yield break; } // Wait before polling again for new events + LogWaitingForNewEvents(SessionId, StreamId, _options.PollingInterval); await Task.Delay(_options.PollingInterval, cancellationToken).ConfigureAwait(false); // Refresh metadata to get the latest sequence and completion status @@ -296,5 +334,35 @@ public DistributedCacheEventStreamReader( mode = currentMetadata.Mode; } } + + [LoggerMessage(Level = LogLevel.Debug, Message = "Starting to read events for session '{SessionId}', stream '{StreamId}' from sequence {StartSequence} to {LastSequence}.")] + private partial void LogReadingEventsStarted(string sessionId, string streamId, long startSequence, long lastSequence); + + [LoggerMessage(Level = LogLevel.Trace, Message = "Event read from session '{SessionId}', stream '{StreamId}' with ID '{EventId}' (sequence {Sequence}).")] + private partial void LogEventRead(string sessionId, string streamId, string eventId, long sequence); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Reading events completed for session '{SessionId}', stream '{StreamId}' in polling mode. Last sequence read: {LastSequence}.")] + private partial void LogReadingEventsCompletedPolling(string sessionId, string streamId, long lastSequence); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Reading events completed for session '{SessionId}', stream '{StreamId}' as stream has ended. Last sequence read: {LastSequence}.")] + private partial void LogReadingEventsCompletedStreamEnded(string sessionId, string streamId, long lastSequence); + + [LoggerMessage(Level = LogLevel.Trace, Message = "Waiting for new events on session '{SessionId}', stream '{StreamId}'. Polling interval: {PollingInterval}.")] + private partial void LogWaitingForNewEvents(string sessionId, string streamId, TimeSpan pollingInterval); } + + [LoggerMessage(Level = LogLevel.Information, Message = "Stream created for session '{SessionId}', stream '{StreamId}' with mode {Mode}.")] + private partial void LogStreamCreated(string sessionId, string streamId, SseEventStreamMode mode); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Stream reader created for session '{SessionId}', stream '{StreamId}' starting at sequence {StartSequence}. Last available sequence: {LastSequence}.")] + private partial void LogStreamReaderCreated(string sessionId, string streamId, long startSequence, long lastSequence); + + [LoggerMessage(Level = LogLevel.Warning, Message = "Failed to parse event ID '{EventId}'. Unable to create stream reader.")] + private partial void LogEventIdParsingFailed(string eventId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Stream metadata not found for session '{SessionId}', stream '{StreamId}'.")] + private partial void LogStreamMetadataNotFound(string sessionId, string streamId); + + [LoggerMessage(Level = LogLevel.Warning, Message = "Failed to deserialize stream metadata for session '{SessionId}', stream '{StreamId}'.")] + private partial void LogStreamMetadataDeserializationFailed(string sessionId, string streamId); } From bcbf247aa51c1811d1715c32ce27a188dcf87e48 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 15 Jan 2026 14:26:09 -0800 Subject: [PATCH 06/18] Store retry interval --- .../DistributedCacheEventStreamStore.cs | 7 ++ .../DistributedCacheEventStreamStoreTests.cs | 89 +++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs index d53170d06..d9c96fda5 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -118,6 +118,7 @@ internal sealed class StoredEvent { public string? EventType { get; set; } public string? EventId { get; set; } + public int? ReconnectionIntervalMs { get; set; } public JsonRpcMessage? Data { get; set; } } @@ -174,6 +175,9 @@ public async ValueTask SetModeAsync(SseEventStreamMode mode, CancellationToken c { EventType = newItem.EventType, EventId = eventId, + ReconnectionIntervalMs = newItem.ReconnectionInterval.HasValue + ? (int)newItem.ReconnectionInterval.Value.TotalMilliseconds + : null, Data = newItem.Data, }; @@ -299,6 +303,9 @@ public DistributedCacheEventStreamReader( yield return new SseItem(storedEvent.Data, storedEvent.EventType) { EventId = storedEvent.EventId, + ReconnectionInterval = storedEvent.ReconnectionIntervalMs.HasValue + ? TimeSpan.FromMilliseconds(storedEvent.ReconnectionIntervalMs.Value) + : null, }; } } diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index eb5d5c309..534d8f079 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -130,6 +130,95 @@ public async Task WriteEventAsync_PreservesEventTypeProperty_InReturnedItem() Assert.Equal("custom-event-type", result.EventType); } + [Fact] + public async Task WriteEventAsync_PreservesReconnectionIntervalProperty_InStoredEvent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + var expectedInterval = TimeSpan.FromSeconds(5); + var item = new SseItem(null) { ReconnectionInterval = expectedInterval }; + + // Act + var result = await writer.WriteEventAsync(item, CancellationToken); + + // Assert - ReconnectionInterval should be preserved in returned item + Assert.Equal(expectedInterval, result.ReconnectionInterval); + + // Get a reader and verify ReconnectionInterval is preserved after round-trip + var reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken); + Assert.NotNull(reader); + + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Reader should not return the event we just wrote (it starts after lastEventId) + Assert.Empty(events); + + // Write another event and verify it can be read with correct ReconnectionInterval + var secondItem = new SseItem(null) { ReconnectionInterval = TimeSpan.FromSeconds(10) }; + _ = await writer.WriteEventAsync(secondItem, CancellationToken); + + // Re-fetch reader using the first event ID to get the second event + reader = await store.GetStreamReaderAsync(result.EventId!, CancellationToken); + Assert.NotNull(reader); + + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + Assert.Single(events); + Assert.Equal(TimeSpan.FromSeconds(10), events[0].ReconnectionInterval); + } + + [Fact] + public async Task WriteEventAsync_HandlesNullReconnectionInterval_InStoredEvent() + { + // Arrange + var cache = CreateMemoryCache(); + var store = new DistributedCacheEventStreamStore(cache); + var writer = await store.CreateStreamAsync(new SseEventStreamOptions + { + SessionId = "session-1", + StreamId = "stream-1", + Mode = SseEventStreamMode.Polling + }, CancellationToken); + + // Write an event WITH a reconnection interval first + var firstItem = new SseItem(null) { ReconnectionInterval = TimeSpan.FromSeconds(5) }; + var firstResult = await writer.WriteEventAsync(firstItem, CancellationToken); + + // Write an event WITHOUT a reconnection interval + var secondItem = new SseItem(null); + var secondResult = await writer.WriteEventAsync(secondItem, CancellationToken); + Assert.Null(secondResult.ReconnectionInterval); + + // Get a reader starting after the first event + var reader = await store.GetStreamReaderAsync(firstResult.EventId!, CancellationToken); + Assert.NotNull(reader); + + var events = new List>(); + await foreach (var evt in reader.ReadEventsAsync(CancellationToken)) + { + events.Add(evt); + } + + // Should get the second event with null ReconnectionInterval + Assert.Single(events); + Assert.Null(events[0].ReconnectionInterval); + } + [Fact] public async Task WriteEventAsync_HandlesNullData_AssignsEventIdAndStoresEvent() { From 66083eabe568dc0d6930962d8581ee170eb8ec40 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 15 Jan 2026 15:22:37 -0800 Subject: [PATCH 07/18] Use span-based APIs for event ID parsing --- .../DistributedCacheEventIdFormatter.cs | 59 +++++++++++++++++++ .../DistributedCacheEventStreamStoreTests.cs | 19 ++++++ 2 files changed, 78 insertions(+) diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs index e54ac0b9d..5fa8525d3 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventIdFormatter.cs @@ -4,6 +4,12 @@ // This is a shared source file included in both ModelContextProtocol.Core and the test project. // Do not reference symbols internal to the core project, as they won't be available in tests. +#if NET +using System.Buffers; +using System.Buffers.Text; +using System.Diagnostics.CodeAnalysis; + +#endif using System.Text; namespace ModelContextProtocol.Server; @@ -39,6 +45,34 @@ public static bool TryParse(string eventId, out string sessionId, out string str streamId = string.Empty; sequence = 0; +#if NET + ReadOnlySpan eventIdSpan = eventId.AsSpan(); + Span partRanges = stackalloc Range[4]; + int rangeCount = eventIdSpan.Split(partRanges, Separator); + if (rangeCount != 3) + { + return false; + } + + try + { + ReadOnlySpan sessionBase64 = eventIdSpan[partRanges[0]]; + ReadOnlySpan streamBase64 = eventIdSpan[partRanges[1]]; + ReadOnlySpan sequenceSpan = eventIdSpan[partRanges[2]]; + + if (!TryDecodeBase64ToString(sessionBase64, out sessionId!) || + !TryDecodeBase64ToString(streamBase64, out streamId!)) + { + return false; + } + + return long.TryParse(sequenceSpan, out sequence); + } + catch + { + return false; + } +#else var parts = eventId.Split(Separator); if (parts.Length != 3) { @@ -55,5 +89,30 @@ public static bool TryParse(string eventId, out string sessionId, out string str { return false; } +#endif + } + +#if NET + private static bool TryDecodeBase64ToString(ReadOnlySpan base64Chars, [NotNullWhen(true)] out string? result) + { + // Use a single buffer: base64 chars are ASCII (1:1 with UTF8 bytes), + // and decoded data is always smaller than encoded, so we can decode in-place. + int bufferLength = base64Chars.Length; + Span buffer = bufferLength <= 256 + ? stackalloc byte[bufferLength] + : new byte[bufferLength]; + + Encoding.UTF8.GetBytes(base64Chars, buffer); + + OperationStatus status = Base64.DecodeFromUtf8InPlace(buffer, out int bytesWritten); + if (status != OperationStatus.Done) + { + result = null; + return false; + } + + result = Encoding.UTF8.GetString(buffer[..bytesWritten]); + return true; } +#endif } diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index 534d8f079..baad69b29 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -1523,6 +1523,25 @@ public void EventIdFormatter_TryParse_RoundTripsSuccessfully() Assert.Equal(originalSequence, sequence); } + [Fact] + public void EventIdFormatter_TryParse_HandlesEmptySessionAndStreamIds() + { + // Arrange + var originalSessionId = ""; + var originalStreamId = ""; + var originalSequence = 42L; + + // Act + var eventId = DistributedCacheEventIdFormatter.Format(originalSessionId, originalStreamId, originalSequence); + var parsed = DistributedCacheEventIdFormatter.TryParse(eventId, out var sessionId, out var streamId, out var sequence); + + // Assert + Assert.True(parsed); + Assert.Equal(originalSessionId, sessionId); + Assert.Equal(originalStreamId, streamId); + Assert.Equal(originalSequence, sequence); + } + [Fact] public void EventIdFormatter_TryParse_HandlesSpecialCharactersInSessionId() { From 17867ee7031489b7dab0c2f6c627fa2172186065 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 15 Jan 2026 15:29:53 -0800 Subject: [PATCH 08/18] Add shorter timeout on `Client_CanResumeUnsolicitedMessageStream_AfterDisconnection` --- .../ResumabilityIntegrationTests.cs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs index e3425b253..83566c956 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -279,6 +279,7 @@ public async Task Client_CanResumePostResponseStream_AfterDisconnection() [Fact] public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() { + var timeout = TimeSpan.FromSeconds(10); using var faultingStreamHandler = new FaultingStreamHandler() { InnerHandler = SocketsHttpHandler, @@ -304,12 +305,12 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() await using var client = await ConnectClientAsync(); // Get the server instance - var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + var server = await serverTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Set up notification tracking with unique messages - var clientReceivedInitialNotificationTcs = new TaskCompletionSource(); - var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(); - var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(); + var clientReceivedInitialNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); const string CustomNotificationMethod = "test/custom_notification"; const string InitialMessage = "Initial notification"; @@ -347,7 +348,7 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken); // Wait for client to receive the first notification - await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedInitialNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Fault the unsolicited message stream (GET SSE) var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken); @@ -359,13 +360,13 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() reconnectAttempt.Continue(); // Wait for client to receive the notification via replay - await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedReplayedNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Send a final notification while the client has reconnected - this should be handled by the transport await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReconnectMessage }, cancellationToken: TestContext.Current.CancellationToken); // Wait for the client to receive the final notification - await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedReconnectNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Assert each notification was received exactly once Assert.Equal(1, initialNotificationReceivedCount); From f59c0e04e644377ffbdfcd9f28ddde63ddd79c60 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 15 Jan 2026 16:11:34 -0800 Subject: [PATCH 09/18] Use longer timeout in test --- .../Server/DistributedCacheEventStreamStoreTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index baad69b29..6900917b9 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -956,7 +956,7 @@ public async Task ReadEventsAsync_InStreamingMode_CompletesWhenStreamIsDisposed( // Assert - The read should complete gracefully within timeout using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - timeoutCts.CancelAfter(TimeSpan.FromSeconds(2)); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(10)); await readTask.WaitAsync(timeoutCts.Token); } From 1fae1f6d0849ccc2f5b6b165599a8abfa67078ae Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Thu, 15 Jan 2026 16:29:32 -0800 Subject: [PATCH 10/18] Use versioned cache keys --- .../Server/DistributedCacheEventStreamStore.cs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs index d9c96fda5..2b74f6989 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -87,9 +87,18 @@ public ValueTask CreateStreamAsync(SseEventStreamOptions /// /// Provides methods for generating cache keys. /// + /// + /// Cache keys are versioned to allow format changes without conflicts with existing entries. + /// When the cache format changes, increment to invalidate old entries. + /// internal static class CacheKeys { - private const string Prefix = "mcp:sse:"; + /// + /// The current cache key version. Increment this when changing the cache format + /// to ensure old entries are ignored. + /// + private const string Version = "v1"; + private const string Prefix = $"mcp:sse:{Version}:"; public static string StreamMetadata(string sessionId, string streamId) => $"{Prefix}meta:{sessionId}:{streamId}"; From e2d3ba01acd54d7565af6bd27d846b4f4cc2bbad Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Fri, 16 Jan 2026 10:53:44 -0800 Subject: [PATCH 11/18] Fix flaky test --- .../ResumabilityIntegrationTests.cs | 3 +++ .../Utils/FaultingStreamHandler.cs | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs index 83566c956..528439f47 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -344,6 +344,9 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() return default; }); + // Wait for the client's GET SSE stream to be established before sending notifications + await faultingStreamHandler.WaitForStreamAsync(TestContext.Current.CancellationToken); + // Send a custom notification to the client on the unsolicited message stream await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs index cace4d8be..0beb498f2 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs @@ -11,6 +11,10 @@ internal sealed class FaultingStreamHandler : DelegatingHandler { private FaultingStream? _lastStream; private TaskCompletionSource? _reconnectTcs; + private TaskCompletionSource _streamAvailableTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public Task WaitForStreamAsync(CancellationToken cancellationToken = default) + => _streamAvailableTcs.Task.WaitAsync(cancellationToken); public async Task TriggerFaultAsync(CancellationToken cancellationToken) { @@ -24,6 +28,7 @@ public async Task TriggerFaultAsync(CancellationToken cancella throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection."); } + _streamAvailableTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); _reconnectTcs = new(); await _lastStream.TriggerFaultAsync(cancellationToken); @@ -63,6 +68,8 @@ protected override async Task SendAsync( } response.Content = newContent; + + _streamAvailableTcs.TrySetResult(); } return response; From 114feac66df97b212269d0d885f09f1170ea0abc Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Fri, 16 Jan 2026 11:40:43 -0800 Subject: [PATCH 12/18] Amend test fix --- .../ResumabilityIntegrationTests.cs | 4 +-- .../Utils/FaultingStreamHandler.cs | 30 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs index 528439f47..a75906463 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -344,8 +344,8 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() return default; }); - // Wait for the client's GET SSE stream to be established before sending notifications - await faultingStreamHandler.WaitForStreamAsync(TestContext.Current.CancellationToken); + // Wait for the client's unsolicited message stream to be established before sending notifications + await faultingStreamHandler.WaitForUnsolicitedMessageStreamAsync(TestContext.Current.CancellationToken); // Send a custom notification to the client on the unsolicited message stream await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs index 0beb498f2..dc157735f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs @@ -11,10 +11,12 @@ internal sealed class FaultingStreamHandler : DelegatingHandler { private FaultingStream? _lastStream; private TaskCompletionSource? _reconnectTcs; - private TaskCompletionSource _streamAvailableTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private TaskCompletionSource _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); - public Task WaitForStreamAsync(CancellationToken cancellationToken = default) - => _streamAvailableTcs.Task.WaitAsync(cancellationToken); + public Task WaitForUnsolicitedMessageStreamAsync(CancellationToken cancellationToken = default) + => _unsolicitedMessageStreamReadyTcs.Task.WaitAsync(cancellationToken); + + internal void SignalUnsolicitedMessageStreamReady() => _unsolicitedMessageStreamReadyTcs.TrySetResult(); public async Task TriggerFaultAsync(CancellationToken cancellationToken) { @@ -28,7 +30,9 @@ public async Task TriggerFaultAsync(CancellationToken cancella throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection."); } - _streamAvailableTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + // Reset the TCS so we can wait for the reconnected unsolicited message stream + _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + _reconnectTcs = new(); await _lastStream.TriggerFaultAsync(cancellationToken); @@ -51,6 +55,7 @@ protected override async Task SendAsync( _reconnectTcs = null; } + var isGetRequest = request.Method == HttpMethod.Get; var response = await base.SendAsync(request, cancellationToken); // Only wrap SSE streams (text/event-stream) @@ -69,7 +74,12 @@ protected override async Task SendAsync( response.Content = newContent; - _streamAvailableTcs.TrySetResult(); + // For GET requests (unsolicited message stream), set up the stream to signal + // when first data is read. This ensures the server's transport handler is ready. + if (isGetRequest) + { + _lastStream.SetReadyCallback(SignalUnsolicitedMessageStreamReady); + } } return response; @@ -96,10 +106,14 @@ private sealed class FaultingStream(Stream innerStream) : Stream { private readonly CancellationTokenSource _cts = new(); private TaskCompletionSource? _faultTcs; + private Action? _readyCallback; + private bool _readySignaled; private bool _disposed; public bool IsDisposed => _disposed; + public void SetReadyCallback(Action callback) => _readyCallback = callback; + public async Task TriggerFaultAsync(CancellationToken cancellationToken) { if (_faultTcs is not null) @@ -138,6 +152,12 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation _cts.Token.ThrowIfCancellationRequested(); + if (bytesRead > 0 && !_readySignaled) + { + _readySignaled = true; + _readyCallback?.Invoke(); + } + return bytesRead; } catch (OperationCanceledException) when (_cts.IsCancellationRequested) From 76c064a74d01ebb56fbf994484ffaad5e0864636 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Fri, 16 Jan 2026 15:10:13 -0800 Subject: [PATCH 13/18] Update log message --- .../Server/DistributedCacheEventStreamStore.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs index 2b74f6989..f2a595b2d 100644 --- a/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs +++ b/src/ModelContextProtocol.Core/Server/DistributedCacheEventStreamStore.cs @@ -351,7 +351,7 @@ public DistributedCacheEventStreamReader( } } - [LoggerMessage(Level = LogLevel.Debug, Message = "Starting to read events for session '{SessionId}', stream '{StreamId}' from sequence {StartSequence} to {LastSequence}.")] + [LoggerMessage(Level = LogLevel.Debug, Message = "Starting to read events for session '{SessionId}', stream '{StreamId}' starting at sequence {StartSequence}. Last available sequence: {LastSequence}.")] private partial void LogReadingEventsStarted(string sessionId, string streamId, long startSequence, long lastSequence); [LoggerMessage(Level = LogLevel.Trace, Message = "Event read from session '{SessionId}', stream '{StreamId}' with ID '{EventId}' (sequence {Sequence}).")] From 0a0bc540f9b916867dc563be83f94f24760f83e7 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Fri, 16 Jan 2026 15:24:21 -0800 Subject: [PATCH 14/18] Lengthen unusually short test timeouts --- .../Server/DistributedCacheEventStreamStoreTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index 6900917b9..4d0048d97 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -1047,7 +1047,7 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() // Start reading in streaming mode (will wait for new events) using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - cts.CancelAfter(TimeSpan.FromSeconds(3)); + cts.CancelAfter(TimeSpan.FromSeconds(10)); var events = new List>(); var readCompleted = false; @@ -1065,7 +1065,7 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() // Assert - Read should complete within timeout after switching to polling mode using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - timeoutCts.CancelAfter(TimeSpan.FromSeconds(1)); + timeoutCts.CancelAfter(TimeSpan.FromSeconds(10)); await readTask.WaitAsync(timeoutCts.Token); Assert.True(readCompleted); Assert.Empty(events); // No new events were written after the one we used to create the reader From 3238d6550d039e7d9feebc480ac7bc44ce5e09f0 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Fri, 16 Jan 2026 15:25:34 -0800 Subject: [PATCH 15/18] Remove redundant CTS --- .../Server/DistributedCacheEventStreamStoreTests.cs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index 4d0048d97..34188a694 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -1064,9 +1064,7 @@ public async Task ReadEventsAsync_RespectsModeSwitchFromStreamingToPolling() await writer.SetModeAsync(SseEventStreamMode.Polling, CancellationToken); // Assert - Read should complete within timeout after switching to polling mode - using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - timeoutCts.CancelAfter(TimeSpan.FromSeconds(10)); - await readTask.WaitAsync(timeoutCts.Token); + await readTask.WaitAsync(cts.Token); Assert.True(readCompleted); Assert.Empty(events); // No new events were written after the one we used to create the reader } From d3f70cf2e168f4519613bcf3ddcab9311784af12 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 20 Jan 2026 09:28:26 -0800 Subject: [PATCH 16/18] Delay flushing the unsolicited message stream --- src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs | 5 ----- .../Server/StreamableHttpServerTransport.cs | 5 +++++ 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 6e11b9b86..c50e51388 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -154,11 +154,6 @@ await WriteJsonRpcErrorAsync(context, { await using var _ = await session.AcquireReferenceAsync(cancellationToken); InitializeSseResponse(context); - - // We should flush headers to indicate a 200 success quickly, because the initialization response - // will be sent in response to a different POST request. It might be a while before we send a message - // over this response body. - await context.Response.Body.FlushAsync(cancellationToken); await session.Transport.HandleGetRequestAsync(context.Response.Body, cancellationToken); } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 6307726d0..adc65c496 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -125,6 +125,11 @@ public async Task HandleGetRequestAsync(Stream sseResponseStream, CancellationTo var primingItem = await _storeSseWriter.WriteEventAsync(SseItem.Prime(), cancellationToken).ConfigureAwait(false); await _httpSseWriter.WriteAsync(primingItem, cancellationToken).ConfigureAwait(false); } + + // We should flush to indicate a 200 success quickly, because the initialization response + // will be sent in response to a different POST request. It might be a while before we send a message + // over this response body. + await sseResponseStream.FlushAsync(cancellationToken).ConfigureAwait(false); } // Wait for the response to be written before returning from the handler. From 34033552124832c8b9ead3f5259689497afb2dbc Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 20 Jan 2026 09:32:32 -0800 Subject: [PATCH 17/18] Lengthen timeouts --- .../Server/DistributedCacheEventStreamStoreTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs index 34188a694..68cc04431 100644 --- a/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/DistributedCacheEventStreamStoreTests.cs @@ -830,7 +830,7 @@ public async Task ReadEventsAsync_InStreamingMode_WaitsForNewEvents() // Act - Start reading and then write a new event using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - cts.CancelAfter(TimeSpan.FromSeconds(2)); + cts.CancelAfter(TimeSpan.FromSeconds(10)); var events = new List>(); var readTask = Task.Run(async () => { @@ -887,7 +887,7 @@ public async Task ReadEventsAsync_InStreamingMode_YieldsNewlyWrittenEvents() // Act - Write multiple events while reader is active using var cts = CancellationTokenSource.CreateLinkedTokenSource(CancellationToken); - cts.CancelAfter(TimeSpan.FromSeconds(3)); + cts.CancelAfter(TimeSpan.FromSeconds(10)); var events = new List>(); var readTask = Task.Run(async () => { From 5ea239d6bedf6a927f403ffbd8fb8d4d76c68236 Mon Sep 17 00:00:00 2001 From: Mackinnon Buck Date: Tue, 20 Jan 2026 14:31:54 -0800 Subject: [PATCH 18/18] Allow any OCE in test --- .../ResumabilityIntegrationTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs index a75906463..f6c8f3a92 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -535,7 +535,7 @@ public async Task PostResponse_EndsAndSseEventStreamWriterIsDisposed_WhenWriteEv timeoutCts.CancelAfter(TimeSpan.FromSeconds(10)); // The call task should throw an OCE due to cancellation - await Assert.ThrowsAsync(() => callTask).WaitAsync(timeoutCts.Token); + await Assert.ThrowsAnyAsync(() => callTask).WaitAsync(timeoutCts.Token); // Wait for the writer to be disposed await blockingStore.DisposedTask.WaitAsync(timeoutCts.Token);