From 5c040f11cb2d0e011a2bdf858a5e501f20c2f128 Mon Sep 17 00:00:00 2001 From: Damian Suess Date: Thu, 8 Jan 2026 18:13:58 -0500 Subject: [PATCH 1/2] Allow CommandState to subscribe to a collection of message types via internal EventAggregator --- .gitignore | 3 +- .../StateTests/CommandStateTests.cs | 94 +++++++++ .../StateTests/CompositeStateTest.cs | 2 - .../StateTests/CustomStateTests.cs | 2 - .../StateTests/TestBase.cs | 2 + .../TestData/Models/CustomCommands.cs | 31 +++ .../TestData/States/CommandL3States.cs | 185 ++++++++++++++++++ .../TestData/States/CommandStateBase.cs | 32 +++ .../TestData/States/CompositeL1DiStates.cs | 12 +- .../TestData/States/DiStateBase.cs | 25 ++- source/Lite.StateMachine/EventAggregator.cs | 112 ++++++++--- source/Lite.StateMachine/ICommandState.cs | 9 + source/Lite.StateMachine/IEventAggregator.cs | 20 +- source/Lite.StateMachine/StateMachine.cs | 11 +- source/Lite.StateMachine/StateRegistration.cs | 9 +- source/Sample.Basics/DiStates/DiStateBase.cs | 6 +- 16 files changed, 502 insertions(+), 53 deletions(-) create mode 100644 source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs create mode 100644 source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs create mode 100644 source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs create mode 100644 source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs diff --git a/.gitignore b/.gitignore index 4d4ae60..e8baef7 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ [Dd]ebug/ [Dd]ebugPublic/ [Rr]elease/ +[Tt]estResults/ x64/ x86/ build/ @@ -92,5 +93,5 @@ sdkconfig.* ## USER DEFINED /[Dd]ocs/*.csv /[Dd]ocs/backup -/[Tt]ests +/[Ss]andbox /[Tt]ools diff --git a/source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs b/source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs new file mode 100644 index 0000000..4639545 --- /dev/null +++ b/source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs @@ -0,0 +1,94 @@ +// Copyright Xeno Innovations, Inc. 2025 +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading.Tasks; +using Lite.StateMachine.Tests.TestData; +using Lite.StateMachine.Tests.TestData.Models; +using Lite.StateMachine.Tests.TestData.Services; +using Lite.StateMachine.Tests.TestData.States.CommandL3DiStates; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace Lite.StateMachine.Tests.StateTests; + +[TestClass] +public class CommandStateTests : TestBase +{ + [TestMethod] + [DataRow(false, DisplayName = "Don't skip State3")] + [DataRow(true, DisplayName = "Skip State3")] + public async Task BasicState_Override_Executes_SuccessAsync(bool skipState3) + { + // Assemble with Dependency Injection + var services = new ServiceCollection() + .AddLogging(InlineTraceLogger(LogLevel.Trace)) + .AddSingleton() + .AddSingleton() + .BuildServiceProvider(); + + var msgService = services.GetRequiredService(); + var events = services.GetRequiredService(); + Func factory = t => ActivatorUtilities.CreateInstance(services, t); + + var ctxProperties = new PropertyBag() + { + { ParameterType.Counter, 0 }, + { ParameterType.TestExitEarly, skipState3 }, + }; + + var machine = new StateMachine(factory, events) + { + // Make sure we don't get stuck. + // And send some message after leaving Command state + // to make sure we unsubscribed successfully. + DefaultStateTimeoutMs = 3000, + }; + + machine + .RegisterState(CompositeL3.State1, CompositeL3.State2) + .RegisterComposite(CompositeL3.State2, initialChildStateId: CompositeL3.State2_Sub1, onSuccess: CompositeL3.State3) + .RegisterSubState(CompositeL3.State2_Sub1, parentStateId: CompositeL3.State2, onSuccess: CompositeL3.State2_Sub2) + .RegisterSubComposite(CompositeL3.State2_Sub2, parentStateId: CompositeL3.State2, initialChildStateId: CompositeL3.State2_Sub2_Sub1, onSuccess: CompositeL3.State2_Sub3) + .RegisterSubState(CompositeL3.State2_Sub2_Sub1, parentStateId: CompositeL3.State2_Sub2, onSuccess: CompositeL3.State2_Sub2_Sub2) + .RegisterSubState(CompositeL3.State2_Sub2_Sub2, parentStateId: CompositeL3.State2_Sub2, onSuccess: CompositeL3.State2_Sub2_Sub3) + .RegisterSubState(CompositeL3.State2_Sub2_Sub3, parentStateId: CompositeL3.State2_Sub2, onSuccess: null) + .RegisterSubState(CompositeL3.State2_Sub3, parentStateId: CompositeL3.State2, onSuccess: null) + .RegisterState(CompositeL3.State3, onSuccess: null); + + events.Subscribe(msg => + { + if (msg is ICustomCommand) + { + if (msg is UnlockCommand cmd) + { + // +100 check so we don't trigger this a 2nd time. + if (cmd.Counter > 100 && cmd.Counter < 200) + return; + + // NOTE: + // First we purposely publish 'OpenCommand' to prove that our OnMessage + // filters out the bad message, followed by publishing the REAL message. + if (cmd.Counter < 200) + events.Publish(new UnlockCommand { Counter = cmd.Counter + 100 }); + + events.Publish(new OpenResponse { Counter = cmd.Counter + 100 }); + + // NOTE: This doesn't reach State2_Sub2_Sub2 because it already left (GOOD) + events.Publish(new CloseResponse { Counter = cmd.Counter + 100 }); + } + } + }); + + // Act - Start your engine! + await machine.RunAsync(CompositeL3.State1, ctxProperties, null, TestContext.CancellationToken); + + // Assert Results + Assert.IsNotNull(machine); + Assert.IsNull(machine.Context); + + Assert.AreEqual(27, msgService.Counter1); + Assert.AreEqual(14, msgService.Counter2, "State2 Context.Param Count"); + Assert.AreEqual(skipState3 ? 13 : 13, msgService.Counter3); + } +} diff --git a/source/Lite.StateMachine.Tests/StateTests/CompositeStateTest.cs b/source/Lite.StateMachine.Tests/StateTests/CompositeStateTest.cs index 1c283d6..dc6e2e4 100644 --- a/source/Lite.StateMachine.Tests/StateTests/CompositeStateTest.cs +++ b/source/Lite.StateMachine.Tests/StateTests/CompositeStateTest.cs @@ -19,8 +19,6 @@ public class CompositeStateTest : TestBase public const string ParameterSubStateEntered = "SubEntered"; public const string SUCCESS = "success"; - public TestContext TestContext { get; set; } - [TestMethod] public async Task Level1_Basic_RegisterHelpers_SuccessTestAsync() { diff --git a/source/Lite.StateMachine.Tests/StateTests/CustomStateTests.cs b/source/Lite.StateMachine.Tests/StateTests/CustomStateTests.cs index 11ad0a6..0c7b9df 100644 --- a/source/Lite.StateMachine.Tests/StateTests/CustomStateTests.cs +++ b/source/Lite.StateMachine.Tests/StateTests/CustomStateTests.cs @@ -14,8 +14,6 @@ namespace Lite.StateMachine.Tests.StateTests; [TestClass] public class CustomStateTests : TestBase { - public TestContext TestContext { get; set; } - [TestMethod] [DataRow(false, DisplayName = "Don't skip State3")] [DataRow(true, DisplayName = "Skip State3")] diff --git a/source/Lite.StateMachine.Tests/StateTests/TestBase.cs b/source/Lite.StateMachine.Tests/StateTests/TestBase.cs index b003a61..241d697 100644 --- a/source/Lite.StateMachine.Tests/StateTests/TestBase.cs +++ b/source/Lite.StateMachine.Tests/StateTests/TestBase.cs @@ -9,6 +9,8 @@ namespace Lite.StateMachine.Tests.StateTests; public class TestBase { + public TestContext TestContext { get; set; } + /// ILogger Helper for generating clean in-line logs. /// Log level (Default: Trace). /// . diff --git a/source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs b/source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs new file mode 100644 index 0000000..94fc803 --- /dev/null +++ b/source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs @@ -0,0 +1,31 @@ +// Copyright Xeno Innovations, Inc. 2025 +// See the LICENSE file in the project root for more information. + +namespace Lite.StateMachine.Tests.TestData.Models; + +#pragma warning disable SA1649 // File name should match first type name +#pragma warning disable SA1402 // File may only contain a single type + +/// Signifies it's one of our event packets. +public interface ICustomCommand; + +/// Sample command sent by state machine. +public class UnlockCommand : ICustomCommand +{ + public int Counter { get; set; } = 0; +} + +/// Sample command response received by state machine. +public class OpenResponse : ICustomCommand +{ + public int Counter { get; set; } = 0; +} + +/// Sample command response received by state machine. +public class CloseResponse : ICustomCommand +{ + public int Counter { get; set; } = 0; +} + +#pragma warning restore SA1402 // File may only contain a single type +#pragma warning restore SA1649 // File name should match first type name diff --git a/source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs b/source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs new file mode 100644 index 0000000..2e6ab25 --- /dev/null +++ b/source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs @@ -0,0 +1,185 @@ +// Copyright Xeno Innovations, Inc. 2025 +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading.Tasks; +using Lite.StateMachine.Tests.TestData.Services; +using Microsoft.Extensions.Logging; + +#pragma warning disable SA1124 // Do not use regions +#pragma warning disable SA1649 // File name should match first type name +#pragma warning disable SA1402 // File may only contain a single type +#pragma warning disable IDE0130 // Namespace does not match folder structure + +/// +namespace Lite.StateMachine.Tests.TestData.States.CommandL3DiStates; + +public class CommonDiStateBase(IMessageService msg, ILogger logger) + : DiStateBase(msg, logger) + where TStateId : struct, Enum +{ + // Helper so we don't have to keep rewriting the same "override Task OnEnter(...)" + // 8 lines * 9 states.. useless + public override Task OnEnter(Context context) + { + context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); + MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); + return base.OnEnter(context); + } +} + +public class State1(IMessageService msg, ILogger log) + : DiStateBase(msg, log) +{ + public override Task OnEnter(Context context) + { + if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) + Assert.IsNull(context.PreviousStateId); + + context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); + MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); + return base.OnEnter(context); + } +} + +/// Level-1: Composite. +public class State2(IMessageService msg, ILogger log) + : CommonDiStateBase(msg, log) +{ + #region CodeMaid - DoNotReorder + + public override Task OnEntering(Context context) + { + // Demonstrate NEW parameter that will carry forward + context.Parameters.Add($"{context.CurrentStateId}!Anchor", Guid.NewGuid()); + return base.OnEntering(context); + } + + #endregion CodeMaid - DoNotReorder + + public override Task OnEnter(Context context) + { + if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) + Assert.AreEqual(CompositeL3.State1, context.PreviousStateId); + + // Demonstrate temporary parameter that will be discarded after State2's OnExit + context.Parameters.Add($"{context.CurrentStateId}!TEMP", Guid.NewGuid()); + return base.OnEnter(context); + } + + public override Task OnExit(Context context) + { + // Expected Count: 7 - State2_Sub2 is composite; therefore, discarded. + // State1,State2!Anchor,State2!TEMP,State2,State2_Sub1,State2_Sub2!Anchor,State2_Sub3 + MessageService.Counter2 = context.Parameters.Count; + return base.OnExit(context); + } +} + +/// Sublevel-2: State. +public class State2_Sub1(IMessageService msg, ILogger log) + : CommonDiStateBase(msg, log) +{ + public override Task OnEnter(Context context) + { + if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) + Assert.IsNull(context.PreviousStateId); + + return base.OnEnter(context); + } +} + +/// Sublevel-2: Composite. +public class State2_Sub2(IMessageService msg, ILogger log) + : CommonDiStateBase(msg, log) +{ + #region CodeMaid - DoNotReorder + + public override Task OnEntering(Context context) + { + // Demonstrate NEW parameter that will carry forward + context.Parameters.Add($"{context.CurrentStateId}!Anchor", Guid.NewGuid()); + return base.OnEntering(context); + } + + #endregion CodeMaid - DoNotReorder + + public override Task OnEnter(Context context) + { + if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) + Assert.AreEqual(CompositeL3.State2_Sub1, context.PreviousStateId); + + // Demonstrate temporary parameter that will be discarded after State2_Sub2's OnExit + context.Parameters.Add($"{context.CurrentStateId}!TEMP", Guid.NewGuid()); + return base.OnEnter(context); + } + + public override Task OnExit(Context context) + { + // Expected Count: 7 + MessageService.Counter3 = context.Parameters.Count; + return base.OnExit(context); + } +} + +/// Sublevel-3: State. +public class State2_Sub2_Sub1(IMessageService msg, ILogger log) + : CommonDiStateBase(msg, log) +{ + public override Task OnEnter(Context context) + { + if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) + Assert.IsNull(context.PreviousStateId); + + return base.OnEnter(context); + } +} + +/// Sublevel-3: State. +public class State2_Sub2_Sub2(IMessageService msg, ILogger log) + : CommonDiStateBase(msg, log); + +/// Sublevel-3: Last State. +public class State2_Sub2_Sub3(IMessageService msg, ILogger log) + : CommonDiStateBase(msg, log) +{ + public override Task OnEnter(Context context) + { + if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) + Assert.AreEqual(CompositeL3.State2_Sub2_Sub2, context.PreviousStateId); + + return base.OnEnter(context); + } +} + +/// Sublevel-2: Last State. +public class State2_Sub3(IMessageService msg, ILogger log) + : DiStateBase(msg, log) +{ + public override Task OnEnter(Context context) + { + if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) + Assert.AreEqual(CompositeL3.State2_Sub2, context.PreviousStateId); + + context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); + MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); + return base.OnEnter(context); + } +} + +/// Make sure not child-created context is there. +public class State3(IMessageService msg, ILogger log) + : DiStateBase(msg, log) +{ + public override Task OnEnter(Context context) + { + context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); + MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); + return base.OnEnter(context); + } +} + +#pragma warning restore IDE0130 // Namespace does not match folder structure +#pragma warning restore SA1649 // File name should match first type name +#pragma warning restore SA1402 // File may only contain a single type +#pragma warning restore SA1124 // Do not use regions diff --git a/source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs b/source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs new file mode 100644 index 0000000..2de8ad4 --- /dev/null +++ b/source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs @@ -0,0 +1,32 @@ +// Copyright Xeno Innovations, Inc. 2025 +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading.Tasks; +using Lite.StateMachine.Tests.TestData.Services; +using Microsoft.Extensions.Logging; + +namespace Lite.StateMachine.Tests.TestData.States; + +public class CommandStateBase(IMessageService msg, ILogger logger) + : DiStateBase(msg, logger), ICommandState + where TStateId : struct, Enum +{ + public Task OnMessage(Context context, object message) + { + Log.LogInformation("[OnEnter]"); + + // DO NOT AUTO-SUCCESS! Placed here as a note + ////context.NextState(Result.Success); + return Task.CompletedTask; + } + + public Task OnTimeout(Context context) + { + Log.LogInformation("[OnMessage]"); + + // DO NOT AUTO-SUCCESS! Placed here as a note + ////context.NextState(Result.Success); + return Task.CompletedTask; + } +} diff --git a/source/Lite.StateMachine.Tests/TestData/States/CompositeL1DiStates.cs b/source/Lite.StateMachine.Tests/TestData/States/CompositeL1DiStates.cs index 85140e0..a886056 100644 --- a/source/Lite.StateMachine.Tests/TestData/States/CompositeL1DiStates.cs +++ b/source/Lite.StateMachine.Tests/TestData/States/CompositeL1DiStates.cs @@ -25,7 +25,7 @@ public class ParentState(IMessageService msg, ILogger log) public override Task OnExit(Context context) { MessageService.Counter1++; - MessageService.AddMessage(GetType().Name + " OnExit"); + MessageService.AddMessage(GetType().Name + " [OnExit]"); Log.LogInformation("[OnExit] => {result}", context.LastChildResult); context.NextState(context.LastChildResult switch @@ -51,7 +51,7 @@ public class ParentSub_WaitMessageState(IMessageService msg, ILogger context) { MessageService.Counter1++; - MessageService.AddMessage(GetType().Name + " OnEnter"); + MessageService.AddMessage(GetType().Name + " [OnEnter]"); Log.LogInformation("[OnEnter] (Counter2: {cnt})", MessageService.Counter2); switch (MessageService.Counter2) @@ -84,7 +84,7 @@ public override Task OnEnter(Context context) public Task OnMessage(Context context, object message) { MessageService.Counter1++; - MessageService.AddMessage(GetType().Name + " OnEnter"); + MessageService.AddMessage(GetType().Name + " [OnEnter]"); if (message is not string response) { @@ -122,7 +122,7 @@ public Task OnMessage(Context context, object message) public Task OnTimeout(Context context) { MessageService.Counter1++; - MessageService.AddMessage(GetType().Name + " OnEnter"); + MessageService.AddMessage(GetType().Name + " [OnEnter]"); context.NextState(Result.Failure); Log.LogInformation("[OnTimeout] => Failure; (Publishing: ReceivedTimeout)"); @@ -145,7 +145,7 @@ public override Task OnEnter(Context context) { MessageService.Counter1++; MessageService.Counter2++; - MessageService.AddMessage(GetType().Name + " OnEnter"); + MessageService.AddMessage(GetType().Name + " [OnEnter]"); Log.LogInformation("[{StateName}] [OnEnter] => OK; Counter2++", GetType().Name); Debug.WriteLine($"[{GetType().Name}] [OnEnter] => OK; Counter2++"); @@ -164,7 +164,7 @@ public override Task OnEnter(Context context) { MessageService.Counter1++; MessageService.Counter2++; - MessageService.AddMessage(GetType().Name + " OnEnter"); + MessageService.AddMessage(GetType().Name + " [OnEnter]"); Log.LogInformation("[{StateName}] [OnEnter] => OK; (Counter2++)", GetType().Name); Debug.WriteLine($"[{GetType().Name}] [OnEnter] => OK; (Counter2++)"); diff --git a/source/Lite.StateMachine.Tests/TestData/States/DiStateBase.cs b/source/Lite.StateMachine.Tests/TestData/States/DiStateBase.cs index 0c4ba74..0b06b65 100644 --- a/source/Lite.StateMachine.Tests/TestData/States/DiStateBase.cs +++ b/source/Lite.StateMachine.Tests/TestData/States/DiStateBase.cs @@ -9,6 +9,8 @@ namespace Lite.StateMachine.Tests.TestData.States; +#pragma warning disable SA1124 // Do not use regions + public class DiStateBase(IMessageService msg, ILogger logger) : IState where TStateId : struct, Enum { @@ -22,35 +24,36 @@ public class DiStateBase(IMessageService msg, ILogger _msgService; - public virtual Task OnEnter(Context context) + #region Suppress CodeMaid Method Sorting + + public virtual Task OnEntering(Context context) { _msgService.Counter1++; - ////_msgService.AddMessage(GetType().Name + " OnEnter"); - _logger.LogInformation("[OnEnter] => OK"); + _logger.LogInformation("[OnEntering]"); if (HasExtraLogging) - Debug.WriteLine($"[{GetType().Name}] [OnEnter] => OK"); + Debug.WriteLine($"[{GetType().Name}] [OnEntering]"); - context.NextState(Result.Success); return Task.CompletedTask; } - public virtual Task OnEntering(Context context) + #endregion Suppress CodeMaid Method Sorting + + public virtual Task OnEnter(Context context) { _msgService.Counter1++; - ////_msgService.AddMessage(GetType().Name + " OnEntering"); - _logger.LogInformation("[OnEntering]"); + _logger.LogInformation("[OnEnter] => OK"); if (HasExtraLogging) - Debug.WriteLine($"[{GetType().Name}] [OnEntering]"); + Debug.WriteLine($"[{GetType().Name}] [OnEnter] => OK"); + context.NextState(Result.Success); return Task.CompletedTask; } public virtual Task OnExit(Context context) { _msgService.Counter1++; - ////_msgService.AddMessage(GetType().Name + " OnExit"); _logger.LogInformation("[OnExit]"); if (HasExtraLogging) @@ -60,3 +63,5 @@ public virtual Task OnExit(Context context) return Task.CompletedTask; } } + +#pragma warning restore SA1124 // Do not use regions diff --git a/source/Lite.StateMachine/EventAggregator.cs b/source/Lite.StateMachine/EventAggregator.cs index 875b0aa..274944c 100644 --- a/source/Lite.StateMachine/EventAggregator.cs +++ b/source/Lite.StateMachine/EventAggregator.cs @@ -11,38 +11,109 @@ public sealed class EventAggregator : IEventAggregator //// Pre .NET 9: private readonly object _lockGate = new(); private readonly System.Threading.Lock _lockGate = new(); - ////private readonly List> _subscribers = new(); - private readonly List> _subscribers = []; + // Typed subscribers keyed by exact runtime Type + private readonly Dictionary>> _typedSubscribers = []; + + // Wildcard subscribers (receive all messages) + private readonly List> _wildcardSubscribers = []; public void Publish(object message) { - Action[] snapshot; + if (message is null) + return; + + Action[] wildcardSnapshot; + Action[] typedSnapshot; + + var msgType = message.GetType(); + lock (_lockGate) - snapshot = _subscribers.ToArray(); + { + wildcardSnapshot = _wildcardSubscribers.ToArray(); + if (_typedSubscribers.TryGetValue(msgType, out var list)) + typedSnapshot = list.ToArray(); + else + typedSnapshot = []; + } - // Fan-out; handlers decide whether to consume or ignore. - foreach (var sub in snapshot) + // Deliver to typed subscribers first (exact matches), then wildcard + // Handlers decide whether to consume or ignore; exceptions are swallowed. +#pragma warning disable SA1501 // Statement should not be on a single line + foreach (var sub in typedSnapshot) { - try - { - sub(message); - } - catch - { - // Swallow to avoid breaking publication loop. - } + try { sub(message); } + catch { /* Swallow to avoid breaking publication loop. */ } + } + + foreach (var sub in wildcardSnapshot) + { + try { sub(message); } + catch { /* Swallow to avoid breaking publication loop. */ } } +#pragma warning restore SA1501 // Statement should not be on a single line } public IDisposable Subscribe(Action handler) { - // Was: Func handler) ArgumentNullException.ThrowIfNull(handler); + lock (_lockGate) + _wildcardSubscribers.Add(handler); + + return new Subscription(() => + { + lock (_lockGate) + _wildcardSubscribers.Remove(handler); + }); + } + + public IDisposable Subscribe(Action handler, params Type[] messageTypes) + { + ArgumentNullException.ThrowIfNull(handler); + messageTypes ??= []; + if (messageTypes.Length == 0) + { + // No types specified -> treat as wildcard to preserve backward compatibility + return Subscribe(handler); + } + + // Register handler under each provided type lock (_lockGate) - _subscribers.Add(handler); + { + foreach (var t in messageTypes) + { + if (t is null) + continue; + + if (!_typedSubscribers.TryGetValue(t, out var list)) + { + list = []; + _typedSubscribers[t] = list; + } - return new Subscription(() => _subscribers.Remove(handler)); + list.Add(handler); + } + } + + // Composite unsubscribe removes handler from each type list + return new Subscription(() => + { + lock (_lockGate) + { + foreach (var t in messageTypes) + { + if (t is null) + continue; + + if (_typedSubscribers.TryGetValue(t, out var list)) + { + list.Remove(handler); + if (list.Count == 0) + _typedSubscribers.Remove(t); + } + } + } + }); } private sealed class Subscription : IDisposable @@ -50,14 +121,11 @@ private sealed class Subscription : IDisposable private readonly Action _unsubscribe; private bool _disposed; - public Subscription(Action unsubscribe) => - _unsubscribe = unsubscribe; + public Subscription(Action unsubscribe) => _unsubscribe = unsubscribe; public void Dispose() { - if (_disposed) - return; - + if (_disposed) return; _disposed = true; _unsubscribe(); } diff --git a/source/Lite.StateMachine/ICommandState.cs b/source/Lite.StateMachine/ICommandState.cs index 47d8b6e..68607b0 100644 --- a/source/Lite.StateMachine/ICommandState.cs +++ b/source/Lite.StateMachine/ICommandState.cs @@ -2,6 +2,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Threading.Tasks; namespace Lite.StateMachine; @@ -11,6 +12,14 @@ namespace Lite.StateMachine; public interface ICommandState : IState where TStateId : struct, Enum { + /// + /// Gets the declared the message types this command state wants to receive. + /// Return multiple types to subscribe to all of them. + /// Return an empty collection (default) to receive all messages (wildcard). + /// Initialize to ();]]> or NULL when not in use. + IReadOnlyCollection SubscribedMessageTypes => []; + /// Gets optional override of timeout for this state; null uses machine default. int? TimeoutMs => null; diff --git a/source/Lite.StateMachine/IEventAggregator.cs b/source/Lite.StateMachine/IEventAggregator.cs index 52909b3..e1b8ba2 100644 --- a/source/Lite.StateMachine/IEventAggregator.cs +++ b/source/Lite.StateMachine/IEventAggregator.cs @@ -8,9 +8,27 @@ namespace Lite.StateMachine; /// Simple event aggregator for delivering messages to the current command state. public interface IEventAggregator { + /// Publish a message to subscribers. + /// Message object to publish. void Publish(object message); + /// Subscribe to all messages (wildcard). + /// Subscription listener method. + /// Disposable subscription. IDisposable Subscribe(Action handler); - ////IDisposable Subscribe(Func handler); + /// + /// Subscribe to one or more specific message types. Only messages whose + /// runtime type matches one of the provided + /// will be delivered to . + /// + /// Subscription listener method. + /// Message type(s) to subscribe to. + /// Disposable subscription. + IDisposable Subscribe(Action handler, params Type[] messageTypes); + + /// Generic convenience overload. + /// Message object to publish (not null). + /// Type to publish. + void Publish(T message) => Publish((object)message!); } diff --git a/source/Lite.StateMachine/StateMachine.cs b/source/Lite.StateMachine/StateMachine.cs index 755160f..958a0d1 100644 --- a/source/Lite.StateMachine/StateMachine.cs +++ b/source/Lite.StateMachine/StateMachine.cs @@ -135,6 +135,7 @@ public StateMachine RegisterState( OnSuccess = onSuccess, OnError = onError, OnFailure = onFailure, + //// vNext: SubscribedMessages = cmdMsgs ?? [], }; _states[stateId] = reg; @@ -452,15 +453,21 @@ private StateRegistration GetRegistration(TStateId stateId) { if (_eventAggregator is not null) { + // Subscribed message types or `Array.Empty()` for none + //// vNext: IReadOnlyCollection types2 = [.. cmd.SubscribedMessageTypes ?? [], .. reg.SubscribedMessageTypes ?? []]; + var types = cmd.SubscribedMessageTypes ?? []; + subscription = _eventAggregator.Subscribe(async (msgObj) => { if (cancellationToken.IsCancellationRequested || tcs.Task.IsCompleted) return; #pragma warning disable SA1501 // Statement should not be on a single line + // Swallow to avoid breaking publication loop try { await cmd.OnMessage(ctx, msgObj).ConfigureAwait(false); } catch { } #pragma warning restore SA1501 // Statement should not be on a single line - }); + }, + [.. types]); //// [.. types] == types.ToArray() var timeoutMs = cmd.TimeoutMs ?? DefaultCommandTimeoutMs; if (timeoutMs > 0) @@ -472,9 +479,7 @@ private StateRegistration GetRegistration(TStateId stateId) { await Task.Delay(timeoutMs, timeoutCts.Token).ConfigureAwait(false); if (!tcs.Task.IsCompleted && !timeoutCts.IsCancellationRequested) - { await cmd.OnTimeout(ctx).ConfigureAwait(false); - } } catch (TaskCanceledException) { diff --git a/source/Lite.StateMachine/StateRegistration.cs b/source/Lite.StateMachine/StateRegistration.cs index 90cc6b2..a0d09e1 100644 --- a/source/Lite.StateMachine/StateRegistration.cs +++ b/source/Lite.StateMachine/StateRegistration.cs @@ -14,12 +14,12 @@ internal sealed class StateRegistration /// OLD: >? Factory = default;]]>. public Func> Factory { get; init; } = default!; - /// Gets a value indicating whether this is a composite parent state or not. - public bool IsCompositeParent { get; init; } - /// Gets the initial child (for Composite states only). public TStateId? InitialChildId { get; init; } + /// Gets a value indicating whether this is a composite parent state or not. + public bool IsCompositeParent { get; init; } + /// Gets or sets an optional auto-wire OnError StateId transition. public TStateId? OnError { get; set; } = null; @@ -37,4 +37,7 @@ internal sealed class StateRegistration /// Gets the State Id, used by ExportUml for . public TStateId StateId { get; init; } + + /////// Gets the messages for to subscribe to. + ////public System.Collections.Generic.IReadOnlyCollection? SubscribedMessageTypes { get; init; } = null; } diff --git a/source/Sample.Basics/DiStates/DiStateBase.cs b/source/Sample.Basics/DiStates/DiStateBase.cs index 3d317f3..8c541c3 100644 --- a/source/Sample.Basics/DiStates/DiStateBase.cs +++ b/source/Sample.Basics/DiStates/DiStateBase.cs @@ -27,7 +27,7 @@ public class DiStateBase(IMessageService msg, ILogger context) { _msgService.Counter1++; - ////_msgService.AddMessage(GetType().Name + " OnEnter"); + ////_msgService.AddMessage(GetType().Name + " [OnEnter]"); _logger.LogInformation("[OnEnter] => OK"); if (HasExtraLogging) @@ -40,7 +40,7 @@ public virtual Task OnEnter(Context context) public virtual Task OnEntering(Context context) { _msgService.Counter1++; - ////_msgService.AddMessage(GetType().Name + " OnEntering"); + ////_msgService.AddMessage(GetType().Name + " [OnEntering]"); _logger.LogInformation("[OnEntering]"); if (HasExtraLogging) @@ -52,7 +52,7 @@ public virtual Task OnEntering(Context context) public virtual Task OnExit(Context context) { _msgService.Counter1++; - ////_msgService.AddMessage(GetType().Name + " OnExit"); + ////_msgService.AddMessage(GetType().Name + " [OnExit]"); _logger.LogInformation("[OnExit]"); if (HasExtraLogging) From eee8ed6d9e18c9ef5fe6f414281e7db3be149944 Mon Sep 17 00:00:00 2001 From: Damian Suess Date: Fri, 9 Jan 2026 10:00:20 -0500 Subject: [PATCH 2/2] Updated tests for Command State Tests ensuring OnMessage is entered once and no race conditions --- .../StateTests/CommandStateTests.cs | 37 ++-- .../TestData/Models/CustomCommands.cs | 2 +- .../TestData/Services/MessageService.cs | 6 + .../TestData/States/CommandL3States.cs | 161 ++++++++++++------ .../TestData/States/CommandStateBase.cs | 37 +++- 5 files changed, 170 insertions(+), 73 deletions(-) diff --git a/source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs b/source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs index 4639545..7a5e325 100644 --- a/source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs +++ b/source/Lite.StateMachine.Tests/StateTests/CommandStateTests.cs @@ -16,9 +16,7 @@ namespace Lite.StateMachine.Tests.StateTests; public class CommandStateTests : TestBase { [TestMethod] - [DataRow(false, DisplayName = "Don't skip State3")] - [DataRow(true, DisplayName = "Skip State3")] - public async Task BasicState_Override_Executes_SuccessAsync(bool skipState3) + public async Task BasicState_Override_Executes_SuccessAsync() { // Assemble with Dependency Injection var services = new ServiceCollection() @@ -34,27 +32,27 @@ public async Task BasicState_Override_Executes_SuccessAsync(bool skipState3) var ctxProperties = new PropertyBag() { { ParameterType.Counter, 0 }, - { ParameterType.TestExitEarly, skipState3 }, }; - var machine = new StateMachine(factory, events) + var machine = new StateMachine(factory, events) { // Make sure we don't get stuck. // And send some message after leaving Command state // to make sure we unsubscribed successfully. DefaultStateTimeoutMs = 3000, + IsContextPersistent = true, }; machine - .RegisterState(CompositeL3.State1, CompositeL3.State2) - .RegisterComposite(CompositeL3.State2, initialChildStateId: CompositeL3.State2_Sub1, onSuccess: CompositeL3.State3) - .RegisterSubState(CompositeL3.State2_Sub1, parentStateId: CompositeL3.State2, onSuccess: CompositeL3.State2_Sub2) - .RegisterSubComposite(CompositeL3.State2_Sub2, parentStateId: CompositeL3.State2, initialChildStateId: CompositeL3.State2_Sub2_Sub1, onSuccess: CompositeL3.State2_Sub3) - .RegisterSubState(CompositeL3.State2_Sub2_Sub1, parentStateId: CompositeL3.State2_Sub2, onSuccess: CompositeL3.State2_Sub2_Sub2) - .RegisterSubState(CompositeL3.State2_Sub2_Sub2, parentStateId: CompositeL3.State2_Sub2, onSuccess: CompositeL3.State2_Sub2_Sub3) - .RegisterSubState(CompositeL3.State2_Sub2_Sub3, parentStateId: CompositeL3.State2_Sub2, onSuccess: null) - .RegisterSubState(CompositeL3.State2_Sub3, parentStateId: CompositeL3.State2, onSuccess: null) - .RegisterState(CompositeL3.State3, onSuccess: null); + .RegisterState(StateId.State1, StateId.State2) + .RegisterComposite(StateId.State2, initialChildStateId: StateId.State2_Sub1, onSuccess: StateId.State3) + .RegisterSubState(StateId.State2_Sub1, parentStateId: StateId.State2, onSuccess: StateId.State2_Sub2) + .RegisterSubComposite(StateId.State2_Sub2, parentStateId: StateId.State2, initialChildStateId: StateId.State2_Sub2_Sub1, onSuccess: StateId.State2_Sub3) + .RegisterSubState(StateId.State2_Sub2_Sub1, parentStateId: StateId.State2_Sub2, onSuccess: StateId.State2_Sub2_Sub2) + .RegisterSubState(StateId.State2_Sub2_Sub2, parentStateId: StateId.State2_Sub2, onSuccess: StateId.State2_Sub2_Sub3) + .RegisterSubState(StateId.State2_Sub2_Sub3, parentStateId: StateId.State2_Sub2, onSuccess: null) + .RegisterSubState(StateId.State2_Sub3, parentStateId: StateId.State2, onSuccess: null) + .RegisterState(StateId.State3, onSuccess: null); events.Subscribe(msg => { @@ -72,7 +70,7 @@ public async Task BasicState_Override_Executes_SuccessAsync(bool skipState3) if (cmd.Counter < 200) events.Publish(new UnlockCommand { Counter = cmd.Counter + 100 }); - events.Publish(new OpenResponse { Counter = cmd.Counter + 100 }); + events.Publish(new UnlockResponse { Counter = cmd.Counter + 100 }); // NOTE: This doesn't reach State2_Sub2_Sub2 because it already left (GOOD) events.Publish(new CloseResponse { Counter = cmd.Counter + 100 }); @@ -81,14 +79,15 @@ public async Task BasicState_Override_Executes_SuccessAsync(bool skipState3) }); // Act - Start your engine! - await machine.RunAsync(CompositeL3.State1, ctxProperties, null, TestContext.CancellationToken); + await machine.RunAsync(StateId.State1, ctxProperties, null, TestContext.CancellationToken); // Assert Results Assert.IsNotNull(machine); Assert.IsNull(machine.Context); - Assert.AreEqual(27, msgService.Counter1); - Assert.AreEqual(14, msgService.Counter2, "State2 Context.Param Count"); - Assert.AreEqual(skipState3 ? 13 : 13, msgService.Counter3); + Assert.AreEqual(29, msgService.Counter1); + Assert.AreEqual(13, msgService.Counter2, "State2 Context.Param Count"); + Assert.AreEqual(12, msgService.Counter3); + Assert.AreEqual(2, msgService.Counter4); } } diff --git a/source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs b/source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs index 94fc803..b60b422 100644 --- a/source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs +++ b/source/Lite.StateMachine.Tests/TestData/Models/CustomCommands.cs @@ -16,7 +16,7 @@ public class UnlockCommand : ICustomCommand } /// Sample command response received by state machine. -public class OpenResponse : ICustomCommand +public class UnlockResponse : ICustomCommand { public int Counter { get; set; } = 0; } diff --git a/source/Lite.StateMachine.Tests/TestData/Services/MessageService.cs b/source/Lite.StateMachine.Tests/TestData/Services/MessageService.cs index a1a93b0..d233e14 100644 --- a/source/Lite.StateMachine.Tests/TestData/Services/MessageService.cs +++ b/source/Lite.StateMachine.Tests/TestData/Services/MessageService.cs @@ -21,6 +21,9 @@ public interface IMessageService /// Gets or sets the user's custom counter. int Counter3 { get; set; } + /// Gets or sets the user's custom counter. + int Counter4 { get; set; } + /// Gets a list of user's custom messages. List Messages { get; } @@ -40,6 +43,9 @@ public class MessageService : IMessageService /// public int Counter3 { get; set; } + /// + public int Counter4 { get; set; } + /// public List Messages { get; } = []; diff --git a/source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs b/source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs index 2e6ab25..2e9ef33 100644 --- a/source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs +++ b/source/Lite.StateMachine.Tests/TestData/States/CommandL3States.cs @@ -2,7 +2,9 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Threading.Tasks; +using Lite.StateMachine.Tests.TestData.Models; using Lite.StateMachine.Tests.TestData.Services; using Microsoft.Extensions.Logging; @@ -14,6 +16,21 @@ /// namespace Lite.StateMachine.Tests.TestData.States.CommandL3DiStates; +public enum StateId +{ + State1, + State2, + State2_Sub1, + State2_Sub2, + State2_Sub2_Sub1, + State2_Sub2_Sub2, + State2_Sub2_Sub3, + State2_Sub3, + State3, + Done, + Error, +} + public class CommonDiStateBase(IMessageService msg, ILogger logger) : DiStateBase(msg, logger) where TStateId : struct, Enum @@ -29,45 +46,77 @@ public override Task OnEnter(Context context) } public class State1(IMessageService msg, ILogger log) - : DiStateBase(msg, log) + : CommandStateBase(msg, log) { - public override Task OnEnter(Context context) + /// Gets message types for command state to subscribe to. + public override IReadOnlyCollection SubscribedMessageTypes => new[] { - if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) - Assert.IsNull(context.PreviousStateId); + //// typeof(OpenCommand), // <---- NOTE: Not needed + typeof(UnlockResponse), + }; + public override Task OnEnter(Context context) + { context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); + + context.EventAggregator?.Publish(new UnlockCommand { Counter = 1 }); + return base.OnEnter(context); } + + public override Task OnMessage(Context context, object message) + { + // NOTE: Cannot supply our own types yet. + ////public override Task OnMessage(Context context, OpenResponse message) + + if (message is not UnlockResponse) + { + // SHOUD NEVER BE HERE! As only 'OpenResponse' is in the filter list + context.NextState(Result.Error); + return Task.CompletedTask; + } + + MessageService.Counter4++; + + context.NextState(Result.Success); + return base.OnMessage(context, message); + } + + public override Task OnTimeout(Context context) + { + context.NextState(Result.Error); + + // Never gets thrown + ////throw new TimeoutException(); + + return base.OnTimeout(context); + } } /// Level-1: Composite. public class State2(IMessageService msg, ILogger log) - : CommonDiStateBase(msg, log) + : CommonDiStateBase(msg, log) { - #region CodeMaid - DoNotReorder + #region CodeMaid - Suppress method sorting - public override Task OnEntering(Context context) + public override Task OnEntering(Context context) { // Demonstrate NEW parameter that will carry forward context.Parameters.Add($"{context.CurrentStateId}!Anchor", Guid.NewGuid()); return base.OnEntering(context); } - #endregion CodeMaid - DoNotReorder + #endregion CodeMaid - Suppress method sorting - public override Task OnEnter(Context context) + public override Task OnEnter(Context context) { - if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) - Assert.AreEqual(CompositeL3.State1, context.PreviousStateId); - // Demonstrate temporary parameter that will be discarded after State2's OnExit context.Parameters.Add($"{context.CurrentStateId}!TEMP", Guid.NewGuid()); return base.OnEnter(context); } - public override Task OnExit(Context context) + public override Task OnExit(Context context) { // Expected Count: 7 - State2_Sub2 is composite; therefore, discarded. // State1,State2!Anchor,State2!TEMP,State2,State2_Sub1,State2_Sub2!Anchor,State2_Sub3 @@ -78,24 +127,18 @@ public override Task OnExit(Context context) /// Sublevel-2: State. public class State2_Sub1(IMessageService msg, ILogger log) - : CommonDiStateBase(msg, log) + : CommonDiStateBase(msg, log) { - public override Task OnEnter(Context context) - { - if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) - Assert.IsNull(context.PreviousStateId); - - return base.OnEnter(context); - } + public override Task OnEnter(Context context) => base.OnEnter(context); } /// Sublevel-2: Composite. public class State2_Sub2(IMessageService msg, ILogger log) - : CommonDiStateBase(msg, log) + : CommonDiStateBase(msg, log) { #region CodeMaid - DoNotReorder - public override Task OnEntering(Context context) + public override Task OnEntering(Context context) { // Demonstrate NEW parameter that will carry forward context.Parameters.Add($"{context.CurrentStateId}!Anchor", Guid.NewGuid()); @@ -104,17 +147,14 @@ public override Task OnEntering(Context context) #endregion CodeMaid - DoNotReorder - public override Task OnEnter(Context context) + public override Task OnEnter(Context context) { - if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) - Assert.AreEqual(CompositeL3.State2_Sub1, context.PreviousStateId); - // Demonstrate temporary parameter that will be discarded after State2_Sub2's OnExit context.Parameters.Add($"{context.CurrentStateId}!TEMP", Guid.NewGuid()); return base.OnEnter(context); } - public override Task OnExit(Context context) + public override Task OnExit(Context context) { // Expected Count: 7 MessageService.Counter3 = context.Parameters.Count; @@ -124,43 +164,62 @@ public override Task OnExit(Context context) /// Sublevel-3: State. public class State2_Sub2_Sub1(IMessageService msg, ILogger log) - : CommonDiStateBase(msg, log) + : CommonDiStateBase(msg, log) { - public override Task OnEnter(Context context) - { - if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) - Assert.IsNull(context.PreviousStateId); - - return base.OnEnter(context); - } + public override Task OnEnter(Context context) => base.OnEnter(context); } /// Sublevel-3: State. public class State2_Sub2_Sub2(IMessageService msg, ILogger log) - : CommonDiStateBase(msg, log); - -/// Sublevel-3: Last State. -public class State2_Sub2_Sub3(IMessageService msg, ILogger log) - : CommonDiStateBase(msg, log) + : CommandStateBase(msg, log) { - public override Task OnEnter(Context context) + /// Gets message types for command state to subscribe to. + public override IReadOnlyCollection SubscribedMessageTypes => + [ + typeof(UnlockResponse), + typeof(CloseResponse), + ]; + + public override Task OnEnter(Context context) { - if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) - Assert.AreEqual(CompositeL3.State2_Sub2_Sub2, context.PreviousStateId); + context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); + MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); + // NOTE: + // 1) We're sending the same OpenCommand to prove that State1's OnMessage isn't called a 2nd time. + // 2) CloseResponse doesn't reached our OnMessage because we left already. + context.EventAggregator?.Publish(new UnlockCommand { Counter = 200 }); return base.OnEnter(context); } + + public override Task OnMessage(Context context, object message) + { + MessageService.Counter4++; + + context.NextState(Result.Success); + return base.OnMessage(context, message); + } + + public override Task OnTimeout(Context context) + { + context.NextState(Result.Error); + return base.OnTimeout(context); + } +} + +/// Sublevel-3: Last State. +public class State2_Sub2_Sub3(IMessageService msg, ILogger log) +: CommonDiStateBase(msg, log) +{ + public override Task OnEnter(Context context) => base.OnEnter(context); } /// Sublevel-2: Last State. public class State2_Sub3(IMessageService msg, ILogger log) - : DiStateBase(msg, log) + : DiStateBase(msg, log) { - public override Task OnEnter(Context context) + public override Task OnEnter(Context context) { - if (context.ParameterAsBool(ParameterType.TestExecutionOrder)) - Assert.AreEqual(CompositeL3.State2_Sub2, context.PreviousStateId); - context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); return base.OnEnter(context); @@ -169,9 +228,9 @@ public override Task OnEnter(Context context) /// Make sure not child-created context is there. public class State3(IMessageService msg, ILogger log) - : DiStateBase(msg, log) + : DiStateBase(msg, log) { - public override Task OnEnter(Context context) + public override Task OnEnter(Context context) { context.Parameters.Add(context.CurrentStateId.ToString(), Guid.NewGuid()); MessageService.AddMessage($"[Keys-{context.CurrentStateId}]: {string.Join(",", context.Parameters.Keys)}"); diff --git a/source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs b/source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs index 2de8ad4..d0256f0 100644 --- a/source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs +++ b/source/Lite.StateMachine.Tests/TestData/States/CommandStateBase.cs @@ -2,29 +2,62 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; +using System.Diagnostics; using System.Threading.Tasks; using Lite.StateMachine.Tests.TestData.Services; using Microsoft.Extensions.Logging; namespace Lite.StateMachine.Tests.TestData.States; +/// Command Class (doesn't auto-success, you must implement `context.NextState(Result.xxx);`. +/// State class object. +/// State Id. public class CommandStateBase(IMessageService msg, ILogger logger) : DiStateBase(msg, logger), ICommandState where TStateId : struct, Enum { - public Task OnMessage(Context context, object message) + //// NEEDS TESTED: public virtual IReadOnlyCollection SubscribedMessageTypes => []; + ////public virtual IReadOnlyCollection SubscribedMessageTypes => Array.Empty(); + public virtual IReadOnlyCollection SubscribedMessageTypes => []; + + public override Task OnEnter(Context context) { + MessageService.Counter1++; Log.LogInformation("[OnEnter]"); + if (HasExtraLogging) + Debug.WriteLine($"[{GetType().Name}] [{context.CurrentStateId}] [OnEnter]"); + // DO NOT AUTO-SUCCESS! Placed here as a note ////context.NextState(Result.Success); return Task.CompletedTask; } - public Task OnTimeout(Context context) + public virtual Task OnMessage(Context context, object message) { + // Note: Cannot supply our own object type + //// public virtual Task OnMessage(Context context, OpenResponse message) + + MessageService.Counter1++; Log.LogInformation("[OnMessage]"); + if (HasExtraLogging) + Debug.WriteLine($"[{GetType().Name}] [{context.CurrentStateId}] [OnMessage]"); + + // DO NOT AUTO-SUCCESS! Placed here as a note + ////context.NextState(Result.Success); + return Task.CompletedTask; + } + + public virtual Task OnTimeout(Context context) + { + MessageService.Counter1++; + Log.LogInformation("[OnTimeout] => "); + + if (HasExtraLogging) + Debug.WriteLine($"[{GetType().Name}] [{context.CurrentStateId}] [OnTimeout]"); + // DO NOT AUTO-SUCCESS! Placed here as a note ////context.NextState(Result.Success); return Task.CompletedTask;