diff --git a/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj b/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj
index 1b6e0e3..73010ad 100644
--- a/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj
+++ b/Forge.TreeWalker.UnitTests/Forge.TreeWalker.UnitTests.csproj
@@ -41,6 +41,8 @@
+
+
@@ -54,6 +56,7 @@
+
@@ -64,6 +67,7 @@
+
diff --git a/Forge.TreeWalker.UnitTests/test/ActionsAndCallbacks/DependencyInjectionTestAction.cs b/Forge.TreeWalker.UnitTests/test/ActionsAndCallbacks/DependencyInjectionTestAction.cs
new file mode 100644
index 0000000..7ed6d91
--- /dev/null
+++ b/Forge.TreeWalker.UnitTests/test/ActionsAndCallbacks/DependencyInjectionTestAction.cs
@@ -0,0 +1,197 @@
+//-----------------------------------------------------------------------
+//
+// Copyright (c) Microsoft Corporation. All rights reserved.
+//
+//
+// ForgeActions used to test dependency injection capabilities.
+//
+//-----------------------------------------------------------------------
+
+namespace Microsoft.Forge.TreeWalker.UnitTests
+{
+ using System;
+ using System.Threading.Tasks;
+
+ using Microsoft.Forge.Attributes;
+ using Microsoft.Forge.TreeWalker;
+
+ ///
+ /// A simple service interface used to verify constructor injection in ForgeActions.
+ ///
+ public interface IDiTestService
+ {
+ string GetValue();
+ }
+
+ ///
+ /// Default implementation of IDiTestService.
+ ///
+ public class DiTestService : IDiTestService
+ {
+ private readonly string value;
+
+ public DiTestService(string value)
+ {
+ this.value = value;
+ }
+
+ public string GetValue()
+ {
+ return this.value;
+ }
+ }
+
+ ///
+ /// A second service interface used to verify multi-dependency constructor injection.
+ ///
+ public interface IDiTestCounter
+ {
+ int Increment();
+ int GetCount();
+ }
+
+ ///
+ /// Default implementation of IDiTestCounter.
+ ///
+ public class DiTestCounter : IDiTestCounter
+ {
+ private int count;
+
+ public int Increment()
+ {
+ return ++this.count;
+ }
+
+ public int GetCount()
+ {
+ return this.count;
+ }
+ }
+
+ ///
+ /// A ForgeAction that requires a single IDiTestService dependency via constructor injection.
+ /// Used to verify that ServiceProviderActionFactory resolves actions with injected services.
+ ///
+ [ForgeAction]
+ public class SingleDependencyAction : BaseAction
+ {
+ private readonly IDiTestService testService;
+
+ public SingleDependencyAction(IDiTestService testService)
+ {
+ this.testService = testService ?? throw new ArgumentNullException(nameof(testService));
+ }
+
+ public override Task RunAction(ActionContext actionContext)
+ {
+ return Task.FromResult(new ActionResponse
+ {
+ Status = "Success",
+ Output = this.testService.GetValue()
+ });
+ }
+ }
+
+ ///
+ /// A ForgeAction that requires multiple dependencies via constructor injection.
+ /// Used to verify that ServiceProviderActionFactory resolves actions with multiple injected services.
+ ///
+ [ForgeAction]
+ public class MultipleDependencyAction : BaseAction
+ {
+ private readonly IDiTestService testService;
+ private readonly IDiTestCounter testCounter;
+
+ public MultipleDependencyAction(IDiTestService testService, IDiTestCounter testCounter)
+ {
+ this.testService = testService ?? throw new ArgumentNullException(nameof(testService));
+ this.testCounter = testCounter ?? throw new ArgumentNullException(nameof(testCounter));
+ }
+
+ public override Task RunAction(ActionContext actionContext)
+ {
+ int count = this.testCounter.Increment();
+
+ return Task.FromResult(new ActionResponse
+ {
+ Status = this.testService.GetValue(),
+ StatusCode = count,
+ Output = string.Format("{0}_{1}", this.testService.GetValue(), count)
+ });
+ }
+ }
+
+ ///
+ /// A ForgeAction with a typed input and constructor-injected dependency.
+ /// Used to verify that DI works alongside Forge's ActionInput deserialization.
+ ///
+ [ForgeAction(InputType: typeof(DiActionWithInputTypeInput))]
+ public class DiActionWithInputType : BaseAction
+ {
+ private readonly IDiTestService testService;
+
+ public DiActionWithInputType(IDiTestService testService)
+ {
+ this.testService = testService ?? throw new ArgumentNullException(nameof(testService));
+ }
+
+ public override Task RunAction(ActionContext actionContext)
+ {
+ var input = (DiActionWithInputTypeInput)actionContext.ActionInput;
+
+ return Task.FromResult(new ActionResponse
+ {
+ Status = "Success",
+ Output = string.Format("{0}_{1}", this.testService.GetValue(), input.MessageProperty)
+ });
+ }
+ }
+
+ public class DiActionWithInputTypeInput
+ {
+ public string MessageProperty { get; set; }
+ }
+
+ ///
+ /// A custom IForgeActionFactory implementation for testing.
+ /// Tracks how many times CreateAction was called.
+ ///
+ public class TestCustomActionFactory : IForgeActionFactory
+ {
+ private readonly IDiTestService testService;
+ public int CreateActionCallCount { get; private set; }
+
+ public TestCustomActionFactory(IDiTestService testService)
+ {
+ this.testService = testService;
+ }
+
+ public BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters)
+ {
+ this.CreateActionCallCount++;
+
+ if (actionType == typeof(SubroutineAction))
+ {
+ return new SubroutineAction(parameters);
+ }
+
+ if (actionType == typeof(SingleDependencyAction))
+ {
+ return new SingleDependencyAction(this.testService);
+ }
+
+ if (actionType == typeof(MultipleDependencyAction))
+ {
+ return new MultipleDependencyAction(this.testService, new DiTestCounter());
+ }
+
+ if (actionType == typeof(DiActionWithInputType))
+ {
+ return new DiActionWithInputType(this.testService);
+ }
+
+ // Fall back to parameterless constructor for other test actions.
+ return (BaseAction)Activator.CreateInstance(actionType);
+ }
+ }
+}
diff --git a/Forge.TreeWalker.UnitTests/test/DependencyInjectionTests.cs b/Forge.TreeWalker.UnitTests/test/DependencyInjectionTests.cs
new file mode 100644
index 0000000..eee4f99
--- /dev/null
+++ b/Forge.TreeWalker.UnitTests/test/DependencyInjectionTests.cs
@@ -0,0 +1,528 @@
+//-----------------------------------------------------------------------
+//
+// Copyright (c) Microsoft Corporation. All rights reserved.
+//
+//
+// Tests the dependency injection capabilities of Forge's action factory system.
+//
+//-----------------------------------------------------------------------
+
+namespace Microsoft.Forge.TreeWalker.UnitTests
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using Microsoft.Extensions.DependencyInjection;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ using Microsoft.Forge.DataContracts;
+ using Microsoft.Forge.TreeWalker;
+ using Microsoft.Forge.TreeWalker.ForgeExceptions;
+ using Newtonsoft.Json;
+
+ [TestClass]
+ public class DependencyInjectionTests
+ {
+ #region Test Schemas
+
+ private const string SingleDependencySchema = @"
+ {
+ ""Tree"": {
+ ""Root"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""Root_SingleDependencyAction"": {
+ ""Action"": ""SingleDependencyAction""
+ }
+ }
+ }
+ }
+ }";
+
+ private const string MultipleDependencySchema = @"
+ {
+ ""Tree"": {
+ ""Root"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""Root_MultipleDependencyAction"": {
+ ""Action"": ""MultipleDependencyAction""
+ }
+ }
+ }
+ }
+ }";
+
+ private const string ParallelDependencyActionsSchema = @"
+ {
+ ""Tree"": {
+ ""Root"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""Root_MultipleDependencyAction_1"": {
+ ""Action"": ""MultipleDependencyAction""
+ },
+ ""Root_MultipleDependencyAction_2"": {
+ ""Action"": ""MultipleDependencyAction""
+ }
+ }
+ }
+ }
+ }";
+
+ private const string DiActionWithInputTypeSchema = @"
+ {
+ ""Tree"": {
+ ""Root"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""Root_DiActionWithInputType"": {
+ ""Action"": ""DiActionWithInputType"",
+ ""Input"": {
+ ""MessageProperty"": ""Hello""
+ }
+ }
+ }
+ }
+ }
+ }";
+
+ private const string DiActionWithChildSelectorSchema = @"
+ {
+ ""Tree"": {
+ ""Root"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""Root_SingleDependencyAction"": {
+ ""Action"": ""SingleDependencyAction""
+ }
+ },
+ ""ChildSelector"": [
+ {
+ ""Label"": ""Success"",
+ ""ShouldSelect"": ""C#|Session.GetLastActionResponse().Status == \""Success\"""",
+ ""Child"": ""SuccessLeaf""
+ }
+ ]
+ },
+ ""SuccessLeaf"": {
+ ""Type"": ""Leaf""
+ }
+ }
+ }";
+
+ private const string MultiNodeDiSchema = @"
+ {
+ ""Tree"": {
+ ""Root"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""Root_SingleDependencyAction"": {
+ ""Action"": ""SingleDependencyAction""
+ }
+ },
+ ""ChildSelector"": [
+ {
+ ""Label"": ""Success"",
+ ""ShouldSelect"": ""C#|Session.GetLastActionResponse().Status == \""Success\"""",
+ ""Child"": ""SecondNode""
+ }
+ ]
+ },
+ ""SecondNode"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""SecondNode_MultipleDependencyAction"": {
+ ""Action"": ""MultipleDependencyAction""
+ }
+ }
+ }
+ }
+ }";
+
+ #endregion Test Schemas
+
+ #region Helper Methods
+
+ private static ServiceProvider BuildServiceProvider(string serviceValue = "Injected")
+ {
+ var services = new ServiceCollection();
+ services.AddSingleton(new DiTestService(serviceValue));
+ services.AddSingleton();
+ return services.BuildServiceProvider();
+ }
+
+ private static TreeWalkerSession CreateSession(
+ string jsonSchema,
+ IForgeActionFactory actionFactory = null)
+ {
+ Guid sessionId = Guid.NewGuid();
+ var forgeState = new ForgeDictionary(new Dictionary(), sessionId, sessionId);
+ var callbacks = new TreeWalkerCallbacksV2();
+ var token = new CancellationTokenSource().Token;
+ ForgeTree forgeTree = JsonConvert.DeserializeObject(jsonSchema);
+
+ var parameters = new TreeWalkerParameters(
+ sessionId,
+ forgeTree,
+ forgeState,
+ callbacks,
+ token)
+ {
+ UserContext = new ForgeUserContext(),
+ ForgeActionsAssembly = typeof(SingleDependencyAction).Assembly,
+ ActionFactory = actionFactory
+ };
+
+ return new TreeWalkerSession(parameters);
+ }
+
+ #endregion Helper Methods
+
+ #region DefaultForgeActionFactory Tests
+
+ [TestMethod]
+ public async Task TestDefaultFactory_IsUsedWhenNoServiceProviderOrFactorySet()
+ {
+ // Test - When ActionFactory is not set, DefaultForgeActionFactory is used.
+ // Actions requiring DI dependencies will fail because Activator.CreateInstance cannot resolve them.
+ var session = CreateSession(SingleDependencySchema);
+
+ await Assert.ThrowsExceptionAsync(async () =>
+ {
+ await session.WalkTree("Root");
+ }, "Expected WalkTree to fail because SingleDependencyAction requires constructor injection which DefaultForgeActionFactory cannot provide.");
+
+ Assert.AreEqual("TimeoutOnAction", session.Status,
+ "Expected TimeoutOnAction because DefaultForgeActionFactory uses Activator.CreateInstance which cannot resolve IDiTestService.");
+ }
+
+ [TestMethod]
+ public void TestDefaultFactory_CreateAction_ParameterlessConstructor()
+ {
+ // Test - DefaultForgeActionFactory can create actions with parameterless constructors.
+ var factory = new DefaultForgeActionFactory();
+ var action = factory.CreateAction(typeof(TardigradeAction), null);
+
+ Assert.IsNotNull(action, "Expected DefaultForgeActionFactory to create an action with a parameterless constructor.");
+ Assert.IsInstanceOfType(action, typeof(TardigradeAction));
+ }
+
+ [TestMethod]
+ public void TestDefaultFactory_CreateAction_SubroutineAction()
+ {
+ // Test - DefaultForgeActionFactory can create SubroutineAction with TreeWalkerParameters.
+ Guid sessionId = Guid.NewGuid();
+ var forgeState = new ForgeDictionary(new Dictionary(), sessionId, sessionId);
+ var callbacks = new TreeWalkerCallbacksV2();
+ ForgeTree forgeTree = JsonConvert.DeserializeObject(SingleDependencySchema);
+ var parameters = new TreeWalkerParameters(sessionId, forgeTree, forgeState, callbacks, CancellationToken.None);
+
+ var factory = new DefaultForgeActionFactory();
+ var action = factory.CreateAction(typeof(SubroutineAction), parameters);
+
+ Assert.IsNotNull(action, "Expected DefaultForgeActionFactory to create SubroutineAction.");
+ Assert.IsInstanceOfType(action, typeof(SubroutineAction));
+ }
+
+ #endregion DefaultForgeActionFactory Tests
+
+ #region ServiceProviderActionFactory Tests
+
+ [TestMethod]
+ public void TestServiceProviderFactory_Constructor_NullServiceProvider_ThrowsArgumentNullException()
+ {
+ // Test - ServiceProviderActionFactory constructor throws on null serviceProvider.
+ Assert.ThrowsException(() =>
+ {
+ new ServiceProviderActionFactory(null);
+ }, "Expected ArgumentNullException when creating ServiceProviderActionFactory with null serviceProvider.");
+ }
+
+ [TestMethod]
+ public async Task TestServiceProviderFactory_SingleDependency_WalkTree_Success()
+ {
+ // Test - WalkTree with ServiceProviderActionFactory, action resolved via ActivatorUtilities with single dependency.
+ using (var sp = BuildServiceProvider("InjectedValue"))
+ {
+ var session = CreateSession(SingleDependencySchema, actionFactory: new ServiceProviderActionFactory(sp));
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete successfully with ServiceProviderActionFactory resolving SingleDependencyAction.");
+
+ ActionResponse response = await session.GetLastActionResponseAsync();
+ Assert.AreEqual("Success", response.Status);
+ Assert.AreEqual("InjectedValue", response.Output,
+ "Expected the injected IDiTestService value to be returned as Output.");
+ }
+ }
+
+ [TestMethod]
+ public async Task TestServiceProviderFactory_MultipleDependencies_WalkTree_Success()
+ {
+ // Test - WalkTree with ServiceProviderActionFactory, action resolved via ActivatorUtilities with multiple dependencies.
+ using (var sp = BuildServiceProvider("MultiDep"))
+ {
+ var session = CreateSession(MultipleDependencySchema, actionFactory: new ServiceProviderActionFactory(sp));
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete successfully with ServiceProviderActionFactory resolving MultipleDependencyAction.");
+
+ ActionResponse response = await session.GetLastActionResponseAsync();
+ Assert.AreEqual("MultiDep", response.Status,
+ "Expected the injected IDiTestService value to be returned as Status.");
+ Assert.AreEqual(1, response.StatusCode,
+ "Expected the IDiTestCounter to have been incremented to 1.");
+ Assert.AreEqual("MultiDep_1", response.Output,
+ "Expected the combined output from both injected services.");
+ }
+ }
+
+ [TestMethod]
+ public async Task TestServiceProviderFactory_ParallelActions_WalkTree_Success()
+ {
+ // Test - WalkTree with two parallel actions on the same node, both resolved via ServiceProviderActionFactory.
+ using (var sp = BuildServiceProvider("Parallel"))
+ {
+ var session = CreateSession(ParallelDependencyActionsSchema, actionFactory: new ServiceProviderActionFactory(sp));
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete successfully with parallel DI-resolved actions.");
+
+ ActionResponse response1 = await session.GetOutputAsync("Root_MultipleDependencyAction_1");
+ ActionResponse response2 = await session.GetOutputAsync("Root_MultipleDependencyAction_2");
+
+ Assert.IsNotNull(response1, "Expected ActionResponse for first parallel action.");
+ Assert.IsNotNull(response2, "Expected ActionResponse for second parallel action.");
+ Assert.AreEqual("Parallel", response1.Status);
+ Assert.AreEqual("Parallel", response2.Status);
+ }
+ }
+
+ [TestMethod]
+ public async Task TestServiceProviderFactory_ActionWithInputType_WalkTree_Success()
+ {
+ // Test - WalkTree with a DI-resolved action that also uses a typed ActionInput.
+ using (var sp = BuildServiceProvider("WithInput"))
+ {
+ var session = CreateSession(DiActionWithInputTypeSchema, actionFactory: new ServiceProviderActionFactory(sp));
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete successfully with DI action using typed ActionInput.");
+
+ ActionResponse response = await session.GetLastActionResponseAsync();
+ Assert.AreEqual("Success", response.Status);
+ Assert.AreEqual("WithInput_Hello", response.Output,
+ "Expected output to combine injected service value with deserialized ActionInput.");
+ }
+ }
+
+ [TestMethod]
+ public async Task TestServiceProviderFactory_MultiNodeWalk_WalkTree_Success()
+ {
+ // Test - WalkTree across multiple nodes, each with DI-resolved actions.
+ using (var sp = BuildServiceProvider("MultiNode"))
+ {
+ var session = CreateSession(MultiNodeDiSchema, actionFactory: new ServiceProviderActionFactory(sp));
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete successfully across multiple nodes with DI-resolved actions.");
+
+ ActionResponse rootResponse = await session.GetOutputAsync("Root_SingleDependencyAction");
+ Assert.IsNotNull(rootResponse, "Expected ActionResponse from Root node.");
+ Assert.AreEqual("MultiNode", rootResponse.Output);
+
+ ActionResponse secondResponse = await session.GetOutputAsync("SecondNode_MultipleDependencyAction");
+ Assert.IsNotNull(secondResponse, "Expected ActionResponse from SecondNode.");
+ Assert.AreEqual("MultiNode", secondResponse.Status);
+ }
+ }
+
+ [TestMethod]
+ public async Task TestServiceProviderFactory_ChildSelectorEvaluatesAfterDiAction_WalkTree_Success()
+ {
+ // Test - WalkTree where a DI-resolved action's response drives ChildSelector evaluation.
+ using (var sp = BuildServiceProvider("Selector"))
+ {
+ var session = CreateSession(DiActionWithChildSelectorSchema, actionFactory: new ServiceProviderActionFactory(sp));
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete and select the SuccessLeaf child based on DI action response.");
+
+ string currentNode = await session.GetCurrentTreeNode();
+ Assert.AreEqual("SuccessLeaf", currentNode,
+ "Expected tree walker to visit SuccessLeaf after DI action returned Status=Success.");
+ }
+ }
+
+ [TestMethod]
+ public void TestServiceProviderFactory_CreateAction_ResolvesViaActivatorUtilities()
+ {
+ // Test - ServiceProviderActionFactory always uses ActivatorUtilities.CreateInstance to resolve actions,
+ // injecting registered dependencies via constructor injection.
+ var services = new ServiceCollection();
+ services.AddSingleton(new DiTestService("Resolved"));
+ services.AddSingleton();
+
+ using (var sp = services.BuildServiceProvider())
+ {
+ var factory = new ServiceProviderActionFactory(sp);
+ var action = factory.CreateAction(typeof(MultipleDependencyAction), null);
+
+ Assert.IsNotNull(action, "Expected factory to resolve the action via ActivatorUtilities.");
+ Assert.IsInstanceOfType(action, typeof(MultipleDependencyAction));
+ }
+ }
+
+ [TestMethod]
+ public void TestServiceProviderFactory_CreateAction_SubroutineAction()
+ {
+ // Test - ServiceProviderActionFactory creates SubroutineAction with TreeWalkerParameters.
+ Guid sessionId = Guid.NewGuid();
+ var forgeState = new ForgeDictionary(new Dictionary(), sessionId, sessionId);
+ var callbacks = new TreeWalkerCallbacksV2();
+ ForgeTree forgeTree = JsonConvert.DeserializeObject(SingleDependencySchema);
+ var parameters = new TreeWalkerParameters(sessionId, forgeTree, forgeState, callbacks, CancellationToken.None);
+
+ using (var sp = BuildServiceProvider())
+ {
+ var factory = new ServiceProviderActionFactory(sp);
+ var action = factory.CreateAction(typeof(SubroutineAction), parameters);
+
+ Assert.IsNotNull(action, "Expected factory to create SubroutineAction via ActivatorUtilities.");
+ Assert.IsInstanceOfType(action, typeof(SubroutineAction));
+ }
+ }
+
+ #endregion ServiceProviderActionFactory Tests
+
+ #region Custom ActionFactory Tests
+
+ [TestMethod]
+ public async Task TestCustomFactory_WalkTree_Success()
+ {
+ // Test - WalkTree with a custom IForgeActionFactory that manually wires dependencies.
+ var testService = new DiTestService("Custom");
+ var customFactory = new TestCustomActionFactory(testService);
+
+ var session = CreateSession(SingleDependencySchema, actionFactory: customFactory);
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete successfully with custom ActionFactory.");
+
+ ActionResponse response = await session.GetLastActionResponseAsync();
+ Assert.AreEqual("Success", response.Status);
+ Assert.AreEqual("Custom", response.Output,
+ "Expected the custom factory to provide the injected IDiTestService value.");
+ }
+
+ [TestMethod]
+ public async Task TestCustomFactory_CreateActionCalled_WalkTree_Success()
+ {
+ // Test - Verify that the custom factory's CreateAction was actually invoked during WalkTree.
+ var testService = new DiTestService("Tracked");
+ var customFactory = new TestCustomActionFactory(testService);
+
+ var session = CreateSession(SingleDependencySchema, actionFactory: customFactory);
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status);
+ Assert.IsTrue(customFactory.CreateActionCallCount > 0,
+ "Expected custom ActionFactory.CreateAction to have been called at least once during WalkTree.");
+ }
+
+ [TestMethod]
+ public async Task TestCustomFactory_OverridesDefaultFactory()
+ {
+ // Test - When ActionFactory is set, it is used instead of the DefaultForgeActionFactory.
+ var testService = new DiTestService("FactoryWins");
+ var customFactory = new TestCustomActionFactory(testService);
+
+ var session = CreateSession(SingleDependencySchema, actionFactory: customFactory);
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status);
+
+ ActionResponse response = await session.GetLastActionResponseAsync();
+ Assert.AreEqual("FactoryWins", response.Output,
+ "Expected custom ActionFactory to be used instead of DefaultForgeActionFactory.");
+ Assert.IsTrue(customFactory.CreateActionCallCount > 0,
+ "Expected custom ActionFactory.CreateAction to have been called.");
+ }
+
+ [TestMethod]
+ public async Task TestCustomFactory_MultipleDependencies_WalkTree_Success()
+ {
+ // Test - Custom factory resolves an action with multiple dependencies.
+ var testService = new DiTestService("CustomMulti");
+ var customFactory = new TestCustomActionFactory(testService);
+
+ var session = CreateSession(MultipleDependencySchema, actionFactory: customFactory);
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected WalkTree to complete successfully with custom factory resolving MultipleDependencyAction.");
+
+ ActionResponse response = await session.GetLastActionResponseAsync();
+ Assert.AreEqual("CustomMulti", response.Status);
+ Assert.AreEqual(1, response.StatusCode);
+ }
+
+ #endregion Custom ActionFactory Tests
+
+ #region Factory Fallback Tests
+
+ [TestMethod]
+ public async Task TestFactoryFallback_NoFactorySet_UsesDefaultFactory()
+ {
+ // Test - When no ActionFactory is set, the DefaultForgeActionFactory is used.
+ // TardigradeAction has a parameterless constructor so it should work.
+ string tardigradeOnlySchema = @"
+ {
+ ""Tree"": {
+ ""Root"": {
+ ""Type"": ""Action"",
+ ""Actions"": {
+ ""Root_TardigradeAction"": {
+ ""Action"": ""TardigradeAction""
+ }
+ }
+ }
+ }
+ }";
+
+ var session = CreateSession(tardigradeOnlySchema);
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected DefaultForgeActionFactory to resolve TardigradeAction via Activator.CreateInstance.");
+ }
+
+ [TestMethod]
+ public async Task TestFactoryFallback_ServiceProviderActionFactory_WalkTree_Success()
+ {
+ // Test - When ActionFactory is explicitly set to a ServiceProviderActionFactory, it resolves actions.
+ using (var sp = BuildServiceProvider("Explicit"))
+ {
+ var session = CreateSession(SingleDependencySchema, actionFactory: new ServiceProviderActionFactory(sp));
+ string status = await session.WalkTree("Root");
+
+ Assert.AreEqual("RanToCompletion", status,
+ "Expected ServiceProviderActionFactory set via ActionFactory to resolve actions.");
+
+ ActionResponse response = await session.GetLastActionResponseAsync();
+ Assert.AreEqual("Explicit", response.Output);
+ }
+ }
+
+ #endregion Factory Fallback Tests
+ }
+}
diff --git a/Forge.TreeWalker/Forge.TreeWalker.csproj b/Forge.TreeWalker/Forge.TreeWalker.csproj
index 22cbab8..e26e512 100644
--- a/Forge.TreeWalker/Forge.TreeWalker.csproj
+++ b/Forge.TreeWalker/Forge.TreeWalker.csproj
@@ -21,6 +21,7 @@
all
+
diff --git a/Forge.TreeWalker/src/DefaultForgeActionFactory.cs b/Forge.TreeWalker/src/DefaultForgeActionFactory.cs
new file mode 100644
index 0000000..9947482
--- /dev/null
+++ b/Forge.TreeWalker/src/DefaultForgeActionFactory.cs
@@ -0,0 +1,35 @@
+//-----------------------------------------------------------------------
+//
+// Copyright (c) Microsoft Corporation. All rights reserved.
+//
+//
+// The DefaultForgeActionFactory class.
+//
+//-----------------------------------------------------------------------
+
+namespace Microsoft.Forge.TreeWalker
+{
+ using System;
+
+ ///
+ /// The default IForgeActionFactory implementation that uses Activator.CreateInstance to instantiate ForgeAction classes.
+ /// This preserves the original behavior of Forge when no custom factory is provided.
+ ///
+ public class DefaultForgeActionFactory : IForgeActionFactory
+ {
+ ///
+ /// Creates an instance of the specified ForgeAction type using Activator.CreateInstance.
+ ///
+ /// The Type of the ForgeAction class to instantiate.
+ /// The TreeWalkerParameters for the current session.
+ /// An instance of the specified action type.
+ public BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters)
+ {
+ if (actionType == typeof(SubroutineAction))
+ {
+ return (SubroutineAction)Activator.CreateInstance(actionType, parameters);
+ }
+ return (BaseAction)Activator.CreateInstance(actionType);
+ }
+ }
+}
diff --git a/Forge.TreeWalker/src/IForgeActionFactory.cs b/Forge.TreeWalker/src/IForgeActionFactory.cs
new file mode 100644
index 0000000..d394bfc
--- /dev/null
+++ b/Forge.TreeWalker/src/IForgeActionFactory.cs
@@ -0,0 +1,28 @@
+//-----------------------------------------------------------------------
+//
+// Copyright (c) Microsoft Corporation. All rights reserved.
+//
+//
+// The IForgeActionFactory interface.
+//
+//-----------------------------------------------------------------------
+
+namespace Microsoft.Forge.TreeWalker
+{
+ using System;
+
+ ///
+ /// The IForgeActionFactory interface defines a factory for creating ForgeAction instances.
+ /// Implement this interface to integrate a dependency injection container of your choice.
+ ///
+ public interface IForgeActionFactory
+ {
+ ///
+ /// Creates an instance of the specified ForgeAction type.
+ ///
+ /// The Type of the ForgeAction class to instantiate. This type derives from .
+ /// The TreeWalkerParameters for the current session. Provided for native actions that require it (e.g. SubroutineAction).
+ /// An instance of the specified action type.
+ BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters);
+ }
+}
diff --git a/Forge.TreeWalker/src/ServiceProviderActionFactory.cs b/Forge.TreeWalker/src/ServiceProviderActionFactory.cs
new file mode 100644
index 0000000..2aa8bee
--- /dev/null
+++ b/Forge.TreeWalker/src/ServiceProviderActionFactory.cs
@@ -0,0 +1,46 @@
+//
+// Copyright (c) Microsoft Corporation. All rights reserved.
+//
+//
+// The ServiceProviderActionFactory class.
+//
+//-----------------------------------------------------------------------
+
+namespace Microsoft.Forge.TreeWalker
+{
+ using System;
+ using Microsoft.Extensions.DependencyInjection;
+
+ ///
+ /// An implementation that resolves ForgeAction instances
+ /// from an backed by Microsoft.Extensions.DependencyInjection.
+ ///
+ public class ServiceProviderActionFactory : IForgeActionFactory
+ {
+ private readonly IServiceProvider serviceProvider;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The service provider used to resolve ForgeAction instances.
+ public ServiceProviderActionFactory(IServiceProvider serviceProvider)
+ {
+ this.serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
+ }
+
+ ///
+ /// Creates an instance of the specified ForgeAction type by resolving it from the service provider.
+ ///
+ /// The Type of the ForgeAction class to instantiate.
+ /// The TreeWalkerParameters for the current session.
+ /// An instance of the specified action type.
+ public BaseAction CreateAction(Type actionType, TreeWalkerParameters parameters)
+ {
+ if (actionType == typeof(SubroutineAction))
+ {
+ return (SubroutineAction)ActivatorUtilities.CreateInstance(this.serviceProvider, actionType, parameters);
+ }
+ return (BaseAction)ActivatorUtilities.CreateInstance(this.serviceProvider, actionType);
+ }
+ }
+}
diff --git a/Forge.TreeWalker/src/TreeWalkerParameters.cs b/Forge.TreeWalker/src/TreeWalkerParameters.cs
index 2092946..c16d48b 100644
--- a/Forge.TreeWalker/src/TreeWalkerParameters.cs
+++ b/Forge.TreeWalker/src/TreeWalkerParameters.cs
@@ -137,6 +137,13 @@ public class TreeWalkerParameters
///
public bool RetryCurrentTreeNodeActions { get; set; }
+ ///
+ /// The factory used to create ForgeAction instances.
+ /// Implement to integrate a dependency injection container of your choice.
+ /// When it is null, Forge will use the which creates instances via Activator.CreateInstance.
+ ///
+ public IForgeActionFactory ActionFactory { get; set; }
+
#endregion
#region Constructor with ITreeWalkerCallbacks, [DEPRECATED]
diff --git a/Forge.TreeWalker/src/TreeWalkerSession.cs b/Forge.TreeWalker/src/TreeWalkerSession.cs
index 38efbf3..da4977e 100644
--- a/Forge.TreeWalker/src/TreeWalkerSession.cs
+++ b/Forge.TreeWalker/src/TreeWalkerSession.cs
@@ -9,6 +9,11 @@
namespace Microsoft.Forge.TreeWalker
{
+ using Microsoft.Forge.Attributes;
+ using Microsoft.Forge.DataContracts;
+ using Microsoft.Forge.TreeWalker.ForgeExceptions;
+ using Newtonsoft.Json;
+ using Newtonsoft.Json.Linq;
using System;
using System.Collections;
using System.Collections.Generic;
@@ -17,11 +22,6 @@ namespace Microsoft.Forge.TreeWalker
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
- using Microsoft.Forge.Attributes;
- using Microsoft.Forge.DataContracts;
- using Microsoft.Forge.TreeWalker.ForgeExceptions;
- using Newtonsoft.Json;
- using Newtonsoft.Json.Linq;
///
/// The TreeWalkerSession tries to walk the given tree schema to completion.
@@ -162,6 +162,7 @@ public TreeWalkerSession(TreeWalkerParameters parameters)
// Initialize properties from optional TreeWalkerParameters properties.
GetActionsMapFromAssembly(parameters.ForgeActionsAssembly, out this.actionsMap);
this.Parameters.ExternalExecutors = parameters.ExternalExecutors ?? new Dictionary>>();
+ this.Parameters.ActionFactory = parameters.ActionFactory ?? new DefaultForgeActionFactory();
// TODO: Consider using a factory pattern to construct asynchronously.
this.Parameters.TreeInput = this.GetOrCommitTreeInput(parameters.TreeInput).GetAwaiter().GetResult();
@@ -910,17 +911,8 @@ await this.EvaluateDynamicProperty(treeAction.Properties, null).ConfigureAwait(f
this.Parameters.RootSessionId
);
- // Instantiate the BaseAction-derived ActionType class and invoke the RunAction method on it.
- object actionObject;
- if (actionDefinition.ActionType == typeof(SubroutineAction))
- {
- // Special initializer is used for the native SubroutineAction.
- actionObject = Activator.CreateInstance(actionDefinition.ActionType, this.Parameters);
- }
- else
- {
- actionObject = Activator.CreateInstance(actionDefinition.ActionType);
- }
+ //Create the action object using the ActionFactory and kick off the action task.
+ BaseAction actionObject = this.Parameters.ActionFactory.CreateAction(actionDefinition.ActionType, this.Parameters);
MethodInfo method = typeof(BaseAction).GetMethod("RunAction");
Task runActionTask = (Task) method.Invoke(actionObject, new object[] { actionContext });