diff --git a/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityOptions.cs b/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityOptions.cs index 3b7fc99..95224de 100644 --- a/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityOptions.cs +++ b/src/DurableTask.SqlServer.AzureFunctions/SqlDurabilityOptions.cs @@ -77,6 +77,13 @@ internal SqlOrchestrationServiceSettings GetOrchestrationServiceSettings( settings.MaxActiveOrchestrations = extensionOptions.MaxConcurrentOrchestratorFunctions.Value; } + settings.ExtendedSessionsEnabled = extensionOptions.ExtendedSessionsEnabled; + if (extensionOptions.ExtendedSessionIdleTimeoutInSeconds > 0) + { + settings.ExtendedSessionIdleTimeout = + TimeSpan.FromSeconds(extensionOptions.ExtendedSessionIdleTimeoutInSeconds); + } + return settings; } } diff --git a/src/DurableTask.SqlServer/Scripts/drop-schema.sql b/src/DurableTask.SqlServer/Scripts/drop-schema.sql index f50f629..39d33bb 100644 --- a/src/DurableTask.SqlServer/Scripts/drop-schema.sql +++ b/src/DurableTask.SqlServer/Scripts/drop-schema.sql @@ -31,6 +31,8 @@ DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._LockNextTask DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._QueryManyOrchestrations DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._RenewOrchestrationLocks DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._RenewTaskLocks +DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._FetchOrchestrationMessages +DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._ReleaseOrchestrationLock DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._UpdateVersion DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._RewindInstance DROP PROCEDURE IF EXISTS __SchemaNamePlaceholder__._RewindInstanceRecursive diff --git a/src/DurableTask.SqlServer/Scripts/logic.sql b/src/DurableTask.SqlServer/Scripts/logic.sql index 6fbc422..ec10c71 100644 --- a/src/DurableTask.SqlServer/Scripts/logic.sql +++ b/src/DurableTask.SqlServer/Scripts/logic.sql @@ -758,7 +758,10 @@ CREATE OR ALTER PROCEDURE __SchemaNamePlaceholder__._CheckpointOrchestration @DeletedEvents MessageIDs READONLY, @NewHistoryEvents HistoryEvents READONLY, @NewOrchestrationEvents OrchestrationEvents READONLY, - @NewTaskEvents TaskEvents READONLY + @NewTaskEvents TaskEvents READONLY, + @KeepLocked bit = 0, + @LockedBy varchar(100) = NULL, + @NewLockExpiration datetime2 = NULL AS BEGIN BEGIN TRANSACTION @@ -869,17 +872,31 @@ BEGIN [RuntimeStatus] = @RuntimeStatus, [LastUpdatedTime] = SYSUTCDATETIME(), [CompletedTime] = (CASE WHEN @IsCompleted = 1 THEN SYSUTCDATETIME() ELSE NULL END), - [LockExpiration] = NULL, -- release the lock + -- Release the lock unless the caller asked to keep it for an extended session and the instance is not in a terminal state. + [LockExpiration] = (CASE WHEN @KeepLocked = 1 AND @IsCompleted = 0 THEN @NewLockExpiration ELSE NULL END), + [LockedBy] = (CASE WHEN @KeepLocked = 1 AND @IsCompleted = 0 THEN @LockedBy ELSE NULL END), [CustomStatusPayloadID] = @CustomStatusPayloadID, [InputPayloadID] = @InputPayloadID, [OutputPayloadID] = @OutputPayloadID FROM Instances - WHERE [TaskHub] = @TaskHub and [InstanceID] = @InstanceID + WHERE + [TaskHub] = @TaskHub + AND [InstanceID] = @InstanceID + -- Do not overwrite a row that was taken over by a different worker after our lock expired. + AND (@KeepLocked = 0 OR [LockedBy] = @LockedBy) IF @@ROWCOUNT = 0 BEGIN - ROLLBACK TRANSACTION; - THROW 50000, 'The instance does not exist.', 1; + IF @KeepLocked = 1 + BEGIN + ROLLBACK TRANSACTION; + THROW 50003, 'Lock lost.', 1; + END + ELSE + BEGIN + ROLLBACK TRANSACTION; + THROW 50000, 'The instance does not exist.', 1; + END END -- External event messages can create new instances -- NOTE: There is a chance this could result in deadlocks if two @@ -1337,6 +1354,93 @@ END GO +-- Used by extended sessions to fetch any new events for an already-locked +-- instance without going through the normal lock-acquisition flow. +CREATE OR ALTER PROCEDURE __SchemaNamePlaceholder__._FetchOrchestrationMessages + @InstanceID varchar(100), + @LockedBy varchar(100), + @LockExpiration datetime2, + @BatchSize int +AS +BEGIN + DECLARE @now datetime2 = SYSUTCDATETIME() + DECLARE @TaskHub varchar(50) = __SchemaNamePlaceholder__.CurrentTaskHub() + DECLARE @parentInstanceID varchar(100) + DECLARE @version varchar(100) + DECLARE @runtimeStatus varchar(30) + DECLARE @tags varchar(8000) + + UPDATE Instances + SET + [LockExpiration] = @LockExpiration, + @parentInstanceID = [ParentInstanceID], + @version = [Version], + @runtimeStatus = [RuntimeStatus], + @tags = [Tags] + WHERE + [TaskHub] = @TaskHub + AND [InstanceID] = @InstanceID + AND [LockedBy] = @LockedBy + AND [LockExpiration] IS NOT NULL + AND [LockExpiration] > @now + + IF @@ROWCOUNT = 0 + THROW 50003, 'Lock lost.', 1; + + IF @runtimeStatus IN ('Completed', 'Failed', 'Terminated') + RETURN + + -- Same column shape as the first result-set of _LockNextOrchestration + SELECT TOP (@BatchSize) + N.[SequenceNumber], + N.[Timestamp], + N.[VisibleTime], + N.[DequeueCount], + N.[InstanceID], + N.[ExecutionID], + N.[EventType], + N.[Name], + N.[RuntimeStatus], + N.[TaskID], + P.[Reason], + P.[Text] AS [PayloadText], + P.[PayloadID], + DATEDIFF(SECOND, [Timestamp], @now) AS [WaitTime], + @parentInstanceID as [ParentInstanceID], + @version as [Version], + N.[TraceContext], + @tags as [Tags] + FROM NewEvents N + LEFT OUTER JOIN __SchemaNamePlaceholder__.[Payloads] P ON + P.[TaskHub] = @TaskHub AND + P.[InstanceID] = N.[InstanceID] AND + P.[PayloadID] = N.[PayloadID] + WHERE + N.[TaskHub] = @TaskHub AND + N.[InstanceID] = @InstanceID AND + (N.[VisibleTime] IS NULL OR N.[VisibleTime] < @now) +END +GO + + +-- Releases an instance lock that was kept across an extended session. +CREATE OR ALTER PROCEDURE __SchemaNamePlaceholder__._ReleaseOrchestrationLock + @InstanceID varchar(100), + @LockedBy varchar(100) +AS +BEGIN + DECLARE @TaskHub varchar(50) = __SchemaNamePlaceholder__.CurrentTaskHub() + + UPDATE Instances + SET [LockExpiration] = NULL, [LockedBy] = NULL + WHERE + [TaskHub] = @TaskHub + AND [InstanceID] = @InstanceID + AND [LockedBy] = @LockedBy +END +GO + + CREATE OR ALTER PROCEDURE __SchemaNamePlaceholder__._RenewTaskLocks @RenewingTasks MessageIDs READONLY, @LockExpiration datetime2 diff --git a/src/DurableTask.SqlServer/Scripts/permissions.sql b/src/DurableTask.SqlServer/Scripts/permissions.sql index 1566cd9..fb8a590 100644 --- a/src/DurableTask.SqlServer/Scripts/permissions.sql +++ b/src/DurableTask.SqlServer/Scripts/permissions.sql @@ -38,6 +38,8 @@ GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._LockNextTask TO __SchemaName GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._QueryManyOrchestrations TO __SchemaNamePlaceholder___runtime GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._RenewOrchestrationLocks TO __SchemaNamePlaceholder___runtime GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._RenewTaskLocks TO __SchemaNamePlaceholder___runtime +GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._FetchOrchestrationMessages TO __SchemaNamePlaceholder___runtime +GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._ReleaseOrchestrationLock TO __SchemaNamePlaceholder___runtime GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._UpdateVersion TO __SchemaNamePlaceholder___runtime GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._RewindInstance TO __SchemaNamePlaceholder___runtime GRANT EXECUTE ON OBJECT::__SchemaNamePlaceholder__._RewindInstanceRecursive TO __SchemaNamePlaceholder___runtime diff --git a/src/DurableTask.SqlServer/SqlOrchestrationService.cs b/src/DurableTask.SqlServer/SqlOrchestrationService.cs index 76774d0..4fdcec8 100644 --- a/src/DurableTask.SqlServer/SqlOrchestrationService.cs +++ b/src/DurableTask.SqlServer/SqlOrchestrationService.cs @@ -300,12 +300,24 @@ await SqlUtils.ExecuteNonQueryAsync( instance = new OrchestrationInstance(); } + string orchestrationInstanceId = messages[0].OrchestrationInstance.InstanceId; + return new ExtendedOrchestrationWorkItem(orchestrationName, instance, eventPayloadMappings) { - InstanceId = messages[0].OrchestrationInstance.InstanceId, + InstanceId = orchestrationInstanceId, LockedUntilUtc = lockExpiration, NewMessages = messages, OrchestrationRuntimeState = runtimeState, + Session = this.settings.ExtendedSessionsEnabled + ? new SqlOrchestrationSession( + this.settings, + this.orchestrationBackoffHelper, + this.traceHelper, + eventPayloadMappings, + orchestrationInstanceId, + this.lockedByValue, + this.ShutdownToken) + : null, }; } } while (stopwatch.Elapsed < receiveTimeout); @@ -361,6 +373,15 @@ public override async Task CompleteTaskOrchestrationWorkItemAsync( command.Parameters.Add("@RuntimeStatus", SqlDbType.VarChar, size: 30).Value = orchestrationState.OrchestrationStatus.ToString(); command.Parameters.Add("@CustomStatusPayload", SqlDbType.VarChar).Value = orchestrationState.Status ?? SqlString.Null; + bool keepLocked = workItem.Session != null && !IsTerminalStatus(orchestrationState.OrchestrationStatus); + DateTime newLockExpiration = DateTime.UtcNow.Add(this.settings.WorkItemLockTimeout); + if (keepLocked) + { + command.Parameters.Add("@KeepLocked", SqlDbType.Bit).Value = true; + command.Parameters.Add("@LockedBy", SqlDbType.VarChar, size: 100).Value = this.lockedByValue; + command.Parameters.Add("@NewLockExpiration", SqlDbType.DateTime2).Value = newLockExpiration; + } + currentWorkItem.EventPayloadMappings.Add(outboundMessages); currentWorkItem.EventPayloadMappings.Add(orchestratorMessages); @@ -400,6 +421,16 @@ public override async Task CompleteTaskOrchestrationWorkItemAsync( this.traceHelper.DuplicateExecutionDetected(instance, orchestrationState.Name); return; } + catch (SqlException e) when (keepLocked && SqlUtils.HasErrorNumber(e, SqlOrchestrationSession.LockLostErrorNumber)) + { + throw new SessionAbortedException( + $"Lost the lock for instance '{instance.InstanceId}' during checkpoint.", e); + } + + if (keepLocked) + { + workItem.LockedUntilUtc = newLockExpiration; + } // notify pollers that new messages may be available if (outboundMessages.Count > 0) @@ -415,10 +446,25 @@ public override async Task CompleteTaskOrchestrationWorkItemAsync( this.traceHelper.CheckpointCompleted(orchestrationState, sw); } + static bool IsTerminalStatus(OrchestrationStatus status) => + status == OrchestrationStatus.Completed || + status == OrchestrationStatus.Failed || + status == OrchestrationStatus.Terminated; + // We abandon work items by just letting their locks expire. The benefit of this "lazy" approach is that it // removes the need for a DB access and also ensures that a work-item can't spam the error logs in a tight loop. public override Task AbandonTaskOrchestrationWorkItemAsync(TaskOrchestrationWorkItem workItem) => Task.CompletedTask; + public override Task ReleaseTaskOrchestrationWorkItemAsync(TaskOrchestrationWorkItem workItem) + { + if (workItem.Session is SqlOrchestrationSession session) + { + return session.ReleaseLockAsync(); + } + + return Task.CompletedTask; + } + public override async Task LockNextTaskActivityWorkItem( TimeSpan receiveTimeout, CancellationToken shutdownCancellationToken) diff --git a/src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs b/src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs index 048a723..e3ae3bf 100644 --- a/src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs +++ b/src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs @@ -93,6 +93,18 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN [JsonProperty("maxActiveOrchestrations")] public int MaxActiveOrchestrations { get; set; } = Environment.ProcessorCount; + /// + /// Gets or sets a flag indicating whether to enable extended sessions. + /// + [JsonProperty("extendedSessionsEnabled")] + public bool ExtendedSessionsEnabled { get; set; } = false; + + /// + /// Gets or sets the number of seconds before an idle session times out. + /// + [JsonProperty("extendedSessionIdleTimeout")] + public TimeSpan ExtendedSessionIdleTimeout { get; set; } = TimeSpan.FromSeconds(30); + /// /// Gets or sets the minimum interval to poll for orchestrations. /// Polling interval increases when no orchestrations or activities are found. diff --git a/src/DurableTask.SqlServer/SqlOrchestrationSession.cs b/src/DurableTask.SqlServer/SqlOrchestrationSession.cs new file mode 100644 index 0000000..d6af5da --- /dev/null +++ b/src/DurableTask.SqlServer/SqlOrchestrationSession.cs @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace DurableTask.SqlServer +{ + using System; + using System.Collections.Generic; + using System.Data; + using System.Data.Common; + using System.Threading; + using System.Threading.Tasks; + using DurableTask.Core; + using DurableTask.Core.Exceptions; + using Microsoft.Data.SqlClient; + + class SqlOrchestrationSession : IOrchestrationSession + { + internal const int LockLostErrorNumber = 50003; + + readonly SqlOrchestrationServiceSettings settings; + readonly BackoffPollingHelper orchestrationBackoffHelper; + readonly LogHelper traceHelper; + readonly EventPayloadMap eventPayloadMappings; + readonly string instanceId; + readonly string lockedByValue; + readonly CancellationToken shutdownToken; + + public SqlOrchestrationSession( + SqlOrchestrationServiceSettings settings, + BackoffPollingHelper orchestrationBackoffHelper, + LogHelper traceHelper, + EventPayloadMap eventPayloadMappings, + string instanceId, + string lockedByValue, + CancellationToken shutdownToken) + { + this.settings = settings; + this.orchestrationBackoffHelper = orchestrationBackoffHelper; + this.traceHelper = traceHelper; + this.eventPayloadMappings = eventPayloadMappings; + this.instanceId = instanceId; + this.lockedByValue = lockedByValue; + this.shutdownToken = shutdownToken; + } + + public async Task?> FetchNewOrchestrationMessagesAsync(TaskOrchestrationWorkItem workItem) + { + DateTime deadline = DateTime.UtcNow + this.settings.ExtendedSessionIdleTimeout; + while (!this.shutdownToken.IsCancellationRequested) + { + DateTime newLockExpiration = DateTime.UtcNow + this.settings.WorkItemLockTimeout; + IList messages; + try + { + messages = await this.FetchAsync(newLockExpiration); + } + catch (SqlException e) when (SqlUtils.HasErrorNumber(e, LockLostErrorNumber)) + { + throw new SessionAbortedException( + $"Lost the lock for instance '{this.instanceId}'.", e); + } + catch (OperationCanceledException) when (this.shutdownToken.IsCancellationRequested) + { + return null; + } + + if (messages.Count > 0) + { + workItem.LockedUntilUtc = newLockExpiration; + return messages; + } + + if (DateTime.UtcNow >= deadline) + { + return null; + } + + try + { + await this.orchestrationBackoffHelper.WaitAsync(this.shutdownToken); + } + catch (OperationCanceledException) + { + return null; + } + } + + return null; + } + + public async Task ReleaseLockAsync() + { + try + { + using SqlConnection connection = this.settings.CreateConnection(); + await connection.OpenAsync(); + using SqlCommand command = connection.CreateCommand(); + command.CommandType = CommandType.StoredProcedure; + command.CommandText = $"{this.settings.SchemaName}._ReleaseOrchestrationLock"; + command.Parameters.Add("@InstanceID", SqlDbType.VarChar, 100).Value = this.instanceId; + command.Parameters.Add("@LockedBy", SqlDbType.VarChar, 100).Value = this.lockedByValue; + + await SqlUtils.ExecuteNonQueryAsync(command, this.traceHelper, this.instanceId); + } + catch (Exception e) + { + // Best-effort release; the lock will expire naturally after WorkItemLockTimeout if this fails. + this.traceHelper.GenericWarning( + $"Failed to release orchestration lock for instance '{this.instanceId}': {e.Message}", + this.instanceId); + } + } + + async Task> FetchAsync(DateTime newLockExpiration) + { + using SqlConnection connection = this.settings.CreateConnection(); + await connection.OpenAsync(this.shutdownToken); + + using SqlCommand command = connection.CreateCommand(); + command.CommandType = CommandType.StoredProcedure; + command.CommandText = $"{this.settings.SchemaName}._FetchOrchestrationMessages"; + command.Parameters.Add("@InstanceID", SqlDbType.VarChar, 100).Value = this.instanceId; + command.Parameters.Add("@LockedBy", SqlDbType.VarChar, 100).Value = this.lockedByValue; + command.Parameters.Add("@LockExpiration", SqlDbType.DateTime2).Value = newLockExpiration; + command.Parameters.Add("@BatchSize", SqlDbType.Int).Value = this.settings.WorkItemBatchSize; + + using DbDataReader reader = await SqlUtils.ExecuteReaderAsync( + command, + this.traceHelper, + this.instanceId, + this.shutdownToken); + + var messages = new List(capacity: this.settings.WorkItemBatchSize); + while (reader.Read()) + { + TaskMessage message = reader.GetTaskMessage(); + messages.Add(message); + Guid? payloadId = reader.GetPayloadId(); + if (payloadId.HasValue) + { + this.eventPayloadMappings.Add(message.Event, payloadId.Value); + } + } + + return messages; + } + } +} diff --git a/src/DurableTask.SqlServer/SqlUtils.cs b/src/DurableTask.SqlServer/SqlUtils.cs index 84983f4..c7cec2b 100644 --- a/src/DurableTask.SqlServer/SqlUtils.cs +++ b/src/DurableTask.SqlServer/SqlUtils.cs @@ -681,6 +681,19 @@ public static bool IsUniqueKeyViolation(SqlException exception) return exception.Errors.Cast().Any(e => e.Class == 14 && (e.Number == 2601 || e.Number == 2627)); } + public static bool HasErrorNumber(SqlException ex, int errorNumber) + { + foreach (SqlError error in ex.Errors) + { + if (error.Number == errorNumber) + { + return true; + } + } + + return false; + } + public static void SetDateTime(this SqlDataRecord record, int ordinal, DateTime? dateTime) { if (dateTime.HasValue) diff --git a/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs b/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs index 75f95b3..4205e76 100644 --- a/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs +++ b/test/DurableTask.SqlServer.Tests/Integration/DatabaseManagement.cs @@ -464,10 +464,12 @@ async Task ValidateDatabaseSchemaAsync(TestDatabase database, string schemaName $"{schemaName}._CheckpointOrchestration", $"{schemaName}._CompleteTasks", $"{schemaName}._DiscardEventsAndUnlockInstance", + $"{schemaName}._FetchOrchestrationMessages", $"{schemaName}._GetVersions", $"{schemaName}._LockNextOrchestration", $"{schemaName}._LockNextTask", $"{schemaName}._QueryManyOrchestrations", + $"{schemaName}._ReleaseOrchestrationLock", $"{schemaName}._RenewOrchestrationLocks", $"{schemaName}._RenewTaskLocks", $"{schemaName}._UpdateVersion", diff --git a/test/DurableTask.SqlServer.Tests/Integration/ExtendedSessionTests.cs b/test/DurableTask.SqlServer.Tests/Integration/ExtendedSessionTests.cs new file mode 100644 index 0000000..36812b8 --- /dev/null +++ b/test/DurableTask.SqlServer.Tests/Integration/ExtendedSessionTests.cs @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace DurableTask.SqlServer.Tests.Integration +{ + using System; + using System.Diagnostics; + using System.Threading.Tasks; + using DurableTask.Core; + using DurableTask.SqlServer.Tests.Utils; + using Newtonsoft.Json; + using Xunit; + using Xunit.Abstractions; + + [Collection("Integration")] + public class ExtendedSessionTests : IAsyncLifetime + { + readonly TestService testService; + readonly ITestOutputHelper output; + + public ExtendedSessionTests(ITestOutputHelper output) + { + this.testService = new TestService(output); + this.output = output; + } + + Task IAsyncLifetime.InitializeAsync() => this.testService.InitializeAsync( + extendedSessions: true, + extendedSessionIdleTimeout: TimeSpan.FromSeconds(15)); + + Task IAsyncLifetime.DisposeAsync() => this.testService.DisposeAsync(); + + [Fact] + public async Task LockHeldWhileWaitingForExternalEvent() + { + TaskCompletionSource tcs = null; + + TestInstance instance = await this.testService.RunOrchestration( + input: null, + orchestrationName: nameof(LockHeldWhileWaitingForExternalEvent), + implementation: (ctx, _) => + { + tcs = new TaskCompletionSource(); + return tcs.Task; + }, + onEvent: (ctx, name, value) => tcs.TrySetResult(JsonConvert.DeserializeObject(value))); + + await instance.WaitForStart(); + await this.WaitForLockToBeHeldAsync(instance.InstanceId, TimeSpan.FromSeconds(10)); + + await instance.RaiseEventAsync("Continue", "done"); + await instance.WaitForCompletion(expectedOutput: "done"); + + string lockedBy = await this.GetLockedByAsync(instance.InstanceId); + Assert.Equal(string.Empty, lockedBy); + } + + [Fact] + public async Task LockReleasedAfterIdleTimeout() + { + this.testService.OrchestrationServiceOptions.ExtendedSessionIdleTimeout = TimeSpan.FromSeconds(2); + + TaskCompletionSource tcs = null; + + TestInstance instance = await this.testService.RunOrchestration( + input: null, + orchestrationName: nameof(LockReleasedAfterIdleTimeout), + implementation: (ctx, _) => + { + tcs = new TaskCompletionSource(); + return tcs.Task; + }, + onEvent: (ctx, name, value) => tcs.TrySetResult(JsonConvert.DeserializeObject(value))); + + await instance.WaitForStart(); + await this.WaitForLockToBeHeldAsync(instance.InstanceId, TimeSpan.FromSeconds(10)); + await this.WaitForLockToBeReleasedAsync(instance.InstanceId, TimeSpan.FromSeconds(15)); + + // The orchestration should still be Running — only the lock has expired/cleared. + OrchestrationState midState = await instance.GetStateAsync(); + Assert.Equal(OrchestrationStatus.Running, midState.OrchestrationStatus); + + // Re-engaging the orchestration should still work end-to-end after the session ended. + await instance.RaiseEventAsync("Continue", "done"); + await instance.WaitForCompletion(expectedOutput: "done"); + } + + [Fact] + public async Task LockNotPoachedWhileSessionActive() + { + TaskCompletionSource tcs = null; + + TestInstance instance = await this.testService.RunOrchestration( + input: null, + orchestrationName: nameof(LockNotPoachedWhileSessionActive), + implementation: (ctx, _) => + { + tcs = new TaskCompletionSource(); + return tcs.Task; + }, + onEvent: (ctx, name, value) => tcs.TrySetResult(JsonConvert.DeserializeObject(value))); + + await instance.WaitForStart(); + string heldBy = await this.WaitForLockToBeHeldAsync(instance.InstanceId, TimeSpan.FromSeconds(10)); + + // Sanity check: a second worker calling _LockNextOrchestration sees no available work + // because the only ready instance is locked by the in-flight session. + object available = await SharedTestHelpers.ExecuteSqlAsync( + this.output, + $"SELECT COUNT(*) FROM dt.[Instances] WHERE [InstanceID] = '{instance.InstanceId}' AND [LockedBy] = '{heldBy}' AND [LockExpiration] > SYSUTCDATETIME()"); + Assert.Equal(1, Convert.ToInt32(available)); + + await instance.RaiseEventAsync("Continue", "done"); + await instance.WaitForCompletion(expectedOutput: "done"); + } + + [Fact] + public async Task MultipleEventsAcrossSession() + { + const int eventCount = 5; + TaskCompletionSource tcs = null; + + TestInstance instance = await this.testService.RunOrchestration( + input: null, + orchestrationName: nameof(MultipleEventsAcrossSession), + implementation: async (ctx, _) => + { + tcs = new TaskCompletionSource(); + + int i; + for (i = 0; i < eventCount; i++) + { + await tcs.Task; + tcs = new TaskCompletionSource(); + } + + return i; + }, + onEvent: (ctx, name, value) => + { + tcs.TrySetResult(int.Parse(value)); + }); + + for (int i = 0; i < eventCount; i++) + { + await instance.RaiseEventAsync($"Event{i}", i); + } + + await instance.WaitForCompletion( + timeout: TimeSpan.FromSeconds(15), + expectedOutput: eventCount); + + string lockedBy = await this.GetLockedByAsync(instance.InstanceId); + Assert.Equal(string.Empty, lockedBy); + } + + [Fact] + public async Task ContinueAsNewWithSession() + { + TestInstance instance = await this.testService.RunOrchestration( + input: 0, + orchestrationName: nameof(ContinueAsNewWithSession), + implementation: async (ctx, input) => + { + if (input < 3) + { + await ctx.CreateTimer(DateTime.MinValue, null); + ctx.ContinueAsNew(input + 1); + } + + return input; + }); + + await instance.WaitForCompletion(expectedOutput: 3, continuedAsNew: true); + + string lockedBy = await this.GetLockedByAsync(instance.InstanceId); + Assert.Equal(string.Empty, lockedBy); + } + + [Fact] + public async Task LockLostRecovers() + { + TaskCompletionSource tcs = null; + + TestInstance instance = await this.testService.RunOrchestration( + input: null, + orchestrationName: nameof(LockLostRecovers), + implementation: (ctx, _) => + { + tcs = new TaskCompletionSource(); + return tcs.Task; + }, + onEvent: (ctx, name, value) => tcs.TrySetResult(JsonConvert.DeserializeObject(value))); + + await instance.WaitForStart(); + await this.WaitForLockToBeHeldAsync(instance.InstanceId, TimeSpan.FromSeconds(10)); + + // Forcibly clear the lock + await SharedTestHelpers.ExecuteSqlAsync( + this.output, + $"UPDATE dt.[Instances] SET [LockedBy] = NULL, [LockExpiration] = NULL WHERE [InstanceID] = '{instance.InstanceId}'"); + + await instance.RaiseEventAsync("Continue", "done"); + await instance.WaitForCompletion( + timeout: TimeSpan.FromSeconds(20), + expectedOutput: "done"); + } + + async Task GetLockedByAsync(string instanceId) + { + object result = await SharedTestHelpers.ExecuteSqlAsync( + this.output, + $"SELECT TOP 1 ISNULL([LockedBy], '') FROM dt.[Instances] WHERE [InstanceID] = '{instanceId}'"); + return result?.ToString() ?? string.Empty; + } + + async Task WaitForLockToBeHeldAsync(string instanceId, TimeSpan timeout) + { + timeout = timeout.AdjustForDebugging(); + Stopwatch sw = Stopwatch.StartNew(); + while (sw.Elapsed < timeout) + { + string lockedBy = await this.GetLockedByAsync(instanceId); + if (!string.IsNullOrEmpty(lockedBy)) + { + return lockedBy; + } + + await Task.Delay(TimeSpan.FromMilliseconds(200)); + } + + throw new TimeoutException( + $"Instance '{instanceId}' lock was not held within {timeout.TotalSeconds}s — extended session did not engage."); + } + + async Task WaitForLockToBeReleasedAsync(string instanceId, TimeSpan timeout) + { + timeout = timeout.AdjustForDebugging(); + Stopwatch sw = Stopwatch.StartNew(); + while (sw.Elapsed < timeout) + { + string lockedBy = await this.GetLockedByAsync(instanceId); + if (string.IsNullOrEmpty(lockedBy)) + { + return; + } + + await Task.Delay(TimeSpan.FromMilliseconds(200)); + } + + throw new TimeoutException( + $"Instance '{instanceId}' lock was not released within {timeout.TotalSeconds}s."); + } + } +} diff --git a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs index b0a892c..7e5251d 100644 --- a/test/DurableTask.SqlServer.Tests/Utils/TestService.cs +++ b/test/DurableTask.SqlServer.Tests/Utils/TestService.cs @@ -51,7 +51,11 @@ public TestService(ITestOutputHelper output) public TestLogProvider LogProvider { get; } - public async Task InitializeAsync(bool startWorker = true, bool legacyErrorPropagation = false) + public async Task InitializeAsync( + bool startWorker = true, + bool legacyErrorPropagation = false, + bool extendedSessions = false, + TimeSpan? extendedSessionIdleTimeout = null) { // The initialization requires administrative credentials (default) await new SqlOrchestrationService(this.OrchestrationServiceOptions).CreateIfNotExistsAsync(); @@ -64,8 +68,14 @@ public async Task InitializeAsync(bool startWorker = true, bool legacyErrorPropa this.OrchestrationServiceOptions = new SqlOrchestrationServiceSettings(this.testCredential.ConnectionString) { LoggerFactory = this.loggerFactory, + ExtendedSessionsEnabled = extendedSessions, }; + if (extendedSessionIdleTimeout.HasValue) + { + this.OrchestrationServiceOptions.ExtendedSessionIdleTimeout = extendedSessionIdleTimeout.Value; + } + // A mock orchestration service allows us to stub out specific methods for testing. this.OrchestrationServiceMock = new Mock(this.OrchestrationServiceOptions) { CallBase = true }; this.worker = new TaskHubWorker(this.OrchestrationServiceMock.Object, this.loggerFactory)