Add input guard middleware for pre-execution safety classification#1619
Add input guard middleware for pre-execution safety classification#1619ctmackay wants to merge 1 commit intoNVIDIA:developfrom
Conversation
- Add InputGuardMiddleware that classifies user prompts as safe/unsafe before the agent processes them, using a classification system prompt with general-purpose LLMs - Register the middleware in the retail agent example and wire it into config-with-defenses.yml Co-authored-by: Cursor <cursoragent@cursor.com>
WalkthroughThe changes introduce an LLM-based input guard defense component that classifies user prompts as Safe or Unsafe before execution. A new InputGuardMiddleware is implemented, registered, and integrated into the retail agent's middleware chain via configuration. Changes
Sequence DiagramsequenceDiagram
actor User
participant InputGuard as InputGuardMiddleware
participant LLM as LLM Classifier
participant Handler as Original Handler
User->>InputGuard: function_middleware_invoke(args)
InputGuard->>InputGuard: _check_input(first_arg)
InputGuard->>LLM: classify prompt (system + user message)
LLM->>LLM: generate Safe/Unsafe response
LLM-->>InputGuard: classification result + categories
InputGuard->>InputGuard: _parse_guard_response()
alt is_safe == true
InputGuard->>Handler: forward to next middleware
Handler-->>User: result
else is_safe == false
InputGuard-->>User: return redirected value or block
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
test out the input middleware like so: |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.py (1)
86-89: Suppress ruff ARG002 for unused-but-required interface parameters.
original_inputandcontextare part of the parent's abstract method signature but are unused in this override. Prefix them with_to satisfy ruff (or suppress with# noqa: ARG002).♻️ Proposed fix
async def _analyze_content(self, content: Any, - original_input: Any = None, - context: FunctionMiddlewareContext | None = None) -> ContentAnalysisResult: + _original_input: Any = None, + _context: FunctionMiddlewareContext | None = None) -> ContentAnalysisResult:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.py` around lines 86 - 89, The override of _analyze_content currently has unused interface parameters original_input and context which trigger ruff ARG002; modify the parameter names in the _analyze_content signature to _original_input and _context (or prefix them with an underscore) so they are recognized as intentionally unused (alternatively add a trailing comment `# noqa: ARG002`), leaving the function body unchanged; this change should be made on the _analyze_content method to align with the parent abstract signature while satisfying ruff.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@examples/safety_and_security/retail_agent/src/nat_retail_agent/configs/config-with-defenses.yml`:
- Around line 90-95: The input_guard middleware block is missing
target_function_or_group, causing it to apply to all functions; update the
input_guard config (the input_guard block that sets _type: input_guard,
llm_name: nim_llm, action: redirection) to include target_function_or_group:
<workflow> so it is restricted to the workflow entry point in line with
pii_defense_workflow and workflow_output_verifier and how _should_apply_defense
expects a non-null target to limit application.
In
`@examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.py`:
- Around line 117-124: In the except block inside InputGuardMiddleware (where
logger.exception and ContentAnalysisResult are used) remove the redundant
exception argument from logger.exception (call logger.exception(...) without
passing e), and change the failure behavior from silently failing-open to
failing-closed: log an explicit warning/error that the guard is degraded (so
outages are observable) and return a ContentAnalysisResult with is_safe=False,
should_refuse=True, error=True and error_message=str(e); also add a short
comment documenting the intentional fail-closed policy so operators can
configure alerting or override if desired.
- Around line 80-84: The __init__ of InputGuardMiddleware is missing a type on
the builder param and it silently bypasses the immediate parent's validation
(ContentSafetyGuardMiddleware.__init__) by calling DefenseMiddleware.__init__
directly; import Builder from nat.builder.builder and change the signature of
InputGuardMiddleware.__init__ to accept builder: Builder, and add an inline
comment/docstring explaining why you intentionally call
DefenseMiddleware.__init__ (to allow target_location='input' for input analysis)
instead of ContentSafetyGuardMiddleware.__init__ so the bypass is explicit and
documented.
---
Nitpick comments:
In
`@examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.py`:
- Around line 86-89: The override of _analyze_content currently has unused
interface parameters original_input and context which trigger ruff ARG002;
modify the parameter names in the _analyze_content signature to _original_input
and _context (or prefix them with an underscore) so they are recognized as
intentionally unused (alternatively add a trailing comment `# noqa: ARG002`),
leaving the function body unchanged; this change should be made on the
_analyze_content method to align with the parent abstract signature while
satisfying ruff.
| # Input Guard: Classify user prompts as safe/unsafe before the agent processes them | ||
| input_guard: | ||
| _type: input_guard | ||
| llm_name: nim_llm | ||
| action: redirection | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the config file at the specified lines
cd examples/safety_and_security/retail_agent && find . -name "config-with-defenses.yml" -type fRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 124
🏁 Script executed:
# Look at the structure of the entire codebase to find the config file
find . -path "*/nat_retail_agent/configs/config-with-defenses.yml" -type fRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 166
🏁 Script executed:
# Check for config file more broadly
find . -name "config-with-defenses.yml" -type f 2>/dev/null | head -5Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 166
🏁 Script executed:
cat -n examples/safety_and_security/retail_agent/src/nat_retail_agent/configs/config-with-defenses.yml | sed -n '80,110p'Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 1314
🏁 Script executed:
# Look at the entire config file to understand all middleware blocks
wc -l examples/safety_and_security/retail_agent/src/nat_retail_agent/configs/config-with-defenses.ymlRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 168
🏁 Script executed:
# Find the input_guard implementation and _should_apply_defense method
rg -n "_should_apply_defense" --type py -B 2 -A 10Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 22036
🏁 Script executed:
# Search for input_guard class definition
rg -n "class.*InputGuard|class.*input_guard" --type py -A 5Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 1873
🏁 Script executed:
# Find where target_function_or_group is defined
rg -n "target_function_or_group" --type py -B 2 -A 2Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 35502
🏁 Script executed:
# Get the complete _should_apply_defense method
cat -n packages/nvidia_nat_core/src/nat/middleware/defense/defense_middleware.py | sed -n '144,180p'Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 1998
🏁 Script executed:
# Get the rest of the _should_apply_defense method
cat -n packages/nvidia_nat_core/src/nat/middleware/defense/defense_middleware.py | sed -n '180,185p'Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 372
🏁 Script executed:
# Verify the default value in the config class
cat -n packages/nvidia_nat_core/src/nat/middleware/defense/defense_middleware.py | sed -n '74,80p'Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 445
Add target_function_or_group: <workflow> to the input_guard block to restrict it to the workflow entry point.
The input_guard middleware (lines 90-94) omits target_function_or_group, which defaults to None. According to the code logic in _should_apply_defense, when this field is None, the defense applies to all functions—including tool calls and nested functions, not just the workflow entry. This contradicts the design pattern established by the other middleware blocks in this config (pii_defense_workflow and workflow_output_verifier), which explicitly set target_function_or_group: <workflow> to restrict application to workflow-level only.
Add the missing field to align with the intended behavior:
Suggested fix
input_guard:
_type: input_guard
llm_name: nim_llm
action: redirection
+ target_function_or_group: <workflow>🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/safety_and_security/retail_agent/src/nat_retail_agent/configs/config-with-defenses.yml`
around lines 90 - 95, The input_guard middleware block is missing
target_function_or_group, causing it to apply to all functions; update the
input_guard config (the input_guard block that sets _type: input_guard,
llm_name: nim_llm, action: redirection) to include target_function_or_group:
<workflow> so it is restricted to the workflow entry point in line with
pii_defense_workflow and workflow_output_verifier and how _should_apply_defense
expects a non-null target to limit application.
| def __init__(self, config: InputGuardMiddlewareConfig, builder): | ||
| from nat.middleware.defense.defense_middleware import DefenseMiddleware | ||
| DefenseMiddleware.__init__(self, config, builder) | ||
| self.config: InputGuardMiddlewareConfig = config # type: ignore[assignment] | ||
| self._llm = None |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find the ContentSafetyGuardMiddleware class definition
find . -type f -name "*.py" | xargs grep -l "class ContentSafetyGuardMiddleware" 2>/dev/nullRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 158
🏁 Script executed:
#!/bin/bash
# Search for ContentSafetyGuardMiddleware definition and its __init__
rg "class ContentSafetyGuardMiddleware" -A 30 --type pyRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 6073
🏁 Script executed:
#!/bin/bash
# Find the file we're reviewing to understand its class hierarchy
find . -type f -path "*/input_guard_middleware.py" | head -5Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 159
🏁 Script executed:
#!/bin/bash
# Get the full class definition including parent classes
rg "class InputGuardMiddlewareClass" -B 2 -A 50 --type py | head -100Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 51
🏁 Script executed:
#!/bin/bash
# Get the full InputGuardMiddlewareClass definition with inheritance
cat -n examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.py | head -100Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 5356
🏁 Script executed:
#!/bin/bash
# Check the complete class definition of InputGuardMiddlewareClass
rg "class InputGuardMiddlewareClass" -B 5 -A 50 examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.pyRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 51
🏁 Script executed:
#!/bin/bash
# Get the full ContentSafetyGuardMiddleware.__init__ to understand what's being bypassed
rg "class ContentSafetyGuardMiddleware" -A 60 packages/nvidia_nat_core/src/nat/middleware/defense/defense_middleware_content_guard.py | head -80Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 3139
🏁 Script executed:
#!/bin/bash
# Find the Builder class or type to determine correct type annotation for builder parameter
rg "class.*Builder" packages/nvidia_nat_core/src/nat/ -A 2 | head -30Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 2567
🏁 Script executed:
#!/bin/bash
# Check how builder is typed in DefenseMiddleware.__init__
rg "def __init__.*builder" packages/nvidia_nat_core/src/nat/middleware/defense/defense_middleware.py -B 2 -A 5Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 336
🏁 Script executed:
#!/bin/bash
# Find the Builder class definition in the nat package
rg "class Builder" packages/nvidia_nat_core/src/nat/ --type py | grep -v "PredictionTrieBuilder\|TrajectoryBuilder"Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 140
🏁 Script executed:
#!/bin/bash
# Check imports in the workflow template to see if Builder is imported
rg "from.*Builder\|import.*Builder" packages/nvidia_nat_core/src/nat/ --type py | head -20Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 51
🏁 Script executed:
#!/bin/bash
# Look at how other middleware handle the builder parameter
rg "def __init__.*builder" packages/nvidia_nat_core/src/nat/middleware/ -A 3 --type pyRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 2925
Add type annotation to builder parameter and document why ContentSafetyGuardMiddleware.__init__ is bypassed.
The builder parameter lacks a type annotation, violating PEP 8 guidelines. Additionally, directly calling DefenseMiddleware.__init__ skips the immediate parent's validation, which explicitly forbids target_location='input'. While this appears intentional (since InputGuardMiddleware supports input analysis), the override should be explicit and documented rather than silently bypassed.
Import Builder from nat.builder.builder and update the method signature:
- def __init__(self, config: InputGuardMiddlewareConfig, builder):
- from nat.middleware.defense.defense_middleware import DefenseMiddleware
- DefenseMiddleware.__init__(self, config, builder)
+ def __init__(self, config: InputGuardMiddlewareConfig, builder: Builder):
+ from nat.middleware.defense.defense_middleware import DefenseMiddleware
+ # InputGuardMiddleware intentionally bypasses ContentSafetyGuardMiddleware.__init__
+ # to allow input analysis, which the parent explicitly forbids.
+ DefenseMiddleware.__init__(self, config, builder)Add the import at the top of the file:
from nat.builder.builder import Builder🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.py`
around lines 80 - 84, The __init__ of InputGuardMiddleware is missing a type on
the builder param and it silently bypasses the immediate parent's validation
(ContentSafetyGuardMiddleware.__init__) by calling DefenseMiddleware.__init__
directly; import Builder from nat.builder.builder and change the signature of
InputGuardMiddleware.__init__ to accept builder: Builder, and add an inline
comment/docstring explaining why you intentionally call
DefenseMiddleware.__init__ (to allow target_location='input' for input analysis)
instead of ContentSafetyGuardMiddleware.__init__ so the bypass is explicit and
documented.
| except Exception as e: | ||
| logger.exception("InputGuardMiddleware analysis failed: %s", e) | ||
| return ContentAnalysisResult(is_safe=True, | ||
| categories=[], | ||
| raw_response="", | ||
| should_refuse=False, | ||
| error=True, | ||
| error_message=str(e)) |
There was a problem hiding this comment.
Two concerns in the exception handler: redundant e in logger.exception (TRY401) and fail-open security posture.
-
TRY401:
logger.exceptionalready attaches the current exception; passingeas a format argument is redundant and produces a duplicate message. -
Fail-open: When the LLM call raises (e.g., network outage, model unavailable), the guard silently returns
is_safe=Trueand allows all prompts through. This effectively disables the input guard during outages — consider at minimum logging aWARNINGorERRORto make the degraded state observable, and document the intentional fail-open policy so operators can configure alerting.
🛠️ Fix for TRY401
- except Exception as e:
- logger.exception("InputGuardMiddleware analysis failed: %s", e)
+ except Exception:
+ logger.exception("InputGuardMiddleware analysis failed")🧰 Tools
🪛 Ruff (0.15.1)
[warning] 118-118: Redundant exception object included in logging.exception call
(TRY401)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/safety_and_security/retail_agent/src/nat_retail_agent/input_guard_middleware.py`
around lines 117 - 124, In the except block inside InputGuardMiddleware (where
logger.exception and ContentAnalysisResult are used) remove the redundant
exception argument from logger.exception (call logger.exception(...) without
passing e), and change the failure behavior from silently failing-open to
failing-closed: log an explicit warning/error that the guard is degraded (so
outages are observable) and return a ContentAnalysisResult with is_safe=False,
should_refuse=True, error=True and error_message=str(e); also add a short
comment documenting the intentional fail-closed policy so operators can
configure alerting or override if desired.
Description
This PR adds an
InputGuardMiddlewarethat classifies user prompts as safe or unsafe before the agent processes them, complementing the existing output-side defenses (output verifier, content safety guard, PII defense).The existing
ContentSafetyGuardMiddlewareonly supportstarget_location: output, so input-side safety classification was not possible without a custom middleware.InputGuardMiddlewareextendsContentSafetyGuardMiddlewareand overrides_analyze_contentto wrap user input in a classification system prompt, enabling general-purpose LLMs (e.g. Llama 3.3 via NIM) to reliably return Safe/Unsafe verdicts that_parse_guard_responsecan parse. It overridesfunction_middleware_invokeandfunction_middleware_streamto intercept the input beforecall_next, rather than post-processing the output.By Submitting this PR I confirm:
Summary by CodeRabbit
Release Notes