Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ repos:
- { id: uv-lock, name: "uv-lock-rag", args: [--project, packages/nvidia_nat_rag], files: "packages/nvidia_nat_rag/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-ragaai", args: [--project, packages/nvidia_nat_ragaai], files: "packages/nvidia_nat_ragaai/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-redis", args: [--project, packages/nvidia_nat_redis], files: "packages/nvidia_nat_redis/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-runtime", args: [--project, packages/nvidia_nat_runtime], files: "packages/nvidia_nat_runtime/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-s3", args: [--project, packages/nvidia_nat_s3], files: "packages/nvidia_nat_s3/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-semantic_kernel", args: [--project, packages/nvidia_nat_semantic_kernel], files: "packages/nvidia_nat_semantic_kernel/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-strands", args: [--project, packages/nvidia_nat_strands], files: "packages/nvidia_nat_strands/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-test", args: [--project, packages/nvidia_nat_test], files: "packages/nvidia_nat_test/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-vanna", args: [--project, packages/nvidia_nat_vanna], files: "packages/nvidia_nat_vanna/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-weave", args: [--project, packages/nvidia_nat_weave], files: "packages/nvidia_nat_weave/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-workspace", args: [--project, packages/nvidia_nat_workspace], files: "packages/nvidia_nat_workspace/pyproject.toml" }
- { id: uv-lock, name: "uv-lock-zep_cloud", args: [--project, packages/nvidia_nat_zep_cloud], files: "packages/nvidia_nat_zep_cloud/pyproject.toml" }
- repo: local
hooks:
Expand Down
34 changes: 34 additions & 0 deletions packages/nvidia_nat_core/src/nat/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,19 @@
from nat.data_models.object_store import ObjectStoreBaseConfig
from nat.data_models.retriever import RetrieverBaseConfig
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
from nat.data_models.workspace import WorkspaceBaseConfig
from nat.experimental.decorators.experimental_warning_decorator import experimental
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
from nat.finetuning.interfaces.finetuning_runner import Trainer
from nat.finetuning.interfaces.trainer_adapter import TrainerAdapter
from nat.finetuning.interfaces.trajectory_builder import TrajectoryBuilder
from nat.guardrails.workspace import WorkspaceGuardrail
from nat.memory.interfaces import MemoryEditor
from nat.middleware.middleware import Middleware
from nat.object_store.interfaces import ObjectStore
from nat.retriever.interface import Retriever
from nat.workspace.types import WorkspaceManagerBase

if typing.TYPE_CHECKING:
from nat.builder.sync_builder import SyncBuilder
Expand Down Expand Up @@ -236,6 +239,37 @@ def get_workflow_config(self) -> FunctionBaseConfig:
"""
pass

@abstractmethod
def get_workspace_config(self) -> WorkspaceBaseConfig | None:
"""Get the workspace configuration.

Returns:
The workspace configuration if configured, otherwise None
"""
pass

@abstractmethod
async def get_workspace_manager(self) -> WorkspaceManagerBase | None:
"""Get the workspace manager.

Returns:
The workspace manager if configured, otherwise None
"""
pass

@abstractmethod
async def get_workspace_guardrails(self) -> list[WorkspaceGuardrail]:
"""Get instantiated workspace guardrail instances from the workspace config.

Implementations should resolve ``WorkspaceBaseConfig.workspace_guardrails``
refs to live :class:`~nat.guardrails.workspace.WorkspaceGuardrail` instances.

