diff --git a/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj b/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj index 1b6e0e3..73010ad 100644 --- a/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj +++ b/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj @@ -41,6 +41,8 @@ + + @@ -54,6 +56,7 @@ + @@ -64,6 +67,7 @@ + diff --git a/Forge.TreeWalker.UnitTests/test/ActionsAndCallbacks/DependencyInjectionTestAction.cs b/Forge.TreeWalker.UnitTests/test/ActionsAndCallbacks/DependencyInjectionTestAction.cs new file mode 100644 index 0000000..7ed6d91 --- /dev/null +++ b/Forge.TreeWalker.UnitTests/test/ActionsAndCallbacks/DependencyInjectionTestAction.cs @@ -0,0 +1,197 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// +// ForgeActions used to test dependency injection capabilities. +// +//----------------------------------------------------------------------- + +namespace Microsoft.Forge.TreeWalker.UnitTests +{ + using System; + using System.Threading.Tasks; + + using Microsoft.Forge.Attributes; + using Microsoft.Forge.TreeWalker; + + /// + /// A simple service interface used to verify constructor injection in ForgeActions. + /// + public interface IDiTestService + { + string GetValue(); + } + + /// + /// Default implementation of IDiTestService. + /// + public class DiTestService : IDiTestService + { + private readonly string value; + + public DiTestService(string value) + { + this.value = value; + } + + public string GetValue() + { + return this.value; + } + } + + /// + /// A second service interface used to verify multi-dependency constructor injection. + /// + public interface IDiTestCounter + { + int Increment(); + int GetCount(); + } + + /// + /// Default implementation of IDiTestCounter. + /// + public class DiTestCounter : IDiTestCounter + { + private int count; + + public int Increment() + { + return ++this.count; + } + + public int GetCount() + { + return this.count; + } + } + + /// + /// A ForgeAction that requires a single IDiTestService dependency via constructor injection. + /// Used to verify that ServiceProviderActionFactory resolves actions with injected services. + /// + [ForgeAction] + public class SingleDependencyAction : BaseAction + { + private readonly IDiTestService testService; + + public SingleDependencyAction(IDiTestService testService) + { + this.testService = testService ?? throw new ArgumentNullException(nameof(testService)); + } + + public override Task RunAction(ActionContext actionContext) + { + return Task.FromResult(new ActionResponse + { + Status = "Success", + Output = this.testService.GetValue() + }); + } + } + + /// + /// A ForgeAction that requires multiple dependencies via constructor injection. + /// Used to verify that ServiceProviderActionFactory resolves actions with multiple injected services. + /// + [ForgeAction] + public class MultipleDependencyAction : BaseAction + { + private readonly IDiTestService testService; + private readonly IDiTestCounter testCounter; + + public MultipleDependencyAction(IDiTestService testService, IDiTestCounter testCounter) + { + this.testService = testService ?? throw new ArgumentNullException(nameof(testService)); + this.testCounter = testCounter ?? throw new ArgumentNullException(nameof(testCounter)); + } + + public override Task RunAction(ActionContext actionContext) + { + int count = this.testCounter.Increment(); + + return Task.FromResult(new ActionResponse + { + Status = this.testService.GetValue(), + StatusCode = count, + Output = string.Format("{0}_{1}", this.testService.GetValue(), count) + }); + } + } + + /// + /// A ForgeAction with a typed input and constructor-injected dependency. + /// Used to verify that DI works alongside Forge's ActionInput deserialization. + /// + [ForgeAction(InputType: typeof(DiActionWithInputTypeInput))] + public class DiActionWithInputType : BaseAction + { + private readonly IDiTestService testService; + + public DiActionWithInputType(IDiTestService testService) + { + this.testService = testService ?? throw new ArgumentNullException(nameof(testService)); + } + + public override Task RunAction(ActionContext actionContext) + { + var input = (DiActionWithInputTypeInput)actionContext.ActionInput; + + return Task.FromResult(new ActionResponse + { + Status = "Success", + Output = string.Format("{0}_{1}", this.testService.GetValue(), input.MessageProperty) + }); + } + } + + public class DiActionWithInputTypeInput + { + public string MessageProperty { get; set; } + } + + /// + /// A custom IForgeActionFactory implementation for testing. + /// Tracks how many times CreateAction was called. + /// + public class TestCustomActionFactory : IForgeActionFactory + { + private readonly IDiTestService testService; + public int CreateActionCallCount { get; private set; } + + public TestCustomActionFactory(IDiTestService testService) + { + this.testService = testService; + } + + public BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters) + { + this.CreateActionCallCount++; + + if (actionType == typeof(SubroutineAction)) + { + return new SubroutineAction(parameters); + } + + if (actionType == typeof(SingleDependencyAction)) + { + return new SingleDependencyAction(this.testService); + } + + if (actionType == typeof(MultipleDependencyAction)) + { + return new MultipleDependencyAction(this.testService, new DiTestCounter()); + } + + if (actionType == typeof(DiActionWithInputType)) + { + return new DiActionWithInputType(this.testService); + } + + // Fall back to parameterless constructor for other test actions. + return (BaseAction)Activator.CreateInstance(actionType); + } + } +} diff --git a/Forge.TreeWalker.UnitTests/test/DependencyInjectionTests.cs b/Forge.TreeWalker.UnitTests/test/DependencyInjectionTests.cs new file mode 100644 index 0000000..eee4f99 --- /dev/null +++ b/Forge.TreeWalker.UnitTests/test/DependencyInjectionTests.cs @@ -0,0 +1,528 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// +// Tests the dependency injection capabilities of Forge's action factory system. +// +//----------------------------------------------------------------------- + +namespace Microsoft.Forge.TreeWalker.UnitTests +{ + using System; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + using Microsoft.Forge.DataContracts; + using Microsoft.Forge.TreeWalker; + using Microsoft.Forge.TreeWalker.ForgeExceptions; + using Newtonsoft.Json; + + [TestClass] + public class DependencyInjectionTests + { + #region Test Schemas + + private const string SingleDependencySchema = @" + { + ""Tree"": { + ""Root"": { + ""Type"": ""Action"", + ""Actions"": { + ""Root_SingleDependencyAction"": { + ""Action"": ""SingleDependencyAction"" + } + } + } + } + }"; + + private const string MultipleDependencySchema = @" + { + ""Tree"": { + ""Root"": { + ""Type"": ""Action"", + ""Actions"": { + ""Root_MultipleDependencyAction"": { + ""Action"": ""MultipleDependencyAction"" + } + } + } + } + }"; + + private const string ParallelDependencyActionsSchema = @" + { + ""Tree"": { + ""Root"": { + ""Type"": ""Action"", + ""Actions"": { + ""Root_MultipleDependencyAction_1"": { + ""Action"": ""MultipleDependencyAction"" + }, + ""Root_MultipleDependencyAction_2"": { + ""Action"": ""MultipleDependencyAction"" + } + } + } + } + }"; + + private const string DiActionWithInputTypeSchema = @" + { + ""Tree"": { + ""Root"": { + ""Type"": ""Action"", + ""Actions"": { + ""Root_DiActionWithInputType"": { + ""Action"": ""DiActionWithInputType"", + ""Input"": { + ""MessageProperty"": ""Hello"" + } + } + } + } + } + }"; + + private const string DiActionWithChildSelectorSchema = @" + { + ""Tree"": { + ""Root"": { + ""Type"": ""Action"", + ""Actions"": { + ""Root_SingleDependencyAction"": { + ""Action"": ""SingleDependencyAction"" + } + }, + ""ChildSelector"": [ + { + ""Label"": ""Success"", + ""ShouldSelect"": ""C#|Session.GetLastActionResponse().Status == \""Success\"""", + ""Child"": ""SuccessLeaf"" + } + ] + }, + ""SuccessLeaf"": { + ""Type"": ""Leaf"" + } + } + }"; + + private const string MultiNodeDiSchema = @" + { + ""Tree"": { + ""Root"": { + ""Type"": ""Action"", + ""Actions"": { + ""Root_SingleDependencyAction"": { + ""Action"": ""SingleDependencyAction"" + } + }, + ""ChildSelector"": [ + { + ""Label"": ""Success"", + ""ShouldSelect"": ""C#|Session.GetLastActionResponse().Status == \""Success\"""", + ""Child"": ""SecondNode"" + } + ] + }, + ""SecondNode"": { + ""Type"": ""Action"", + ""Actions"": { + ""SecondNode_MultipleDependencyAction"": { + ""Action"": ""MultipleDependencyAction"" + } + } + } + } + }"; + + #endregion Test Schemas + + #region Helper Methods + + private static ServiceProvider BuildServiceProvider(string serviceValue = "Injected") + { + var services = new ServiceCollection(); + services.AddSingleton(new DiTestService(serviceValue)); + services.AddSingleton(); + return services.BuildServiceProvider(); + } + + private static TreeWalkerSession CreateSession( + string jsonSchema, + IForgeActionFactory actionFactory = null) + { + Guid sessionId = Guid.NewGuid(); + var forgeState = new ForgeDictionary(new Dictionary(), sessionId, sessionId); + var callbacks = new TreeWalkerCallbacksV2(); + var token = new CancellationTokenSource().Token; + ForgeTree forgeTree = JsonConvert.DeserializeObject(jsonSchema); + + var parameters = new TreeWalkerParameters( + sessionId, + forgeTree, + forgeState, + callbacks, + token) + { + UserContext = new ForgeUserContext(), + ForgeActionsAssembly = typeof(SingleDependencyAction).Assembly, + ActionFactory = actionFactory + }; + + return new TreeWalkerSession(parameters); + } + + #endregion Helper Methods + + #region DefaultForgeActionFactory Tests + + [TestMethod] + public async Task TestDefaultFactory_IsUsedWhenNoServiceProviderOrFactorySet() + { + // Test - When ActionFactory is not set, DefaultForgeActionFactory is used. + // Actions requiring DI dependencies will fail because Activator.CreateInstance cannot resolve them. + var session = CreateSession(SingleDependencySchema); + + await Assert.ThrowsExceptionAsync(async () => + { + await session.WalkTree("Root"); + }, "Expected WalkTree to fail because SingleDependencyAction requires constructor injection which DefaultForgeActionFactory cannot provide."); + + Assert.AreEqual("TimeoutOnAction", session.Status, + "Expected TimeoutOnAction because DefaultForgeActionFactory uses Activator.CreateInstance which cannot resolve IDiTestService."); + } + + [TestMethod] + public void TestDefaultFactory_CreateAction_ParameterlessConstructor() + { + // Test - DefaultForgeActionFactory can create actions with parameterless constructors. + var factory = new DefaultForgeActionFactory(); + var action = factory.CreateAction(typeof(TardigradeAction), null); + + Assert.IsNotNull(action, "Expected DefaultForgeActionFactory to create an action with a parameterless constructor."); + Assert.IsInstanceOfType(action, typeof(TardigradeAction)); + } + + [TestMethod] + public void TestDefaultFactory_CreateAction_SubroutineAction() + { + // Test - DefaultForgeActionFactory can create SubroutineAction with TreeWalkerParameters. + Guid sessionId = Guid.NewGuid(); + var forgeState = new ForgeDictionary(new Dictionary(), sessionId, sessionId); + var callbacks = new TreeWalkerCallbacksV2(); + ForgeTree forgeTree = JsonConvert.DeserializeObject(SingleDependencySchema); + var parameters = new TreeWalkerParameters(sessionId, forgeTree, forgeState, callbacks, CancellationToken.None); + + var factory = new DefaultForgeActionFactory(); + var action = factory.CreateAction(typeof(SubroutineAction), parameters); + + Assert.IsNotNull(action, "Expected DefaultForgeActionFactory to create SubroutineAction."); + Assert.IsInstanceOfType(action, typeof(SubroutineAction)); + } + + #endregion DefaultForgeActionFactory Tests + + #region ServiceProviderActionFactory Tests + + [TestMethod] + public void TestServiceProviderFactory_Constructor_NullServiceProvider_ThrowsArgumentNullException() + { + // Test - ServiceProviderActionFactory constructor throws on null serviceProvider. + Assert.ThrowsException(() => + { + new ServiceProviderActionFactory(null); + }, "Expected ArgumentNullException when creating ServiceProviderActionFactory with null serviceProvider."); + } + + [TestMethod] + public async Task TestServiceProviderFactory_SingleDependency_WalkTree_Success() + { + // Test - WalkTree with ServiceProviderActionFactory, action resolved via ActivatorUtilities with single dependency. + using (var sp = BuildServiceProvider("InjectedValue")) + { + var session = CreateSession(SingleDependencySchema, actionFactory: new ServiceProviderActionFactory(sp)); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete successfully with ServiceProviderActionFactory resolving SingleDependencyAction."); + + ActionResponse response = await session.GetLastActionResponseAsync(); + Assert.AreEqual("Success", response.Status); + Assert.AreEqual("InjectedValue", response.Output, + "Expected the injected IDiTestService value to be returned as Output."); + } + } + + [TestMethod] + public async Task TestServiceProviderFactory_MultipleDependencies_WalkTree_Success() + { + // Test - WalkTree with ServiceProviderActionFactory, action resolved via ActivatorUtilities with multiple dependencies. + using (var sp = BuildServiceProvider("MultiDep")) + { + var session = CreateSession(MultipleDependencySchema, actionFactory: new ServiceProviderActionFactory(sp)); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete successfully with ServiceProviderActionFactory resolving MultipleDependencyAction."); + + ActionResponse response = await session.GetLastActionResponseAsync(); + Assert.AreEqual("MultiDep", response.Status, + "Expected the injected IDiTestService value to be returned as Status."); + Assert.AreEqual(1, response.StatusCode, + "Expected the IDiTestCounter to have been incremented to 1."); + Assert.AreEqual("MultiDep_1", response.Output, + "Expected the combined output from both injected services."); + } + } + + [TestMethod] + public async Task TestServiceProviderFactory_ParallelActions_WalkTree_Success() + { + // Test - WalkTree with two parallel actions on the same node, both resolved via ServiceProviderActionFactory. + using (var sp = BuildServiceProvider("Parallel")) + { + var session = CreateSession(ParallelDependencyActionsSchema, actionFactory: new ServiceProviderActionFactory(sp)); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete successfully with parallel DI-resolved actions."); + + ActionResponse response1 = await session.GetOutputAsync("Root_MultipleDependencyAction_1"); + ActionResponse response2 = await session.GetOutputAsync("Root_MultipleDependencyAction_2"); + + Assert.IsNotNull(response1, "Expected ActionResponse for first parallel action."); + Assert.IsNotNull(response2, "Expected ActionResponse for second parallel action."); + Assert.AreEqual("Parallel", response1.Status); + Assert.AreEqual("Parallel", response2.Status); + } + } + + [TestMethod] + public async Task TestServiceProviderFactory_ActionWithInputType_WalkTree_Success() + { + // Test - WalkTree with a DI-resolved action that also uses a typed ActionInput. + using (var sp = BuildServiceProvider("WithInput")) + { + var session = CreateSession(DiActionWithInputTypeSchema, actionFactory: new ServiceProviderActionFactory(sp)); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete successfully with DI action using typed ActionInput."); + + ActionResponse response = await session.GetLastActionResponseAsync(); + Assert.AreEqual("Success", response.Status); + Assert.AreEqual("WithInput_Hello", response.Output, + "Expected output to combine injected service value with deserialized ActionInput."); + } + } + + [TestMethod] + public async Task TestServiceProviderFactory_MultiNodeWalk_WalkTree_Success() + { + // Test - WalkTree across multiple nodes, each with DI-resolved actions. + using (var sp = BuildServiceProvider("MultiNode")) + { + var session = CreateSession(MultiNodeDiSchema, actionFactory: new ServiceProviderActionFactory(sp)); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete successfully across multiple nodes with DI-resolved actions."); + + ActionResponse rootResponse = await session.GetOutputAsync("Root_SingleDependencyAction"); + Assert.IsNotNull(rootResponse, "Expected ActionResponse from Root node."); + Assert.AreEqual("MultiNode", rootResponse.Output); + + ActionResponse secondResponse = await session.GetOutputAsync("SecondNode_MultipleDependencyAction"); + Assert.IsNotNull(secondResponse, "Expected ActionResponse from SecondNode."); + Assert.AreEqual("MultiNode", secondResponse.Status); + } + } + + [TestMethod] + public async Task TestServiceProviderFactory_ChildSelectorEvaluatesAfterDiAction_WalkTree_Success() + { + // Test - WalkTree where a DI-resolved action's response drives ChildSelector evaluation. + using (var sp = BuildServiceProvider("Selector")) + { + var session = CreateSession(DiActionWithChildSelectorSchema, actionFactory: new ServiceProviderActionFactory(sp)); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete and select the SuccessLeaf child based on DI action response."); + + string currentNode = await session.GetCurrentTreeNode(); + Assert.AreEqual("SuccessLeaf", currentNode, + "Expected tree walker to visit SuccessLeaf after DI action returned Status=Success."); + } + } + + [TestMethod] + public void TestServiceProviderFactory_CreateAction_ResolvesViaActivatorUtilities() + { + // Test - ServiceProviderActionFactory always uses ActivatorUtilities.CreateInstance to resolve actions, + // injecting registered dependencies via constructor injection. + var services = new ServiceCollection(); + services.AddSingleton(new DiTestService("Resolved")); + services.AddSingleton(); + + using (var sp = services.BuildServiceProvider()) + { + var factory = new ServiceProviderActionFactory(sp); + var action = factory.CreateAction(typeof(MultipleDependencyAction), null); + + Assert.IsNotNull(action, "Expected factory to resolve the action via ActivatorUtilities."); + Assert.IsInstanceOfType(action, typeof(MultipleDependencyAction)); + } + } + + [TestMethod] + public void TestServiceProviderFactory_CreateAction_SubroutineAction() + { + // Test - ServiceProviderActionFactory creates SubroutineAction with TreeWalkerParameters. + Guid sessionId = Guid.NewGuid(); + var forgeState = new ForgeDictionary(new Dictionary(), sessionId, sessionId); + var callbacks = new TreeWalkerCallbacksV2(); + ForgeTree forgeTree = JsonConvert.DeserializeObject(SingleDependencySchema); + var parameters = new TreeWalkerParameters(sessionId, forgeTree, forgeState, callbacks, CancellationToken.None); + + using (var sp = BuildServiceProvider()) + { + var factory = new ServiceProviderActionFactory(sp); + var action = factory.CreateAction(typeof(SubroutineAction), parameters); + + Assert.IsNotNull(action, "Expected factory to create SubroutineAction via ActivatorUtilities."); + Assert.IsInstanceOfType(action, typeof(SubroutineAction)); + } + } + + #endregion ServiceProviderActionFactory Tests + + #region Custom ActionFactory Tests + + [TestMethod] + public async Task TestCustomFactory_WalkTree_Success() + { + // Test - WalkTree with a custom IForgeActionFactory that manually wires dependencies. + var testService = new DiTestService("Custom"); + var customFactory = new TestCustomActionFactory(testService); + + var session = CreateSession(SingleDependencySchema, actionFactory: customFactory); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete successfully with custom ActionFactory."); + + ActionResponse response = await session.GetLastActionResponseAsync(); + Assert.AreEqual("Success", response.Status); + Assert.AreEqual("Custom", response.Output, + "Expected the custom factory to provide the injected IDiTestService value."); + } + + [TestMethod] + public async Task TestCustomFactory_CreateActionCalled_WalkTree_Success() + { + // Test - Verify that the custom factory's CreateAction was actually invoked during WalkTree. + var testService = new DiTestService("Tracked"); + var customFactory = new TestCustomActionFactory(testService); + + var session = CreateSession(SingleDependencySchema, actionFactory: customFactory); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status); + Assert.IsTrue(customFactory.CreateActionCallCount > 0, + "Expected custom ActionFactory.CreateAction to have been called at least once during WalkTree."); + } + + [TestMethod] + public async Task TestCustomFactory_OverridesDefaultFactory() + { + // Test - When ActionFactory is set, it is used instead of the DefaultForgeActionFactory. + var testService = new DiTestService("FactoryWins"); + var customFactory = new TestCustomActionFactory(testService); + + var session = CreateSession(SingleDependencySchema, actionFactory: customFactory); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status); + + ActionResponse response = await session.GetLastActionResponseAsync(); + Assert.AreEqual("FactoryWins", response.Output, + "Expected custom ActionFactory to be used instead of DefaultForgeActionFactory."); + Assert.IsTrue(customFactory.CreateActionCallCount > 0, + "Expected custom ActionFactory.CreateAction to have been called."); + } + + [TestMethod] + public async Task TestCustomFactory_MultipleDependencies_WalkTree_Success() + { + // Test - Custom factory resolves an action with multiple dependencies. + var testService = new DiTestService("CustomMulti"); + var customFactory = new TestCustomActionFactory(testService); + + var session = CreateSession(MultipleDependencySchema, actionFactory: customFactory); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected WalkTree to complete successfully with custom factory resolving MultipleDependencyAction."); + + ActionResponse response = await session.GetLastActionResponseAsync(); + Assert.AreEqual("CustomMulti", response.Status); + Assert.AreEqual(1, response.StatusCode); + } + + #endregion Custom ActionFactory Tests + + #region Factory Fallback Tests + + [TestMethod] + public async Task TestFactoryFallback_NoFactorySet_UsesDefaultFactory() + { + // Test - When no ActionFactory is set, the DefaultForgeActionFactory is used. + // TardigradeAction has a parameterless constructor so it should work. + string tardigradeOnlySchema = @" + { + ""Tree"": { + ""Root"": { + ""Type"": ""Action"", + ""Actions"": { + ""Root_TardigradeAction"": { + ""Action"": ""TardigradeAction"" + } + } + } + } + }"; + + var session = CreateSession(tardigradeOnlySchema); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected DefaultForgeActionFactory to resolve TardigradeAction via Activator.CreateInstance."); + } + + [TestMethod] + public async Task TestFactoryFallback_ServiceProviderActionFactory_WalkTree_Success() + { + // Test - When ActionFactory is explicitly set to a ServiceProviderActionFactory, it resolves actions. + using (var sp = BuildServiceProvider("Explicit")) + { + var session = CreateSession(SingleDependencySchema, actionFactory: new ServiceProviderActionFactory(sp)); + string status = await session.WalkTree("Root"); + + Assert.AreEqual("RanToCompletion", status, + "Expected ServiceProviderActionFactory set via ActionFactory to resolve actions."); + + ActionResponse response = await session.GetLastActionResponseAsync(); + Assert.AreEqual("Explicit", response.Output); + } + } + + #endregion Factory Fallback Tests + } +} diff --git a/Forge.TreeWalker/Forge.TreeWalker.csproj b/Forge.TreeWalker/Forge.TreeWalker.csproj index 22cbab8..e26e512 100644 --- a/Forge.TreeWalker/Forge.TreeWalker.csproj +++ b/Forge.TreeWalker/Forge.TreeWalker.csproj @@ -21,6 +21,7 @@ all + diff --git a/Forge.TreeWalker/src/DefaultForgeActionFactory.cs b/Forge.TreeWalker/src/DefaultForgeActionFactory.cs new file mode 100644 index 0000000..9947482 --- /dev/null +++ b/Forge.TreeWalker/src/DefaultForgeActionFactory.cs @@ -0,0 +1,35 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// +// The DefaultForgeActionFactory class. +// +//----------------------------------------------------------------------- + +namespace Microsoft.Forge.TreeWalker +{ + using System; + + /// + /// The default IForgeActionFactory implementation that uses Activator.CreateInstance to instantiate ForgeAction classes. + /// This preserves the original behavior of Forge when no custom factory is provided. + /// + public class DefaultForgeActionFactory : IForgeActionFactory + { + /// + /// Creates an instance of the specified ForgeAction type using Activator.CreateInstance. + /// + /// The Type of the ForgeAction class to instantiate. + /// The TreeWalkerParameters for the current session. + /// An instance of the specified action type. + public BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters) + { + if (actionType == typeof(SubroutineAction)) + { + return (SubroutineAction)Activator.CreateInstance(actionType, parameters); + } + return (BaseAction)Activator.CreateInstance(actionType); + } + } +} diff --git a/Forge.TreeWalker/src/IForgeActionFactory.cs b/Forge.TreeWalker/src/IForgeActionFactory.cs new file mode 100644 index 0000000..d394bfc --- /dev/null +++ b/Forge.TreeWalker/src/IForgeActionFactory.cs @@ -0,0 +1,28 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// +// The IForgeActionFactory interface. +// +//----------------------------------------------------------------------- + +namespace Microsoft.Forge.TreeWalker +{ + using System; + + /// + /// The IForgeActionFactory interface defines a factory for creating ForgeAction instances. + /// Implement this interface to integrate a dependency injection container of your choice. + /// + public interface IForgeActionFactory + { + /// + /// Creates an instance of the specified ForgeAction type. + /// + /// The Type of the ForgeAction class to instantiate. This type derives from . + /// The TreeWalkerParameters for the current session. Provided for native actions that require it (e.g. SubroutineAction). + /// An instance of the specified action type. + BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters); + } +} diff --git a/Forge.TreeWalker/src/ServiceProviderActionFactory.cs b/Forge.TreeWalker/src/ServiceProviderActionFactory.cs new file mode 100644 index 0000000..2aa8bee --- /dev/null +++ b/Forge.TreeWalker/src/ServiceProviderActionFactory.cs @@ -0,0 +1,46 @@ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// +// The ServiceProviderActionFactory class. +// +//----------------------------------------------------------------------- + +namespace Microsoft.Forge.TreeWalker +{ + using System; + using Microsoft.Extensions.DependencyInjection; + + /// + /// An implementation that resolves ForgeAction instances + /// from an backed by Microsoft.Extensions.DependencyInjection. + /// + public class ServiceProviderActionFactory : IForgeActionFactory + { + private readonly IServiceProvider serviceProvider; + + /// + /// Initializes a new instance of the class. + /// + /// The service provider used to resolve ForgeAction instances. + public ServiceProviderActionFactory(IServiceProvider serviceProvider) + { + this.serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); + } + + /// + /// Creates an instance of the specified ForgeAction type by resolving it from the service provider. + /// + /// The Type of the ForgeAction class to instantiate. + /// The TreeWalkerParameters for the current session. + /// An instance of the specified action type. + public BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters) + { + if (actionType == typeof(SubroutineAction)) + { + return (SubroutineAction)ActivatorUtilities.CreateInstance(this.serviceProvider, actionType, parameters); + } + return (BaseAction)ActivatorUtilities.CreateInstance(this.serviceProvider, actionType); + } + } +} diff --git a/Forge.TreeWalker/src/TreeWalkerParameters.cs b/Forge.TreeWalker/src/TreeWalkerParameters.cs index 2092946..c16d48b 100644 --- a/Forge.TreeWalker/src/TreeWalkerParameters.cs +++ b/Forge.TreeWalker/src/TreeWalkerParameters.cs @@ -137,6 +137,13 @@ public class TreeWalkerParameters /// public bool RetryCurrentTreeNodeActions { get; set; } + /// + /// The factory used to create ForgeAction instances. + /// Implement to integrate a dependency injection container of your choice. + /// When it is null, Forge will use the which creates instances via Activator.CreateInstance. + /// + public IForgeActionFactory ActionFactory { get; set; } + #endregion #region Constructor with ITreeWalkerCallbacks, [DEPRECATED] diff --git a/Forge.TreeWalker/src/TreeWalkerSession.cs b/Forge.TreeWalker/src/TreeWalkerSession.cs index 38efbf3..da4977e 100644 --- a/Forge.TreeWalker/src/TreeWalkerSession.cs +++ b/Forge.TreeWalker/src/TreeWalkerSession.cs @@ -9,6 +9,11 @@ namespace Microsoft.Forge.TreeWalker { + using Microsoft.Forge.Attributes; + using Microsoft.Forge.DataContracts; + using Microsoft.Forge.TreeWalker.ForgeExceptions; + using Newtonsoft.Json; + using Newtonsoft.Json.Linq; using System; using System.Collections; using System.Collections.Generic; @@ -17,11 +22,6 @@ namespace Microsoft.Forge.TreeWalker using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; - using Microsoft.Forge.Attributes; - using Microsoft.Forge.DataContracts; - using Microsoft.Forge.TreeWalker.ForgeExceptions; - using Newtonsoft.Json; - using Newtonsoft.Json.Linq; /// /// The TreeWalkerSession tries to walk the given tree schema to completion. @@ -162,6 +162,7 @@ public TreeWalkerSession(TreeWalkerParameters parameters) // Initialize properties from optional TreeWalkerParameters properties. GetActionsMapFromAssembly(parameters.ForgeActionsAssembly, out this.actionsMap); this.Parameters.ExternalExecutors = parameters.ExternalExecutors ?? new Dictionary>>(); + this.Parameters.ActionFactory = parameters.ActionFactory ?? new DefaultForgeActionFactory(); // TODO: Consider using a factory pattern to construct asynchronously. this.Parameters.TreeInput = this.GetOrCommitTreeInput(parameters.TreeInput).GetAwaiter().GetResult(); @@ -910,17 +911,8 @@ await this.EvaluateDynamicProperty(treeAction.Properties, null).ConfigureAwait(f this.Parameters.RootSessionId ); - // Instantiate the BaseAction-derived ActionType class and invoke the RunAction method on it. - object actionObject; - if (actionDefinition.ActionType == typeof(SubroutineAction)) - { - // Special initializer is used for the native SubroutineAction. - actionObject = Activator.CreateInstance(actionDefinition.ActionType, this.Parameters); - } - else - { - actionObject = Activator.CreateInstance(actionDefinition.ActionType); - } + //Create the action object using the ActionFactory and kick off the action task. + BaseAction actionObject = this.Parameters.ActionFactory.CreateAction(actionDefinition.ActionType, this.Parameters); MethodInfo method = typeof(BaseAction).GetMethod("RunAction"); Task runActionTask = (Task) method.Invoke(actionObject, new object[] { actionContext });