Returns:
List of workspace guardrail instances configured for the active workspace.
Returns an empty list if no workspace or no guardrails are configured.
"""
pass

@abstractmethod
async def get_tools(self,
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
Expand Down
12 changes: 11 additions & 1 deletion packages/nvidia_nat_core/src/nat/builder/child_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from nat.data_models.object_store import ObjectStoreBaseConfig
from nat.data_models.retriever import RetrieverBaseConfig
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
from nat.data_models.workspace import WorkspaceBaseConfig
from nat.experimental.decorators.experimental_warning_decorator import experimental
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
Expand All @@ -60,6 +61,7 @@
from nat.object_store.interfaces import ObjectStore
from nat.retriever.interface import Retriever
from nat.utils.type_utils import override
from nat.workspace.types import WorkspaceManagerBase


class ChildBuilder(Builder):
Expand All @@ -70,8 +72,8 @@ def __init__(self, workflow_builder: Builder) -> None:

self._dependencies = FunctionDependencies()

@override
@property
@override
def sync_builder(self) -> SyncBuilder:
return SyncBuilder(self)

Expand Down Expand Up @@ -125,6 +127,14 @@ def get_workflow(self) -> Function:
def get_workflow_config(self) -> FunctionBaseConfig:
return self._workflow_builder.get_workflow_config()

@override
def get_workspace_config(self) -> WorkspaceBaseConfig | None:
return self._workflow_builder.get_workspace_config()

@override
async def get_workspace_manager(self) -> WorkspaceManagerBase | None:
return await self._workflow_builder.get_workspace_manager()

@override
async def get_tools(self,
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
Expand Down
8 changes: 6 additions & 2 deletions packages/nvidia_nat_core/src/nat/builder/component_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from nat.data_models.object_store import ObjectStoreBaseConfig
from nat.data_models.retriever import RetrieverBaseConfig
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
from nat.data_models.workspace_guardrail import WorkspaceGuardrailBaseConfig
from nat.utils.type_utils import DecomposedType

WORKFLOW_COMPONENT_NAME = "<workflow>"
Expand All @@ -57,6 +58,7 @@
ComponentGroup.RETRIEVERS,
ComponentGroup.TTC_STRATEGIES,
ComponentGroup.MIDDLEWARE,
ComponentGroup.WORKSPACE_GUARDRAILS,
ComponentGroup.FUNCTION_GROUPS,
ComponentGroup.FUNCTIONS,
ComponentGroup.TRAINER_ADAPTERS,
Expand Down Expand Up @@ -130,6 +132,8 @@ def group_from_component(component: TypedBaseModel) -> ComponentGroup | None:
return ComponentGroup.FUNCTION_GROUPS
if (isinstance(component, MiddlewareBaseConfig)):
return ComponentGroup.MIDDLEWARE
if (isinstance(component, WorkspaceGuardrailBaseConfig)):
return ComponentGroup.WORKSPACE_GUARDRAILS
if (isinstance(component, LLMBaseConfig)):
return ComponentGroup.LLMS
if (isinstance(component, MemoryBaseConfig)):
Expand Down Expand Up @@ -289,8 +293,8 @@ def build_dependency_sequence(config: "Config") -> list[ComponentInstanceData]:
total_node_count = (len(config.embedders) + len(config.functions) + len(config.function_groups) + len(config.llms) +
len(config.memory) + len(config.object_stores) + len(config.retrievers) +
len(config.ttc_strategies) + len(config.authentication) + len(config.middleware) +
len(config.trainers) + len(config.trajectory_builders) + len(config.trainer_adapters) + 1
) # +1 for the workflow
len(config.workspace_guardrails) + len(config.trainers) + len(config.trajectory_builders) +
len(config.trainer_adapters) + 1) # +1 for the workflow

dependency_map: dict
dependency_graph: nx.DiGraph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ class LLMFrameworkEnum(StrEnum):
ADK = "adk"
STRANDS = "strands"
AUTOGEN = "autogen"
RUNTIME = "runtime"
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from nat.data_models.object_store import ObjectStoreBaseConfig
from nat.data_models.retriever import RetrieverBaseConfig
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
from nat.data_models.workspace import WorkspaceBaseConfig
from nat.experimental.decorators.experimental_warning_decorator import experimental
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
Expand All @@ -78,6 +79,7 @@
from nat.object_store.interfaces import ObjectStore
from nat.retriever.interface import Retriever
from nat.utils.type_utils import override
from nat.workspace.types import WorkspaceManagerBase

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,8 +136,8 @@ def _get_exit_stack(self) -> AsyncExitStack:
"Exit stack not initialized. Did you forget to call `async with PerUserWorkflowBuilder() as builder`?")
return self._exit_stack

@override
@property
@override
def sync_builder(self) -> SyncBuilder:
return SyncBuilder(self)

Expand Down Expand Up @@ -354,6 +356,14 @@ def get_workflow_config(self) -> FunctionBaseConfig:
# Otherwise, delegate to shared builder
return self._shared_builder.get_workflow_config()

@override
def get_workspace_config(self) -> WorkspaceBaseConfig | None:
return self._shared_builder.get_workspace_config()

@override
async def get_workspace_manager(self) -> WorkspaceManagerBase | None:
return await self._shared_builder.get_workspace_manager()

@override
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
if isinstance(fn_name, FunctionRef):
Expand Down
18 changes: 18 additions & 0 deletions packages/nvidia_nat_core/src/nat/builder/sync_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from nat.data_models.object_store import ObjectStoreBaseConfig
from nat.data_models.retriever import RetrieverBaseConfig
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
from nat.data_models.workspace import WorkspaceBaseConfig
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
from nat.finetuning.interfaces.finetuning_runner import Trainer
Expand All @@ -59,6 +60,7 @@
from nat.middleware.middleware import Middleware
from nat.object_store.interfaces import ObjectStore
from nat.retriever.interface import Retriever
from nat.workspace.types import WorkspaceManagerBase

if typing.TYPE_CHECKING:
from nat.experimental.test_time_compute.models.strategy_base import StrategyBase
Expand Down Expand Up @@ -180,6 +182,22 @@ def get_workflow_config(self) -> FunctionBaseConfig:
"""
return self._builder.get_workflow_config()

def get_workspace_config(self) -> WorkspaceBaseConfig | None:
"""Get the workspace configuration.

Returns:
The workspace configuration if configured, otherwise None
"""
return self._builder.get_workspace_config()

def get_workspace_manager(self) -> WorkspaceManagerBase | None:
"""Get the workspace manager.

Returns:
The workspace manager if configured, otherwise None
"""
return self._loop.run_until_complete(self._builder.get_workspace_manager())

def get_tools(self,
tool_names: Sequence[str | FunctionRef | FunctionGroupRef],
wrapper_type: LLMFrameworkEnum | str) -> list[typing.Any]:
Expand Down
25 changes: 24 additions & 1 deletion packages/nvidia_nat_core/src/nat/builder/workflow_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from nat.data_models.retriever import RetrieverBaseConfig
from nat.data_models.telemetry_exporter import TelemetryExporterBaseConfig
from nat.data_models.ttc_strategy import TTCStrategyBaseConfig
from nat.data_models.workspace import WorkspaceBaseConfig
from nat.experimental.decorators.experimental_warning_decorator import experimental
from nat.experimental.test_time_compute.models.stage_enums import PipelineTypeEnum
from nat.experimental.test_time_compute.models.stage_enums import StageTypeEnum
Expand All @@ -91,6 +92,7 @@
from nat.object_store.interfaces import ObjectStore
from nat.observability.exporter.base_exporter import BaseExporter
from nat.utils.type_utils import override
from nat.workspace.types import WorkspaceManagerBase

try:
from nat.plugins.eval.profiler.decorators.framework_wrapper import chain_wrapped_build_fn
Expand Down Expand Up @@ -351,6 +353,8 @@ def __init__(self, *, general_config: GeneralConfig | None = None, registry: Typ
self._trainers: dict[str, ConfiguredTrainer] = {}
self._trainer_adapters: dict[str, ConfiguredTrainerAdapter] = {}
self._trajectory_builders: dict[str, ConfiguredTrajectoryBuilder] = {}
self._workspace_config: WorkspaceBaseConfig | None = None
self._workspace_manager: WorkspaceManagerBase | None = None

self._context_state = ContextState.get()

Expand Down Expand Up @@ -444,8 +448,8 @@ async def __aexit__(self, *exc_details):

await self._exit_stack.__aexit__(*exc_details)

@override
@property
@override
def sync_builder(self) -> SyncBuilder:
return SyncBuilder(self)

Expand Down Expand Up @@ -782,6 +786,22 @@ def get_workflow_config(self) -> FunctionBaseConfig:

return self._workflow.config

@override
def get_workspace_config(self) -> WorkspaceBaseConfig | None:
return self._workspace_config

@override
async def get_workspace_manager(self) -> WorkspaceManagerBase | None:
if self._workspace_config is None:
return None

if self._workspace_manager is None:
workspace_info = self._registry.get_workspace(type(self._workspace_config))
self._workspace_manager = await self._get_exit_stack().enter_async_context(
workspace_info.build_fn(self._workspace_config))

return self._workspace_manager

@override
def get_function_dependencies(self, fn_name: str | FunctionRef) -> FunctionDependencies:
if isinstance(fn_name, FunctionRef):
Expand Down Expand Up @@ -1418,6 +1438,9 @@ async def populate_builder(self, config: Config, skip_workflow: bool = False):
skip_workflow (bool): If True, skips the workflow instantiation step. Defaults to False.

"""
self._workspace_config = config.workspace
self._workspace_manager = None

# Generate the build sequence
build_sequence = build_dependency_sequence(config)

Expand Down
Loading