diff --git a/databricks-mcp-server/databricks_mcp_server/agentguard/__init__.py b/databricks-mcp-server/databricks_mcp_server/agentguard/__init__.py new file mode 100644 index 00000000..febe38a9 --- /dev/null +++ b/databricks-mcp-server/databricks_mcp_server/agentguard/__init__.py @@ -0,0 +1 @@ +"""AgentGuard hooks for the MCP server (middleware and tools).""" diff --git a/databricks-mcp-server/databricks_mcp_server/agentguard/commands.py b/databricks-mcp-server/databricks_mcp_server/agentguard/commands.py new file mode 100644 index 00000000..62c7c595 --- /dev/null +++ b/databricks-mcp-server/databricks_mcp_server/agentguard/commands.py @@ -0,0 +1,137 @@ +"""MCP tools for starting, stopping, and inspecting AgentGuard sessions.""" + +from __future__ import annotations + +import json +from typing import Optional + +from databricks_tools_core.agentguard.context import get_active_session +from databricks_tools_core.agentguard.models import AgentGuardMode +from databricks_tools_core.agentguard.session import ( + get_session_status, + start_session, + stop_session, +) + +from ..server import mcp + + +@mcp.tool +def agentguard_start( + mode: str = "monitor_only", + description: str = "", + scope_template: Optional[str] = None, + scope_variables: Optional[dict] = None, +) -> str: + """Start a session: mode monitor_only|enforce, optional scope_template/variables.""" + try: + guard_mode = AgentGuardMode(mode) + except ValueError: + return f"Invalid mode '{mode}'. Use 'monitor_only' or 'enforce'." + + try: + session = start_session( + mode=guard_mode, + description=description, + scope_template=scope_template, + scope_variables=scope_variables, + ) + except ValueError as e: + return str(e) + + mode_label = "monitor-only" if guard_mode == AgentGuardMode.MONITOR_ONLY else "enforce" + lines = [ + f"AgentGuard session started ({mode_label}).", + f"Task ID: {session.task_id}", + "All actions will be recorded.", + ] + if guard_mode == AgentGuardMode.MONITOR_ONLY: + lines.append("Nothing will be blocked. Use mode='enforce' to enable enforcement.") + else: + lines.append("Policy and scope violations will be BLOCKED.") + if scope_template: + lines.append(f"Scope template: {scope_template}") + if scope_variables: + lines.append(f"Scope variables: {scope_variables}") + lines.append("Risk scoring is active. High-risk actions will require approval.") + + return "\n".join(lines) + + +@mcp.tool +def agentguard_stop() -> str: + """Stop the session, flush audit trail if possible, return summary text.""" + session = stop_session() + if session is None: + return "No active AgentGuard session to stop." + + lines = [ + "AgentGuard session stopped.", + session.summary(), + ] + + ledger = session._ledger_result + if ledger: + status = ledger.get("status", "unknown") + if status == "success": + dest = ledger.get("destination", "?") + lines.append(f"Audit trail: {ledger.get('rows', 0)} actions written to {dest}") + lines.append(f"Query: SELECT * FROM {dest} WHERE task_id = '{session.task_id}'") + elif status == "pending": + lines.append(f"Audit trail: saved locally at {ledger.get('destination', '?')}") + lines.append(f"Note: {ledger.get('note', 'Delta write unavailable. Data saved locally.')}") + elif status == "skipped": + lines.append("Audit trail: no actions to write.") + else: + lines.append(f"Audit trail: {status} — {ledger.get('error', ledger.get('note', 'unknown'))}") + else: + lines.append(f"Audit trail: agentguard.core.action_log WHERE task_id = '{session.task_id}'") + + return "\n".join(lines) + + +@mcp.tool +def agentguard_status() -> str: + """Return session status text, or a hint to call agentguard_start.""" + status = get_session_status() + if status is None: + return "No active AgentGuard session. Run agentguard_start to begin." + return status + + +@mcp.tool +def agentguard_history(limit: int = 50) -> str: + """Return recent actions as JSON (limit defaults to 50, tail of the list).""" + session = get_active_session() + if session is None: + return "No active AgentGuard session. Run agentguard_start to begin." + + if not session.actions: + return "No actions recorded yet in this session." + + actions_to_show = session.actions[-limit:] if limit > 0 else session.actions + history = [] + for action in actions_to_show: + entry = { + "seq": action.action_sequence, + "tool": action.tool_name, + "operation": action.operation, + "category": action.action_category.value, + "decision": action.final_decision, + "risk_score": action.checkpoint_result.risk_score if action.checkpoint_result else 0, + "policy": action.checkpoint_result.policy_result if action.checkpoint_result else "n/a", + "overhead_ms": round(action.overhead_ms, 1) if action.overhead_ms else 0, + "success": action.execution_success, + "timestamp": action.received_at.isoformat(), + } + if action.sql_statement: + stmt = action.sql_statement + entry["sql"] = stmt[:120] + "..." if len(stmt) > 120 else stmt + if action.checkpoint_result and action.checkpoint_result.block_reason: + entry["block_reason"] = action.checkpoint_result.block_reason + history.append(entry) + + total = len(session.actions) + shown = len(history) + suffix = f"\n\n(Showing {shown} of {total} actions)" if shown < total else "" + return json.dumps(history, indent=2) + suffix diff --git a/databricks-mcp-server/databricks_mcp_server/agentguard/middleware.py b/databricks-mcp-server/databricks_mcp_server/agentguard/middleware.py new file mode 100644 index 00000000..d5988b5a --- /dev/null +++ b/databricks-mcp-server/databricks_mcp_server/agentguard/middleware.py @@ -0,0 +1,403 @@ +"""MCP middleware: runs AgentGuard checkpoints on each tool call when a session is active.""" + +import asyncio +import json +import logging +import time +from typing import Any + +from databricks_tools_core.agentguard.context import get_active_session +from databricks_tools_core.agentguard.models import ( + Action, + AgentGuardMode, + CheckpointDecision, + CheckpointResult, +) +from databricks_tools_core.agentguard.policy import PolicyEngine +from databricks_tools_core.agentguard.risk import compute_risk_score +from databricks_tools_core.agentguard.scope import check_scope, check_scope_limits +from databricks_tools_core.agentguard.timing import Timer +from fastmcp.server.context import ( + AcceptedElicitation, + CancelledElicitation, + DeclinedElicitation, +) +from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext +from fastmcp.tools.tool import ToolResult +from mcp.types import CallToolRequestParams, TextContent + +_APPROVAL_TIMEOUT_SECONDS = 300 + +logger = logging.getLogger(__name__) + +_AGENTGUARD_TOOLS = frozenset( + { + "agentguard_start", + "agentguard_stop", + "agentguard_status", + "agentguard_history", + } +) + + +class AgentGuardMiddleware(Middleware): + """Wraps tool calls with policy/scope/risk checks and optional human approval.""" + + def __init__(self) -> None: + self.policy_engine = PolicyEngine() + + async def on_call_tool( + self, + context: MiddlewareContext[CallToolRequestParams], + call_next: CallNext[CallToolRequestParams, ToolResult], + ) -> ToolResult: + session = get_active_session() + + if session is None: + return await call_next(context) + + tool_name = context.message.name + tool_params = context.message.arguments or {} + + if tool_name in _AGENTGUARD_TOOLS: + return await call_next(context) + + timer = Timer() + + action = timer.measure( + "build_action", + Action.from_tool_call, + tool_name=tool_name, + tool_params=tool_params, + task_id=session.task_id, + agent_id=session.agent_id, + sequence=session.next_sequence(), + ) + + checkpoint_result = timer.measure( + "checkpoint_pipeline", + self._run_checkpoints, + action, + session, + ) + action.checkpoint_result = checkpoint_result + checkpoint_result.timings = timer.records + + if ( + checkpoint_result.decision + in ( + CheckpointDecision.FLAG, + CheckpointDecision.HOLD_FOR_APPROVAL, + ) + and session.mode == AgentGuardMode.ENFORCE + ): + checkpoint_result.original_decision = checkpoint_result.decision + checkpoint_result.approval_requested = True + + approved = await self._request_human_approval(context, action, checkpoint_result) + if not approved: + checkpoint_result.decision = CheckpointDecision.BLOCK + checkpoint_result.blocking_checkpoint = "CP-4: Human Approval" + action.final_decision = "rejected_by_user" + action.overhead_ms = timer.total_ms + session.record_action(action) + self._log_action(action, checkpoint_result) + self._persist_action(session, action) + return self._blocked_result(tool_name, checkpoint_result) + + checkpoint_result.approval_outcome = "approved" + action.final_decision = "approved_by_user" + logger.info( + "[AgentGuard] #%d %s — approved by user", + action.action_sequence, + tool_name, + ) + + should_block = checkpoint_result.decision == CheckpointDecision.BLOCK and session.mode == AgentGuardMode.ENFORCE + + if should_block: + action.final_decision = action.final_decision or "blocked" + action.overhead_ms = timer.total_ms + session.record_action(action) + self._log_action(action, checkpoint_result) + self._persist_action(session, action) + return self._blocked_result(tool_name, checkpoint_result) + + exec_start = time.perf_counter() + execution_failed = False + try: + tool_result = await call_next(context) + action.execution_success = True + except Exception as e: + execution_failed = True + action.execution_success = False + action.execution_error = str(e) + raise + finally: + action.execution_duration_ms = (time.perf_counter() - exec_start) * 1000 + action.overhead_ms = timer.total_ms + + if execution_failed: + action.final_decision = action.final_decision or "failed" + elif not action.final_decision: + if checkpoint_result.decision in ( + CheckpointDecision.WOULD_BLOCK, + CheckpointDecision.BLOCK, + CheckpointDecision.HOLD_FOR_APPROVAL, + ): + action.final_decision = "would_block" + elif checkpoint_result.decision == CheckpointDecision.FLAG: + action.final_decision = "flagged" + else: + action.final_decision = "executed" + + session.record_action(action) + self._log_action(action, checkpoint_result) + self._persist_action(session, action) + + return tool_result + + async def _request_human_approval( + self, + context: MiddlewareContext[CallToolRequestParams], + action: Action, + result: CheckpointResult, + ) -> bool: + """Elicit yes/no approval; False on timeout, missing context, or decline.""" + ctx = context.fastmcp_context + if ctx is None: + logger.warning( + "[AgentGuard] CP-4 BLOCKED (no context): %s %s — %s", + action.tool_name, + action.operation, + result.block_reason, + ) + result.approval_outcome = "unavailable" + result.approval_note = "No FastMCP context available" + return False + + message = self._format_approval_message(action, result) + + try: + elicit_result = await asyncio.wait_for( + ctx.elicit(message=message, response_type=bool), + timeout=_APPROVAL_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + logger.warning( + "[AgentGuard] CP-4 BLOCKED (timeout after %ds): %s %s — %s", + _APPROVAL_TIMEOUT_SECONDS, + action.tool_name, + action.operation, + result.block_reason, + ) + result.approval_outcome = "timeout" + result.approval_note = f"No response within {_APPROVAL_TIMEOUT_SECONDS}s. Action blocked for safety." + return False + except Exception as exc: + logger.warning( + "[AgentGuard] CP-4 BLOCKED (elicitation unavailable): %s %s — %s " + "(client may not support elicitation: %s). " + "Action blocked for safety.", + action.tool_name, + action.operation, + result.block_reason, + exc, + ) + result.approval_outcome = "unavailable" + result.approval_note = ( + f"Client does not support interactive approval dialogs ({exc}). Action blocked for safety." + ) + return False + + if isinstance(elicit_result, AcceptedElicitation): + result.approval_outcome = "approved" + return True + + if isinstance(elicit_result, DeclinedElicitation): + result.approval_outcome = "declined" + result.approval_note = "User declined the action" + elif isinstance(elicit_result, CancelledElicitation): + result.approval_outcome = "cancelled" + result.approval_note = "User cancelled the approval dialog" + + return False + + @staticmethod + def _format_approval_message(action: Action, result: CheckpointResult) -> str: + severity = ( + "HIGH RISK — Approval Required" + if result.decision == CheckpointDecision.HOLD_FOR_APPROVAL + else "Review Required" + ) + parts = [ + f"AgentGuard — {severity}", + "", + f" Tool: {action.tool_name}", + f" Operation: {action.operation}", + ] + if action.target_resource_id: + parts.append(f" Target: {action.target_resource_id}") + if action.sql_statement: + stmt = action.sql_statement + display = (stmt[:100] + "...") if len(stmt) > 100 else stmt + parts.append(f" SQL: {display}") + if result.risk_score > 0: + parts.append(f" Risk: {result.risk_score:.0f}/100") + parts += [ + "", + f" Reason: {result.block_reason or 'Flagged by policy'}", + "", + "Check the box (Space), then select Accept or Decline (Enter).", + ] + return "\n".join(parts) + + def _run_checkpoints( + self, + action: Action, + session, + ) -> CheckpointResult: + """Policy, optional scope/limits, then risk; later stages cannot undo a block.""" + result = CheckpointResult() + scope_violated = False + + policy_decision, rule_hit = self.policy_engine.check(action, session.mode) + result.policy_result = policy_decision.value + result.policy_rule_hit = rule_hit + + if policy_decision in ( + CheckpointDecision.BLOCK, + CheckpointDecision.WOULD_BLOCK, + ): + result.decision = policy_decision + result.block_reason = rule_hit + result.blocking_checkpoint = "CP-1: Policy" + elif policy_decision == CheckpointDecision.FLAG: + result.decision = CheckpointDecision.FLAG + result.block_reason = rule_hit + result.blocking_checkpoint = "CP-1: Policy -> CP-4: Approval" + + if session.scope is not None: + scope_decision, scope_violation = check_scope( + action, + session.scope, + session.mode, + ) + result.scope_result = scope_decision.value + result.scope_violation = scope_violation + + if scope_decision in ( + CheckpointDecision.BLOCK, + CheckpointDecision.WOULD_BLOCK, + ): + scope_violated = True + if result.decision in ( + CheckpointDecision.AUTO_APPROVE, + CheckpointDecision.FLAG, + ): + result.decision = scope_decision + result.block_reason = scope_violation + result.blocking_checkpoint = "CP-2: Scope" + + limit_decision, limit_violation = check_scope_limits( + action, + session.scope, + session.action_count, + session.write_count, + session.mode, + ) + if limit_decision in ( + CheckpointDecision.BLOCK, + CheckpointDecision.WOULD_BLOCK, + ): + if result.decision in ( + CheckpointDecision.AUTO_APPROVE, + CheckpointDecision.FLAG, + ): + result.decision = limit_decision + result.block_reason = limit_violation + result.blocking_checkpoint = "CP-2: Scope (limit)" + + risk = compute_risk_score(action, scope_violated=scope_violated) + result.risk_score = risk.score + result.risk_breakdown = risk.breakdown + + if risk.decision == CheckpointDecision.BLOCK: + if result.decision in ( + CheckpointDecision.AUTO_APPROVE, + CheckpointDecision.FLAG, + CheckpointDecision.HOLD_FOR_APPROVAL, + ): + result.decision = CheckpointDecision.BLOCK + result.block_reason = result.block_reason or f"Risk score {risk.score:.0f} exceeds block threshold" + result.blocking_checkpoint = "CP-3: Risk Score" + elif risk.decision == CheckpointDecision.HOLD_FOR_APPROVAL: + if result.decision in ( + CheckpointDecision.AUTO_APPROVE, + CheckpointDecision.FLAG, + ): + result.decision = CheckpointDecision.HOLD_FOR_APPROVAL + result.block_reason = result.block_reason or f"Risk score {risk.score:.0f} requires human approval" + result.blocking_checkpoint = "CP-3: Risk Score -> CP-4: Approval" + elif risk.decision == CheckpointDecision.FLAG: + if result.decision == CheckpointDecision.AUTO_APPROVE: + result.decision = CheckpointDecision.FLAG + result.block_reason = result.block_reason or f"Risk score {risk.score:.0f} exceeds flag threshold" + result.blocking_checkpoint = "CP-3: Risk Score -> CP-4: Approval" + + # In monitor mode, never hard-BLOCK — downgrade to WOULD_BLOCK so + # session counters and summaries are consistent with policy/scope behavior. + if session.mode == AgentGuardMode.MONITOR_ONLY and result.decision == CheckpointDecision.BLOCK: + result.decision = CheckpointDecision.WOULD_BLOCK + + return result + + @staticmethod + def _blocked_result(tool_name: str, result: CheckpointResult) -> ToolResult: + return ToolResult( + content=[ + TextContent( + type="text", + text=json.dumps( + { + "agentguard_blocked": True, + "tool": tool_name, + "reason": result.block_reason, + "checkpoint": result.blocking_checkpoint, + "risk_score": result.risk_score, + "suggestion": ( + "This action was blocked by AgentGuard. " + "If you believe this is correct, ask the user " + "to approve it or contact an admin to adjust the policy." + ), + } + ), + ) + ] + ) + + @staticmethod + def _log_action(action: Action, result: CheckpointResult) -> None: + overhead = f"{action.overhead_ms:.0f}ms" if action.overhead_ms else "?" + risk = f"risk={result.risk_score:.0f}" + decision = result.decision.value + + logger.info( + "[AgentGuard] #%d %s %s | %s | %s | policy=%s | overhead=%s", + action.action_sequence, + action.tool_name, + action.operation, + decision, + risk, + result.policy_result, + overhead, + ) + + @staticmethod + def _persist_action(session: Any, action: Action) -> None: + try: + from databricks_tools_core.agentguard.ledger import append_action + + append_action(session.task_id, action, session) + except Exception as e: + logger.debug("Audit persist failed for %s: %s", action.action_id, e) diff --git a/databricks-mcp-server/databricks_mcp_server/server.py b/databricks-mcp-server/databricks_mcp_server/server.py index d823fec7..6bdc80ac 100644 --- a/databricks-mcp-server/databricks_mcp_server/server.py +++ b/databricks-mcp-server/databricks_mcp_server/server.py @@ -14,6 +14,7 @@ from fastmcp import FastMCP +from .agentguard.middleware import AgentGuardMiddleware from .middleware import TimeoutHandlingMiddleware @@ -131,6 +132,7 @@ async def _noop_lifespan(*args, **kwargs): # Register middleware (see middleware.py for details on each) mcp.add_middleware(TimeoutHandlingMiddleware()) +mcp.add_middleware(AgentGuardMiddleware()) if sys.platform == "win32": _patch_tool_decorator_for_windows() @@ -156,3 +158,4 @@ async def _noop_lifespan(*args, **kwargs): workspace, pdf, ) +from .agentguard import commands as _agentguard_commands # noqa: F401, E402 diff --git a/databricks-tools-core/databricks_tools_core/agentguard/__init__.py b/databricks-tools-core/databricks_tools_core/agentguard/__init__.py new file mode 100644 index 00000000..17accd3f --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/__init__.py @@ -0,0 +1,17 @@ +""" +AgentGuard + +Runtime firewall for AI agents on Databricks: intercept, verify, score, and record actions. +""" + +from databricks_tools_core.agentguard.models import ( + ActionCategory, + AgentGuardMode, + CheckpointDecision, +) + +__all__ = [ + "AgentGuardMode", + "ActionCategory", + "CheckpointDecision", +] diff --git a/databricks-tools-core/databricks_tools_core/agentguard/context.py b/databricks-tools-core/databricks_tools_core/agentguard/context.py new file mode 100644 index 00000000..7dd6a9c9 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/context.py @@ -0,0 +1,38 @@ +""" +AgentGuard context + +Module-level active session so it survives separate MCP tool calls. ContextVar would not, +because each message is a new async context (same idea as auth’s _active_profile / _active_host). +""" + +from __future__ import annotations + +import threading +from typing import Optional + +from databricks_tools_core.agentguard.models import AgentGuardSession + +_lock = threading.Lock() +_active_session: Optional[AgentGuardSession] = None + + +def get_active_session() -> Optional[AgentGuardSession]: + with _lock: + return _active_session + + +def set_active_session(session: AgentGuardSession) -> None: + global _active_session + with _lock: + _active_session = session + + +def clear_active_session() -> None: + global _active_session + with _lock: + _active_session = None + + +def has_active_session() -> bool: + with _lock: + return _active_session is not None diff --git a/databricks-tools-core/databricks_tools_core/agentguard/ledger.py b/databricks-tools-core/databricks_tools_core/agentguard/ledger.py new file mode 100644 index 00000000..fc0855ce --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/ledger.py @@ -0,0 +1,462 @@ +""" +Audit ledger writer for AgentGuard. + +Two-tier persistence strategy: +1. HOT PATH (per-action): Append each action to a local JSONL file as it happens. + Cost: ~0.1ms per action. Survives crashes, forgotten stops, and closed connections. +2. COLD PATH (at session stop): Batch-flush the JSONL to a Delta table on Databricks. + If Delta is unavailable, the JSONL file is kept for later recovery. + +On next session start, any orphaned JSONL files (from crashed/abandoned sessions) +are detected and reported so they can be flushed to Delta. + +This guarantees zero audit data loss regardless of how the session ends. +""" + +from __future__ import annotations + +import json +import logging +import os +import threading +from pathlib import Path +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# Default catalog/schema for the audit ledger. +# Default: main.agentguard — "main" catalog exists on all workspaces. +# Override via env var if needed (e.g., AGENTGUARD_CATALOG=my_catalog). +_DEFAULT_CATALOG = os.environ.get("AGENTGUARD_CATALOG", "main") +_DEFAULT_SCHEMA = os.environ.get("AGENTGUARD_SCHEMA", "agentguard") +_ACTION_LOG_TABLE = "action_log" + +# Local storage directory for JSONL session files +_SESSIONS_DIR = Path.home() / ".agentguard" / "sessions" + +# Cache: once we confirm the Delta table exists, skip CREATE IF NOT EXISTS. +_table_exists_cache: set[str] = set() +_cache_lock = threading.Lock() + + +# --------------------------------------------------------------------------- +# Hot path: per-action local JSONL append +# --------------------------------------------------------------------------- + + +def init_session_file(task_id: str) -> Path: + """Create the JSONL file for a new session. Called at session start.""" + _SESSIONS_DIR.mkdir(parents=True, exist_ok=True) + file_path = _SESSIONS_DIR / f"{task_id}.jsonl" + # Write a header line with session metadata + file_path.touch(exist_ok=True) + return file_path + + +def append_action(task_id: str, action: Any, session: Any) -> None: + """Append a single action to the session's JSONL file. + + Called from the middleware after every tool call. Must be fast (<1ms). + Failures are logged but never block the agent. + """ + file_path = _SESSIONS_DIR / f"{task_id}.jsonl" + try: + row = _action_to_row(action, session) + line = json.dumps(row, default=str) + with open(file_path, "a") as f: + f.write(line + "\n") + except Exception as e: + # Never fail the agent because of audit logging + logger.warning("Audit append failed for action %s: %s", action.action_id, e) + + +def _action_to_row(action: Any, session: Any) -> dict[str, Any]: + """Convert a single action to a flat row dict.""" + cp = action.checkpoint_result + return { + "action_id": action.action_id, + "task_id": action.task_id, + "agent_id": action.agent_id, + "action_sequence": action.action_sequence, + "tool_name": action.tool_name, + "action_category": action.action_category.value, + "operation": action.operation, + "target_resource_type": action.target_resource_type, + "target_resource_id": action.target_resource_id, + "target_environment": action.target_environment, + "sql_statement": action.sql_statement, + "sql_parsed_type": action.sql_parsed_type, + "tables_read": json.dumps(action.tables_read), + "tables_written": json.dumps(action.tables_written), + "rows_affected": action.rows_affected, + "cp1_policy_result": cp.policy_result if cp else None, + "cp1_policy_rule_hit": cp.policy_rule_hit if cp else None, + "cp2_scope_result": cp.scope_result if cp else None, + "cp2_scope_violation": cp.scope_violation if cp else None, + "cp3_risk_score": cp.risk_score if cp else 0.0, + "cp3_risk_breakdown": json.dumps(cp.risk_breakdown) if cp else None, + "cp3_risk_decision": cp.decision.value if cp else None, + "cp4_approval_requested": cp.approval_requested if cp else False, + "cp4_approval_outcome": cp.approval_outcome if cp else None, + "cp4_approval_note": cp.approval_note if cp else None, + "final_decision": action.final_decision, + "execution_success": action.execution_success, + "execution_error": action.execution_error, + "execution_duration_ms": action.execution_duration_ms, + "overhead_ms": action.overhead_ms, + "received_at": action.received_at.isoformat() if action.received_at else None, + "executed_at": action.executed_at.isoformat() if action.executed_at else None, + "completed_at": action.completed_at.isoformat() if action.completed_at else None, + "session_mode": session.mode.value, + "session_description": session.description, + "scope_template": session.scope_template, + # Project identity — matches the tagging in identity.py so audit + # records can be correlated with other AI Dev Kit resources + "project_name": getattr(session, "project_name", "unknown"), + } + + +# --------------------------------------------------------------------------- +# Cold path: flush JSONL to Delta at session stop +# --------------------------------------------------------------------------- + + +def flush_session_to_delta( + session: Any, + catalog: str = _DEFAULT_CATALOG, + schema: str = _DEFAULT_SCHEMA, + warehouse_id: Optional[str] = None, +) -> dict[str, Any]: + """Flush a session's JSONL file to the Delta audit ledger. + + Reads all rows from the local JSONL, batch-inserts into Delta, + and deletes the JSONL on success. If Delta write fails, the JSONL + is kept for later recovery. + + Args: + session: A completed AgentGuardSession. + catalog: Unity Catalog catalog name. + schema: Schema within the catalog. + warehouse_id: SQL warehouse ID. If None, auto-selects. + + Returns: + Dict with write status, row count, and destination. + """ + jsonl_path = _SESSIONS_DIR / f"{session.task_id}.jsonl" + + # Read rows from JSONL + rows = _read_jsonl(jsonl_path) + if not rows: + _cleanup_jsonl(jsonl_path) + return {"status": "skipped", "reason": "no actions to flush", "rows": 0} + + try: + result = _write_to_delta(rows, catalog, schema, warehouse_id) + # Delta write succeeded — delete the local JSONL + _cleanup_jsonl(jsonl_path) + logger.info( + "Audit ledger: %d actions written to %s for task %s", + len(rows), + result["destination"], + session.task_id, + ) + return result + except Exception as e: + logger.warning( + "Audit ledger: Delta write failed for task %s (%s). JSONL preserved at %s for recovery.", + session.task_id, + e, + jsonl_path, + ) + return { + "status": "pending", + "destination": str(jsonl_path), + "rows": len(rows), + "method": "local_jsonl", + "note": ( + f"Delta write failed. {len(rows)} actions preserved locally at {jsonl_path}. " + f"They will be flushed on the next successful session stop, " + f"or can be loaded manually." + ), + } + + +def _read_jsonl(path: Path) -> list[dict[str, Any]]: + """Read all rows from a JSONL file.""" + if not path.exists(): + return [] + rows = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + try: + rows.append(json.loads(line)) + except json.JSONDecodeError: + logger.warning("Skipping malformed JSONL line in %s", path) + return rows + + +def _cleanup_jsonl(path: Path) -> None: + """Delete a JSONL file after successful Delta flush.""" + try: + if path.exists(): + path.unlink() + except OSError as e: + logger.warning("Could not delete JSONL %s: %s", path, e) + + +# --------------------------------------------------------------------------- +# Orphan recovery: detect abandoned sessions from previous runs +# --------------------------------------------------------------------------- + + +def find_orphaned_sessions() -> list[dict[str, Any]]: + """Find JSONL files from sessions that were never flushed to Delta. + + These are sessions where the user closed Claude Code without running + /agentguard stop, or where the Delta write failed. + + Returns: + List of dicts with task_id, file path, and row count. + """ + if not _SESSIONS_DIR.exists(): + return [] + + orphans = [] + for jsonl_file in _SESSIONS_DIR.glob("*.jsonl"): + # Skip empty files (created by init_session_file but never written to) + if jsonl_file.stat().st_size == 0: + _cleanup_jsonl(jsonl_file) + continue + rows = _read_jsonl(jsonl_file) + if rows: + task_id = jsonl_file.stem + orphans.append( + { + "task_id": task_id, + "file": str(jsonl_file), + "actions": len(rows), + "oldest_action": rows[0].get("received_at", "unknown"), + } + ) + return orphans + + +def flush_orphaned_sessions( + catalog: str = _DEFAULT_CATALOG, + schema: str = _DEFAULT_SCHEMA, + warehouse_id: Optional[str] = None, +) -> list[dict[str, Any]]: + """Flush all orphaned JSONL files to Delta. + + Called on session start to recover data from abandoned sessions. + + Returns: + List of results, one per orphaned session. + """ + orphans = find_orphaned_sessions() + if not orphans: + return [] + + results = [] + for orphan in orphans: + jsonl_path = Path(orphan["file"]) + rows = _read_jsonl(jsonl_path) + if not rows: + _cleanup_jsonl(jsonl_path) + continue + + try: + result = _write_to_delta(rows, catalog, schema, warehouse_id) + _cleanup_jsonl(jsonl_path) + results.append( + { + "task_id": orphan["task_id"], + "status": "recovered", + "rows": len(rows), + "destination": result["destination"], + } + ) + logger.info( + "Recovered orphaned session %s: %d actions → Delta", + orphan["task_id"], + len(rows), + ) + except Exception as e: + results.append( + { + "task_id": orphan["task_id"], + "status": "failed", + "rows": len(rows), + "error": str(e), + } + ) + logger.warning( + "Failed to recover orphaned session %s: %s", + orphan["task_id"], + e, + ) + + return results + + +# --------------------------------------------------------------------------- +# Delta write internals +# --------------------------------------------------------------------------- + + +_TABLE_COLUMNS = """( + action_id STRING, task_id STRING, agent_id STRING, + action_sequence INT, tool_name STRING, action_category STRING, + operation STRING, target_resource_type STRING, + target_resource_id STRING, target_environment STRING, + sql_statement STRING, sql_parsed_type STRING, + tables_read STRING, tables_written STRING, rows_affected BIGINT, + cp1_policy_result STRING, cp1_policy_rule_hit STRING, + cp2_scope_result STRING, cp2_scope_violation STRING, + cp3_risk_score DOUBLE, cp3_risk_breakdown STRING, + cp3_risk_decision STRING, cp4_approval_requested BOOLEAN, + cp4_approval_outcome STRING, cp4_approval_note STRING, + final_decision STRING, execution_success BOOLEAN, + execution_error STRING, execution_duration_ms DOUBLE, + overhead_ms DOUBLE, received_at STRING, executed_at STRING, + completed_at STRING, session_mode STRING, + session_description STRING, scope_template STRING, + project_name STRING +)""" + + +def _write_to_delta( + rows: list[dict[str, Any]], + catalog: str, + schema: str, + warehouse_id: Optional[str], +) -> dict[str, Any]: + """Write rows to the Delta audit ledger via SQL INSERT VALUES. + + Simple, direct approach — one SQL call. No Volume, no COPY INTO, + no extra permissions beyond table write access. + + If the INSERT fails (e.g., table was deleted externally), clears the + cache and retries with table recreation. + """ + from databricks_tools_core.sql.sql import execute_sql + + full_table = f"{catalog}.{schema}.{_ACTION_LOG_TABLE}" + + # Ensure table exists (cached after first call) + with _cache_lock: + needs_setup = full_table not in _table_exists_cache + + if needs_setup: + _ensure_table_exists(catalog, schema, warehouse_id) + with _cache_lock: + _table_exists_cache.add(full_table) + + # Build batch INSERT + value_rows = [] + for row in rows: + values = [] + for col in _COLUMN_ORDER: + val = row.get(col) + if val is None: + values.append("NULL") + elif isinstance(val, bool): + values.append("TRUE" if val else "FALSE") + elif isinstance(val, (int, float)): + values.append(str(val)) + else: + escaped = str(val).replace("\\", "\\\\").replace("'", "''") + values.append(f"'{escaped}'") + value_rows.append(f"({', '.join(values)})") + + columns = ", ".join(_COLUMN_ORDER) + values_sql = ",\n".join(value_rows) + insert_sql = f"INSERT INTO {full_table} ({columns}) VALUES\n{values_sql}" + + try: + execute_sql(insert_sql, warehouse_id=warehouse_id) + except Exception: + # Table may have been deleted externally. Clear cache, recreate, retry. + logger.info("INSERT failed — recreating table and retrying") + with _cache_lock: + _table_exists_cache.discard(full_table) + _ensure_table_exists(catalog, schema, warehouse_id) + with _cache_lock: + _table_exists_cache.add(full_table) + execute_sql(insert_sql, warehouse_id=warehouse_id) + + return { + "status": "success", + "destination": full_table, + "rows": len(rows), + "method": "delta", + } + + +def _ensure_table_exists( + catalog: str, + schema: str, + warehouse_id: Optional[str], +) -> None: + """Create schema and table if they don't exist. + + Runs once per process lifetime (cached). Skips catalog creation — + most shared workspaces don't allow CREATE CATALOG. The catalog + must exist already (or be set via AGENTGUARD_CATALOG env var to + an existing catalog the user has access to). + """ + from databricks_tools_core.sql.sql import execute_sql + + # Schema + table only. No CREATE CATALOG (requires admin). + for stmt in [ + f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}", + f"CREATE TABLE IF NOT EXISTS {catalog}.{schema}.{_ACTION_LOG_TABLE} {_TABLE_COLUMNS} USING DELTA TBLPROPERTIES ('delta.appendOnly' = 'true')", + ]: + try: + execute_sql(stmt, warehouse_id=warehouse_id) + except Exception as e: + logger.warning("Setup statement failed (non-fatal): %s — %s", stmt[:60], e) + + +# Column order must match the INSERT and CREATE TABLE +_COLUMN_ORDER = [ + "action_id", + "task_id", + "agent_id", + "action_sequence", + "tool_name", + "action_category", + "operation", + "target_resource_type", + "target_resource_id", + "target_environment", + "sql_statement", + "sql_parsed_type", + "tables_read", + "tables_written", + "rows_affected", + "cp1_policy_result", + "cp1_policy_rule_hit", + "cp2_scope_result", + "cp2_scope_violation", + "cp3_risk_score", + "cp3_risk_breakdown", + "cp3_risk_decision", + "cp4_approval_requested", + "cp4_approval_outcome", + "cp4_approval_note", + "final_decision", + "execution_success", + "execution_error", + "execution_duration_ms", + "overhead_ms", + "received_at", + "executed_at", + "completed_at", + "session_mode", + "session_description", + "scope_template", + "project_name", +] diff --git a/databricks-tools-core/databricks_tools_core/agentguard/models.py b/databricks-tools-core/databricks_tools_core/agentguard/models.py new file mode 100644 index 00000000..726129e3 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/models.py @@ -0,0 +1,580 @@ +""" +AgentGuard models + +Pydantic models for checkpoint state, actions, and sessions. +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any, ClassVar, Optional + +from pydantic import BaseModel, Field, PrivateAttr + + +class AgentGuardMode(str, Enum): + """AgentGuard session mode.""" + + MONITOR_ONLY = "monitor_only" + ENFORCE = "enforce" + + +class ActionCategory(str, Enum): + READ = "READ" + WRITE = "WRITE" + DDL = "DDL" + DCL = "DCL" + ADMIN = "ADMIN" + EXTERNAL = "EXTERNAL" + UNKNOWN = "UNKNOWN" + + +class CheckpointDecision(str, Enum): + AUTO_APPROVE = "auto_approve" + FLAG = "flag" + HOLD_FOR_APPROVAL = "hold_for_approval" + BLOCK = "block" + WOULD_BLOCK = "would_block" + + +class SessionStatus(str, Enum): + ACTIVE = "active" + COMPLETED = "completed" + FAILED = "failed" + ROLLED_BACK = "rolled_back" + + +class TimingRecord(BaseModel): + """One named timing sample (checkpoint phase).""" + + name: str + duration_ms: float + + +class CheckpointResult(BaseModel): + """Checkpoint pipeline outcome for one action.""" + + decision: CheckpointDecision = CheckpointDecision.AUTO_APPROVE + original_decision: Optional[CheckpointDecision] = None + block_reason: Optional[str] = None + blocking_checkpoint: Optional[str] = None + risk_score: float = 0.0 + risk_breakdown: dict[str, float] = Field(default_factory=dict) + + policy_result: str = CheckpointDecision.AUTO_APPROVE.value + policy_rule_hit: Optional[str] = None + scope_result: str = CheckpointDecision.AUTO_APPROVE.value + scope_violation: Optional[str] = None + + approval_requested: bool = False + approval_outcome: Optional[str] = None + approval_note: Optional[str] = None + + timings: list[TimingRecord] = Field(default_factory=list) + + @property + def blocked(self) -> bool: + """Hard block only; WOULD_BLOCK is monitor-only.""" + return self.decision == CheckpointDecision.BLOCK + + @property + def total_overhead_ms(self) -> float: + return sum(t.duration_ms for t in self.timings) + + +def _generate_action_id() -> str: + return f"act_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + + +_SQL_CATEGORY_MAP: dict[str, tuple[ActionCategory, str]] = { + "SELECT": (ActionCategory.READ, "SELECT"), + "DESCRIBE": (ActionCategory.READ, "DESCRIBE"), + "SHOW": (ActionCategory.READ, "SHOW"), + "EXPLAIN": (ActionCategory.READ, "EXPLAIN"), + "INSERT": (ActionCategory.WRITE, "INSERT"), + "UPDATE": (ActionCategory.WRITE, "UPDATE"), + "MERGE": (ActionCategory.WRITE, "MERGE"), + "DELETE": (ActionCategory.WRITE, "DELETE"), + "TRUNCATE": (ActionCategory.WRITE, "TRUNCATE"), + "COPY": (ActionCategory.WRITE, "COPY_INTO"), + "CREATE": (ActionCategory.DDL, "CREATE"), + "ALTER": (ActionCategory.DDL, "ALTER"), + "DROP": (ActionCategory.DDL, "DROP"), + "GRANT": (ActionCategory.DCL, "GRANT"), + "REVOKE": (ActionCategory.DCL, "REVOKE"), +} + + +def _strip_sql_noise(sql: str) -> str: + """Normalize SQL for keyword classification. + + Strips leading comments (-- and /* */), CTEs (WITH ... AS), and + whitespace so the first token is the actual operation keyword. + + Examples: + "-- fix\\nDROP TABLE t" → "DROP TABLE T" + "/* cleanup */ DELETE FROM t" → "DELETE FROM T" + "WITH cte AS (SELECT 1) SELECT" → "SELECT" + """ + import re + + text = sql.strip().upper() + # Strip line comments + text = re.sub(r"--[^\n]*", " ", text) + # Strip block comments + text = re.sub(r"/\*.*?\*/", " ", text, flags=re.DOTALL) + text = text.strip() + # Strip CTE prefix: WITH AS (...) + # Find the last top-level closing paren of the CTE, then take what follows + if text.startswith("WITH "): + # Simple heuristic: find the keyword after the last balanced ")" + depth = 0 + for i, ch in enumerate(text): + if ch == "(": + depth += 1 + elif ch == ")": + depth -= 1 + if depth == 0: + remainder = text[i + 1 :].strip() + if remainder: + text = remainder + break + return text + + +class Action(BaseModel): + """One intercepted agent tool call.""" + + action_id: str = Field(default_factory=_generate_action_id) + task_id: str = "" + agent_id: str = "" + + tool_name: str + tool_params: dict[str, Any] = Field(default_factory=dict) + action_sequence: int = 0 + + action_category: ActionCategory = ActionCategory.UNKNOWN + operation: str = "" + target_resource_type: str = "" + target_resource_id: str = "" + target_environment: str = "" + + sql_statement: Optional[str] = None + sql_parsed_type: Optional[str] = None + tables_read: list[str] = Field(default_factory=list) + tables_written: list[str] = Field(default_factory=list) + rows_affected: Optional[int] = None + + api_endpoint: Optional[str] = None + api_method: Optional[str] = None + + checkpoint_result: Optional[CheckpointResult] = None + + final_decision: str = "" + execution_success: Optional[bool] = None + execution_error: Optional[str] = None + execution_duration_ms: Optional[float] = None + + received_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + executed_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + overhead_ms: Optional[float] = None + + @classmethod + def from_tool_call( + cls, + tool_name: str, + tool_params: dict[str, Any], + task_id: str, + agent_id: str, + sequence: int, + ) -> Action: + """Build from an MCP tool call and enrich classification.""" + action = cls( + tool_name=tool_name, + tool_params=tool_params, + task_id=task_id, + agent_id=agent_id, + action_sequence=sequence, + ) + action._enrich_from_tool() + return action + + _ACTION_BASED_TOOLS: ClassVar[dict[str, str]] = { + "manage_jobs": "job", + "manage_job_runs": "job_run", + "manage_uc_objects": "unity_catalog", + "manage_uc_grants": "uc_grant", + "manage_uc_storage": "uc_storage", + "manage_uc_connections": "uc_connection", + "manage_uc_tags": "uc_tag", + "manage_uc_security_policies": "uc_security", + "manage_uc_monitors": "uc_monitor", + "manage_uc_sharing": "uc_sharing", + "manage_metric_views": "metric_view", + "manage_ka": "knowledge_assistant", + "manage_mas": "supervisor_agent", + "manage_vs_data": "vector_search", + "manage_workspace": "workspace", + "create_or_update_app": "app", + "create_or_update_pipeline": "pipeline", + "create_or_update_dashboard": "dashboard", + "create_or_update_genie": "genie_space", + "create_or_update_vs_endpoint": "vs_endpoint", + "create_or_update_vs_index": "vs_index", + "create_or_update_lakebase_database": "lakebase_database", + "create_or_update_lakebase_branch": "lakebase_branch", + "create_or_update_lakebase_sync": "lakebase_sync", + } + + _DELETE_TOOLS: ClassVar[dict[str, tuple[str, str]]] = { + "delete_app": ("app", "name"), + "delete_dashboard": ("dashboard", "dashboard_id"), + "delete_pipeline": ("pipeline", "pipeline_id"), + "delete_genie": ("genie_space", "space_id"), + "delete_vs_endpoint": ("vs_endpoint", "name"), + "delete_vs_index": ("vs_index", "index_name"), + "delete_lakebase_database": ("lakebase_database", "name"), + "delete_lakebase_branch": ("lakebase_branch", "name"), + "delete_lakebase_sync": ("lakebase_sync", "table_name"), + "delete_volume_file": ("file", "volume_path"), + "delete_volume_directory": ("file", "volume_path"), + "delete_tracked_resource": ("_dynamic", "resource_id"), + } + + _READ_ONLY_TOOLS: ClassVar[frozenset] = frozenset( + { + "get_table_details", + "get_volume_folder_details", + "list_warehouses", + "get_best_warehouse", + "list_clusters", + "get_best_cluster", + "get_cluster_status", + "get_pipeline", + "get_update", + "get_pipeline_events", + "find_pipeline_by_name", + "get_app", + "get_dashboard", + "get_genie", + "ask_genie", + "get_vs_endpoint", + "get_vs_index", + "query_vs_index", + "get_serving_endpoint_status", + "list_serving_endpoints", + "get_lakebase_database", + "list_volume_files", + "get_volume_file_info", + "download_from_volume", + "list_tracked_resources", + "get_current_user", + } + ) + + _STANDALONE_WRITE_TOOLS: ClassVar[dict[str, tuple[str, str]]] = { + "create_pipeline": ("pipeline", "CREATE"), + "update_pipeline": ("pipeline", "UPDATE"), + "start_update": ("pipeline", "START"), + "stop_pipeline": ("pipeline", "STOP"), + "start_cluster": ("cluster", "START_CLUSTER"), + "upload_file": ("file", "UPLOAD"), + "upload_folder": ("file", "UPLOAD"), + "upload_to_volume": ("file", "UPLOAD"), + "create_volume_directory": ("file", "CREATE_DIRECTORY"), + "generate_and_upload_pdf": ("file", "GENERATE_PDF"), + "generate_and_upload_pdfs": ("file", "GENERATE_PDF"), + "publish_dashboard": ("dashboard", "PUBLISH"), + "generate_lakebase_credential": ("lakebase_credential", "GENERATE_CREDENTIAL"), + "migrate_genie": ("genie_space", "MIGRATE"), + "query_serving_endpoint": ("serving_endpoint", "QUERY"), + } + + _CODE_EXECUTION_TOOLS: ClassVar[frozenset] = frozenset( + { + "execute_databricks_command", + "run_python_file_on_databricks", + } + ) + + _DESTRUCTIVE_ACTIONS: ClassVar[frozenset] = frozenset( + { + "delete", + "drop", + "remove", + "destroy", + "terminate", + "purge", + } + ) + + _WRITE_ACTIONS: ClassVar[frozenset] = frozenset( + { + "create", + "update", + "run_now", + "start", + "restart", + "reset", + "run", + "trigger", + "deploy", + "grant", + "revoke", + } + ) + + def _enrich_from_tool(self) -> None: + """Set category, operation, and resource fields from tool_name and tool_params.""" + + if self.tool_name == "execute_sql": + self.target_resource_type = "sql" + stmt = self.tool_params.get("sql_query") or self.tool_params.get("query", "") + if stmt: + self.sql_statement = stmt + self._classify_sql(stmt) + else: + self.action_category = ActionCategory.READ + self.operation = "SQL_EMPTY" + return + + if self.tool_name == "execute_sql_multi": + self.target_resource_type = "sql" + sql_content = self.tool_params.get("sql_content", "") + if sql_content and isinstance(sql_content, str): + sqls = [s.strip() for s in sql_content.split(";") if s.strip()] + else: + sqls = self.tool_params.get("sqls", []) + if sqls and isinstance(sqls, list) and len(sqls) > 0: + self.sql_statement = "; ".join(sqls[:5]) + self._classify_sql_multi(sqls) + else: + self.action_category = ActionCategory.READ + self.operation = "SQL_MULTI_EMPTY" + return + + if self.tool_name in self._READ_ONLY_TOOLS: + self.action_category = ActionCategory.READ + self.operation = "READ" + self.target_resource_id = self._extract_resource_id() + return + + if self.tool_name in self._DELETE_TOOLS: + resource_type, id_key = self._DELETE_TOOLS[self.tool_name] + if resource_type == "_dynamic": + resource_type = self.tool_params.get("type", "unknown") + self.target_resource_type = resource_type + self.action_category = ActionCategory.ADMIN + self.target_resource_id = self.tool_params.get(id_key, "") + self.operation = f"DELETE_{resource_type.upper()}" + return + + if self.tool_name in self._ACTION_BASED_TOOLS: + self.target_resource_type = self._ACTION_BASED_TOOLS[self.tool_name] + action_param = self._extract_action_param() + self.target_resource_id = self._extract_resource_id() + + action_lower = action_param.lower() + if action_lower in self._DESTRUCTIVE_ACTIONS: + self.action_category = ActionCategory.ADMIN + self.operation = f"DELETE_{self.target_resource_type.upper()}" + elif action_lower in ("grant", "revoke"): + self.action_category = ActionCategory.DCL + self.operation = action_param.upper() + elif action_lower in self._WRITE_ACTIONS: + self.action_category = ActionCategory.WRITE + self.operation = action_param.upper() + elif action_lower in ("get", "list", "find_by_name", "describe"): + self.action_category = ActionCategory.READ + self.operation = action_param.upper() + else: + self.action_category = ActionCategory.ADMIN + self.operation = action_param.upper() + return + + if self.tool_name in self._STANDALONE_WRITE_TOOLS: + resource_type, operation = self._STANDALONE_WRITE_TOOLS[self.tool_name] + self.target_resource_type = resource_type + self.action_category = ActionCategory.WRITE + self.operation = operation + self.target_resource_id = self._extract_resource_id() + return + + if self.tool_name in self._CODE_EXECUTION_TOOLS: + self.target_resource_type = "code_execution" + self.action_category = ActionCategory.ADMIN + self.operation = "EXECUTE_CODE" + self.target_resource_id = self.tool_params.get("cluster_id", "") + return + + self.action_category = ActionCategory.UNKNOWN + self.operation = self.tool_name.upper() + + def _extract_action_param(self) -> str: + for key in ("action", "operation", "command"): + val = self.tool_params.get(key, "") + if val: + return str(val) + if "delete" in self.tool_name.lower(): + return "delete" + if "create" in self.tool_name.lower(): + return "create" + return "unknown" + + def _extract_resource_id(self) -> str: + """First matching param wins; order avoids generic keys shadowing specific IDs. + + Note: "project_name" is intentionally excluded — it conflicts with the + session-level project_name field from identity.py. + """ + for key in ( + "app_name", + "job_id", + "pipeline_id", + "endpoint_name", + "dashboard_id", + "resource_id", + "space_id", + "tile_id", + "index_name", + "cluster_id", + "catalog_name", + "schema_name", + "table_name", + "full_name", + "instance_name", + "volume_path", + "workspace_path", + "display_name", + "name", + ): + val = self.tool_params.get(key, "") + if val: + return str(val) + return "" + + def _classify_sql(self, stmt: str) -> None: + normalized = _strip_sql_noise(stmt) + tokens = normalized.split() + first_keyword = tokens[0] if tokens else "" + + category, operation = _SQL_CATEGORY_MAP.get(first_keyword, (ActionCategory.UNKNOWN, first_keyword)) + self.action_category = category + self.sql_parsed_type = operation + self.operation = operation + + def _classify_sql_multi(self, sqls: list[str]) -> None: + """Use highest-severity statement (DCL > DDL > WRITE > READ > UNKNOWN).""" + priority = { + ActionCategory.DCL: 5, + ActionCategory.DDL: 4, + ActionCategory.WRITE: 3, + ActionCategory.READ: 2, + ActionCategory.UNKNOWN: 1, + } + worst_category = ActionCategory.UNKNOWN + worst_operation = "MULTI" + + for stmt in sqls: + normalized = stmt.strip().upper() + tokens = normalized.split() + first_keyword = tokens[0] if tokens else "" + cat, op = _SQL_CATEGORY_MAP.get(first_keyword, (ActionCategory.UNKNOWN, first_keyword)) + if priority.get(cat, 0) > priority.get(worst_category, 0): + worst_category = cat + worst_operation = op + + self.action_category = worst_category + self.sql_parsed_type = worst_operation + self.operation = f"MULTI({worst_operation})" + + +def _generate_task_id() -> str: + return f"task_{datetime.now(timezone.utc).strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" + + +class AgentGuardSession(BaseModel): + """One agent task. Not thread-safe; OK while MCP tools run single-threaded async.""" + + task_id: str = Field(default_factory=_generate_task_id) + agent_id: str = "unknown" + user_id: str = "unknown" + project_name: str = "unknown" + mode: AgentGuardMode = AgentGuardMode.MONITOR_ONLY + status: SessionStatus = SessionStatus.ACTIVE + + description: str = "" + scope_template: Optional[str] = None + scope_variables: Optional[dict[str, str]] = None + scope: Optional[Any] = None # Runtime type: Optional[ScopeManifest] (avoids circular import with scope.py) + + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + completed_at: Optional[datetime] = None + + _sequence_counter: int = PrivateAttr(default=0) + _ledger_result: Optional[dict] = PrivateAttr(default=None) + blocked_count: int = 0 + would_block_count: int = 0 + write_count: int = 0 + total_risk_score: float = 0.0 + max_risk_score: float = 0.0 + + actions: list[Action] = Field(default_factory=list) + + def next_sequence(self) -> int: + """Next sequence number (does not append an action).""" + self._sequence_counter += 1 + return self._sequence_counter + + @property + def action_count(self) -> int: + return len(self.actions) + + def record_action(self, action: Action) -> None: + self.actions.append(action) + if action.action_category in (ActionCategory.WRITE, ActionCategory.DDL, ActionCategory.ADMIN): + self.write_count += 1 + if action.checkpoint_result: + score = action.checkpoint_result.risk_score + self.total_risk_score += score + if score > self.max_risk_score: + self.max_risk_score = score + if action.checkpoint_result.decision == CheckpointDecision.BLOCK: + self.blocked_count += 1 + elif action.checkpoint_result.decision == CheckpointDecision.WOULD_BLOCK: + self.would_block_count += 1 + + @property + def avg_risk_score(self) -> float: + if self.action_count == 0: + return 0.0 + return self.total_risk_score / self.action_count + + def complete(self) -> None: + self.status = SessionStatus.COMPLETED + self.completed_at = datetime.now(timezone.utc) + + def summary(self) -> str: + mode_label = "monitor-only" if self.mode == AgentGuardMode.MONITOR_ONLY else "enforce" + duration = "" + if self.created_at: + elapsed = datetime.now(timezone.utc) - self.created_at + minutes = int(elapsed.total_seconds() // 60) + seconds = int(elapsed.total_seconds() % 60) + duration = f" | Duration: {minutes}m {seconds}s" + + block_label = "Would-block" if self.mode == AgentGuardMode.MONITOR_ONLY else "Blocked" + block_count = self.would_block_count if self.mode == AgentGuardMode.MONITOR_ONLY else self.blocked_count + + return ( + f"Session: {self.status.value} ({mode_label}) | " + f"Task: {self.task_id} | " + f"Actions: {self.action_count} | " + f"{block_label}: {block_count} | " + f"Avg Risk: {self.avg_risk_score:.0f} | " + f"Max Risk: {self.max_risk_score:.0f}" + f"{duration}" + ) diff --git a/databricks-tools-core/databricks_tools_core/agentguard/policy.py b/databricks-tools-core/databricks_tools_core/agentguard/policy.py new file mode 100644 index 00000000..a3226a21 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/policy.py @@ -0,0 +1,151 @@ +""" +AgentGuard policy + +Rules match normalized action text (not raw SQL) so optional keywords and whitespace +do not bypass patterns. Example: `DROP TABLE IF EXISTS main.foo;` → `DROP TABLE main.foo`. +""" + +from __future__ import annotations + +import re +from typing import Optional + +from databricks_tools_core.agentguard.models import ( + Action, + AgentGuardMode, + CheckpointDecision, +) + + +class PolicyRule: + """Glob pattern against normalized text; `*` → any substring. Case-insensitive.""" + + def __init__(self, pattern: str, decision: CheckpointDecision, description: str = ""): + self.pattern = pattern + self.decision = decision + self.description = description + escaped = re.escape(pattern) + regex_str = "^" + escaped.replace(r"\*", ".*") + self._regex = re.compile(regex_str, re.IGNORECASE) + + def matches(self, text: str) -> bool: + return self._regex.search(text) is not None + + +_DEFAULT_ALWAYS_BLOCK: list[tuple[str, str]] = [ + ("DROP DATABASE *", "Dropping databases is globally blocked"), + ("DROP SCHEMA *", "Dropping schemas is globally blocked"), + ("TRUNCATE TABLE *", "Truncating tables is globally blocked"), + ("GRANT ALL PRIVILEGES *", "Granting all privileges is globally blocked"), + ("REVOKE * FROM *", "Revoking privileges is globally blocked"), +] + +_DEFAULT_REQUIRE_APPROVAL: list[tuple[str, str]] = [ + ("DROP TABLE *", "Dropping tables requires approval"), + ("ALTER TABLE * DROP COLUMN *", "Dropping columns requires approval"), + ("DELETE FROM *", "Deleting data requires approval"), + ("CREATE OR REPLACE TABLE *", "Overwriting tables requires approval"), + ("CREATE OR REPLACE VIEW *", "Overwriting views requires approval"), + ("DELETE_APP *", "Deleting Databricks apps requires approval"), + ("DELETE_JOB *", "Deleting jobs requires approval"), + ("DELETE_JOB_RUN *", "Cancelling/deleting job runs requires approval"), + ("DELETE_PIPELINE *", "Deleting pipelines requires approval"), + ("DELETE_DASHBOARD *", "Deleting dashboards requires approval"), + ("DELETE_UNITY_CATALOG *", "Deleting Unity Catalog objects requires approval"), + ("DELETE_UC_GRANT *", "Revoking UC grants requires approval"), + ("DELETE_UC_STORAGE *", "Deleting UC storage credentials requires approval"), + ("DELETE_UC_CONNECTION *", "Deleting UC connections requires approval"), + ("DELETE_UC_SECURITY *", "Deleting UC security policies requires approval"), + ("DELETE_UC_MONITOR *", "Deleting UC monitors requires approval"), + ("DELETE_UC_SHARING *", "Deleting UC sharing requires approval"), + ("DELETE_METRIC_VIEW *", "Deleting metric views requires approval"), + ("DELETE_KNOWLEDGE_ASSISTANT *", "Deleting knowledge assistants requires approval"), + ("DELETE_SUPERVISOR_AGENT *", "Deleting supervisor agents requires approval"), + ("DELETE_GENIE_SPACE *", "Deleting Genie spaces requires approval"), + ("DELETE_VS_ENDPOINT *", "Deleting vector search endpoints requires approval"), + ("DELETE_VS_INDEX *", "Deleting vector search indexes requires approval"), + ("DELETE_VECTOR_SEARCH *", "Deleting vector search resources requires approval"), + ("DELETE_LAKEBASE_DATABASE *", "Deleting Lakebase databases requires approval"), + ("DELETE_LAKEBASE_BRANCH *", "Deleting Lakebase branches requires approval"), + ("DELETE_LAKEBASE_SYNC *", "Deleting Lakebase sync configs requires approval"), + ("DELETE_FILE *", "Deleting files requires approval"), + ("DELETE_WORKSPACE *", "Deleting workspace resources requires approval"), + ("GRANT *", "Granting permissions via UC tool requires approval"), + ("REVOKE *", "Revoking permissions via UC tool requires approval"), + ("EXECUTE_CODE *", "Executing arbitrary code on a cluster requires approval"), +] + + +def _normalize_sql(sql: str) -> str: + """Strip comments, IF EXISTS noise, collapse whitespace.""" + text = sql.strip().rstrip(";").strip() + text = re.sub(r"/\*.*?\*/", " ", text, flags=re.DOTALL) + text = re.sub(r"--[^\n]*", " ", text) + text = re.sub(r"\bIF\s+NOT\s+EXISTS\b", "", text, flags=re.IGNORECASE) + text = re.sub(r"\bIF\s+EXISTS\b", "", text, flags=re.IGNORECASE) + text = re.sub(r"\s+", " ", text).strip() + return text + + +class PolicyEngine: + """Match actions against always-block and require-approval rules.""" + + def __init__(self, use_defaults: bool = True) -> None: + self.always_block: list[PolicyRule] = [] + self.require_approval: list[PolicyRule] = [] + + if use_defaults: + for pattern, desc in _DEFAULT_ALWAYS_BLOCK: + self.always_block.append(PolicyRule(pattern, CheckpointDecision.BLOCK, desc)) + for pattern, desc in _DEFAULT_REQUIRE_APPROVAL: + self.require_approval.append(PolicyRule(pattern, CheckpointDecision.HOLD_FOR_APPROVAL, desc)) + + def check(self, action: Action, mode: AgentGuardMode) -> tuple[CheckpointDecision, Optional[str]]: + """Returns (decision, matching rule description or None).""" + texts_to_check = self._action_to_policy_texts(action) + if not texts_to_check: + return CheckpointDecision.AUTO_APPROVE, None + + worst_decision = CheckpointDecision.AUTO_APPROVE + worst_rule: Optional[str] = None + + for text in texts_to_check: + for rule in self.always_block: + if rule.matches(text): + if mode == AgentGuardMode.MONITOR_ONLY: + return ( + CheckpointDecision.WOULD_BLOCK, + f"[monitor] {rule.description} (pattern: {rule.pattern})", + ) + return ( + CheckpointDecision.BLOCK, + f"{rule.description} (pattern: {rule.pattern})", + ) + + for rule in self.require_approval: + if rule.matches(text): + if worst_decision != CheckpointDecision.FLAG: + worst_decision = CheckpointDecision.FLAG + worst_rule = f"{rule.description} (pattern: {rule.pattern})" + + if worst_decision != CheckpointDecision.AUTO_APPROVE: + return worst_decision, worst_rule + + return CheckpointDecision.AUTO_APPROVE, None + + def _action_to_policy_texts(self, action: Action) -> list[str]: + """Normalized SQL lines or `OPERATION resource_id` for non-SQL.""" + if action.sql_statement: + raw_statements = action.sql_statement.split(";") + texts = [] + for stmt in raw_statements: + normalized = _normalize_sql(stmt) + if normalized: + texts.append(normalized) + return texts if texts else [] + + if action.operation: + resource = action.target_resource_id or "" + return [f"{action.operation} {resource}"] + + return [] diff --git a/databricks-tools-core/databricks_tools_core/agentguard/risk.py b/databricks-tools-core/databricks_tools_core/agentguard/risk.py new file mode 100644 index 00000000..c2e9272a --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/risk.py @@ -0,0 +1,257 @@ +""" +AgentGuard risk (CP-3) + +Weighted 0–100 score per action; optional scope violation adds a fixed penalty. +Thresholds map to flag / hold / block decisions. +""" + +from __future__ import annotations + +import re +from typing import Optional + +from pydantic import BaseModel, Field + +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + CheckpointDecision, +) + + +class RiskScore(BaseModel): + """Score, decision, and per-factor breakdown.""" + + score: float = 0.0 + decision: CheckpointDecision = CheckpointDecision.AUTO_APPROVE + breakdown: dict[str, float] = Field(default_factory=dict) + + +_ACTION_TYPE_SCORES = { + "SELECT": 0, + "DESCRIBE": 0, + "SHOW": 0, + "EXPLAIN": 0, + "READ": 0, + "LIST": 0, + "GET": 0, + "FIND_BY_NAME": 0, + "INSERT": 30, + "COPY_INTO": 35, + "CREATE": 25, + "UPLOAD": 25, + "CREATE_DIRECTORY": 15, + "GENERATE_PDF": 20, + "UPDATE": 50, + "MERGE": 60, + "START": 35, + "RESTART": 40, + "STOP": 30, + "MIGRATE": 40, + "DEPLOY": 50, + "DELETE": 80, + "TRUNCATE": 90, + "DROP": 95, + "GRANT": 85, + "REVOKE": 85, + "EXECUTE_CODE": 75, + "GENERATE_CREDENTIAL": 70, + "START_CLUSTER": 45, +} + +_ACTION_TYPE_SCORES_SORTED = sorted( + _ACTION_TYPE_SCORES.items(), + key=lambda x: len(x[0]), + reverse=True, +) + + +def _score_action_type(action: Action) -> float: + op = action.operation.upper() + if op in _ACTION_TYPE_SCORES: + return _ACTION_TYPE_SCORES[op] + + for keyword, score in _ACTION_TYPE_SCORES_SORTED: + if keyword in op: + return score + + if op.startswith("DELETE_"): + return 80 + + if op.startswith("MULTI("): + inner = op.replace("MULTI(", "").rstrip(")") + return _ACTION_TYPE_SCORES.get(inner, 50) + + category_defaults = { + ActionCategory.READ: 0, + ActionCategory.WRITE: 40, + ActionCategory.DDL: 50, + ActionCategory.DCL: 80, + ActionCategory.ADMIN: 60, + ActionCategory.EXTERNAL: 50, + ActionCategory.UNKNOWN: 30, + } + return category_defaults.get(action.action_category, 30) + + +def _score_environment(action: Action) -> float: + resource = (action.target_resource_id or "").lower() + sql = (action.sql_statement or "").lower() + combined = f"{resource} {sql}" + + if "production" in combined or "prod." in combined: + return 90 + if "staging" in combined or "stg." in combined: + return 40 + if "dev." in combined or "sandbox" in combined or "test" in combined: + return 10 + + return 30 + + +def _score_blast_radius(action: Action) -> float: + if action.rows_affected is not None: + if action.rows_affected <= 1: + return 5 + if action.rows_affected <= 100: + return 15 + if action.rows_affected <= 10000: + return 35 + if action.rows_affected <= 1000000: + return 60 + return 85 + + sql = (action.sql_statement or "").upper() + if sql: + if ("DELETE" in sql or "UPDATE" in sql) and "WHERE" not in sql: + return 90 + if "DROP " in sql or "TRUNCATE " in sql: + return 100 + if action.action_category == ActionCategory.READ: + return 10 + if "WHERE" in sql: + return 30 + + op = action.operation.upper() + if op.startswith("DELETE_"): + return 70 + if op in ("START_CLUSTER", "EXECUTE_CODE"): + return 40 + + return 20 + + +def _score_time_context(action: Action) -> float: + return 25 + + +def _score_behavioral(action: Action) -> float: + return 15 + + +def _sensitive_re(keyword: str) -> re.Pattern: + return re.compile(rf"(?:^|[_.\/\s]){keyword}(?:$|[_.\/\s])", re.IGNORECASE) + + +_SENSITIVE_PATTERNS = [ + (_sensitive_re("pii"), 75), + (_sensitive_re("phi"), 90), + (_sensitive_re("financial"), 85), + (_sensitive_re("credit.?card"), 90), + (_sensitive_re("ssn"), 95), + (_sensitive_re("password"), 90), + (_sensitive_re("secret"), 80), + (_sensitive_re("token"), 70), + (_sensitive_re("salary"), 75), + (_sensitive_re("medical"), 85), + (_sensitive_re("hipaa"), 90), + (_sensitive_re("gdpr"), 80), + (_sensitive_re("sensitive"), 65), + (_sensitive_re("confidential"), 70), + (_sensitive_re("restricted"), 65), + (_sensitive_re("personal"), 60), +] + + +def _score_data_sensitivity(action: Action) -> float: + text = " ".join( + filter( + None, + [ + action.target_resource_id, + action.sql_statement, + " ".join(action.tables_read), + " ".join(action.tables_written), + ], + ) + ) + + if not text: + return 0 + + max_score = 0 + for pattern, score in _SENSITIVE_PATTERNS: + if pattern.search(text): + max_score = max(max_score, score) + + return max_score + + +_DEFAULT_WEIGHTS = { + "action_type": 0.25, + "environment": 0.25, + "blast_radius": 0.20, + "time_context": 0.10, + "behavioral": 0.10, + "data_sensitivity": 0.10, +} + +_DEFAULT_THRESHOLDS = { + "flag_above": 35, + "hold_above": 70, + "block_above": 90, +} + +_SCOPE_VIOLATION_PENALTY = 25 + + +def compute_risk_score( + action: Action, + scope_violated: bool = False, + weights: Optional[dict[str, float]] = None, + thresholds: Optional[dict[str, float]] = None, +) -> RiskScore: + """Weighted score; `scope_violated` adds a fixed penalty.""" + w = weights or _DEFAULT_WEIGHTS + t = thresholds or _DEFAULT_THRESHOLDS + + breakdown = { + "action_type": _score_action_type(action), + "environment": _score_environment(action), + "blast_radius": _score_blast_radius(action), + "time_context": _score_time_context(action), + "behavioral": _score_behavioral(action), + "data_sensitivity": _score_data_sensitivity(action), + } + + base_score = sum(breakdown[k] * w[k] for k in breakdown) + + penalty = _SCOPE_VIOLATION_PENALTY if scope_violated else 0 + breakdown["scope_violation_penalty"] = penalty + + final_score = min(100, base_score + penalty) + + if final_score >= t["block_above"]: + decision = CheckpointDecision.BLOCK + elif final_score >= t["hold_above"]: + decision = CheckpointDecision.HOLD_FOR_APPROVAL + elif final_score >= t["flag_above"]: + decision = CheckpointDecision.FLAG + else: + decision = CheckpointDecision.AUTO_APPROVE + + return RiskScore( + score=round(final_score, 2), + decision=decision, + breakdown={k: round(v, 2) for k, v in breakdown.items()}, + ) diff --git a/databricks-tools-core/databricks_tools_core/agentguard/scope.py b/databricks-tools-core/databricks_tools_core/agentguard/scope.py new file mode 100644 index 00000000..9a335686 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/scope.py @@ -0,0 +1,214 @@ +""" +AgentGuard scope (CP-2) + +Optional manifest: allowed resources per type. Enforce blocks out-of-scope; monitor-only uses WOULD_BLOCK. +Without scope, CP-2 defers; risk still runs on the action alone. +""" + +from __future__ import annotations + +import fnmatch +import json +import re +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field + +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + AgentGuardMode, + CheckpointDecision, +) + +_TEMPLATES_DIR = Path(__file__).parent / "templates" + + +class ResourceScope(BaseModel): + """Per-resource-type allow lists (glob patterns).""" + + read: list[str] = Field(default_factory=list) + write: list[str] = Field(default_factory=list) + ddl: list[str] = Field(default_factory=list) + delete: list[str] = Field(default_factory=list) + + +class ScopeManifest(BaseModel): + """Task blast radius as globs per resource bucket.""" + + tables: ResourceScope = Field(default_factory=ResourceScope) + jobs: ResourceScope = Field(default_factory=ResourceScope) + pipelines: ResourceScope = Field(default_factory=ResourceScope) + apps: ResourceScope = Field(default_factory=ResourceScope) + dashboards: ResourceScope = Field(default_factory=ResourceScope) + clusters: ResourceScope = Field(default_factory=ResourceScope) + serving_endpoints: ResourceScope = Field(default_factory=ResourceScope) + genie_spaces: ResourceScope = Field(default_factory=ResourceScope) + knowledge_assistants: ResourceScope = Field(default_factory=ResourceScope) + supervisor_agents: ResourceScope = Field(default_factory=ResourceScope) + unity_catalog: ResourceScope = Field(default_factory=ResourceScope) + vector_search: ResourceScope = Field(default_factory=ResourceScope) + lakebase: ResourceScope = Field(default_factory=ResourceScope) + files: ResourceScope = Field(default_factory=ResourceScope) + code_execution: ResourceScope = Field(default_factory=ResourceScope) + + default_deny: bool = False + + max_actions: Optional[int] = None + max_write_actions: Optional[int] = None + + +def _resource_scope_has_patterns(rs: ResourceScope) -> bool: + return bool(rs.read or rs.write or rs.ddl or rs.delete) + + +def _matches_any(value: str, patterns: list[str]) -> bool: + value_lower = value.lower() + for pattern in patterns: + if fnmatch.fnmatch(value_lower, pattern.lower()): + return True + return False + + +_RESOURCE_TYPE_TO_SCOPE_FIELD = { + "sql": "tables", + "job": "jobs", + "job_run": "jobs", + "pipeline": "pipelines", + "app": "apps", + "dashboard": "dashboards", + "cluster": "clusters", + "serving_endpoint": "serving_endpoints", + "genie_space": "genie_spaces", + "knowledge_assistant": "knowledge_assistants", + "supervisor_agent": "supervisor_agents", + "unity_catalog": "unity_catalog", + "uc_grant": "unity_catalog", + "uc_storage": "unity_catalog", + "uc_connection": "unity_catalog", + "uc_tag": "unity_catalog", + "uc_security": "unity_catalog", + "uc_monitor": "unity_catalog", + "uc_sharing": "unity_catalog", + "metric_view": "unity_catalog", + "vs_endpoint": "vector_search", + "vs_index": "vector_search", + "vector_search": "vector_search", + "lakebase_database": "lakebase", + "lakebase_branch": "lakebase", + "lakebase_sync": "lakebase", + "lakebase_credential": "lakebase", + "file": "files", + "code_execution": "code_execution", + "workspace": "unity_catalog", +} + +_CATEGORY_TO_SCOPE_OP = { + ActionCategory.READ: "read", + ActionCategory.WRITE: "write", + ActionCategory.DDL: "ddl", + ActionCategory.DCL: "write", + ActionCategory.ADMIN: "delete", +} + + +def check_scope( + action: Action, + scope: ScopeManifest, + mode: AgentGuardMode, +) -> tuple[CheckpointDecision, Optional[str]]: + """Returns (decision, violation message or None).""" + resource_id = action.target_resource_id + if not resource_id: + return CheckpointDecision.AUTO_APPROVE, None + + scope_field_name = _RESOURCE_TYPE_TO_SCOPE_FIELD.get(action.target_resource_type) + if scope_field_name is None: + return CheckpointDecision.AUTO_APPROVE, None + + resource_scope: ResourceScope = getattr(scope, scope_field_name) + + scope_op = _CATEGORY_TO_SCOPE_OP.get(action.action_category, "write") + + allowed_patterns: list[str] = getattr(resource_scope, scope_op, []) + + if not allowed_patterns: + if scope.default_deny and _resource_scope_has_patterns(resource_scope): + violation = ( + f"Operation '{scope_op}' on '{scope_field_name}' is not explicitly allowed (default_deny is enabled)" + ) + if mode == AgentGuardMode.MONITOR_ONLY: + return CheckpointDecision.WOULD_BLOCK, f"[monitor] {violation}" + return CheckpointDecision.BLOCK, violation + return CheckpointDecision.AUTO_APPROVE, None + + if _matches_any(resource_id, allowed_patterns): + return CheckpointDecision.AUTO_APPROVE, None + + violation = ( + f"Resource '{resource_id}' ({scope_op}) is not in scope for " + f"{scope_field_name}. Allowed patterns: {allowed_patterns}" + ) + + if mode == AgentGuardMode.MONITOR_ONLY: + return CheckpointDecision.WOULD_BLOCK, f"[monitor] {violation}" + + return CheckpointDecision.BLOCK, violation + + +def check_scope_limits( + action: Action, + scope: ScopeManifest, + session_action_count: int, + session_write_count: int, + mode: AgentGuardMode, +) -> tuple[CheckpointDecision, Optional[str]]: + """Enforce max_actions / max_write_actions.""" + if scope.max_actions and session_action_count >= scope.max_actions: + violation = f"Session exceeded max_actions limit ({scope.max_actions})" + if mode == AgentGuardMode.MONITOR_ONLY: + return CheckpointDecision.WOULD_BLOCK, f"[monitor] {violation}" + return CheckpointDecision.BLOCK, violation + + if ( + scope.max_write_actions + and action.action_category in (ActionCategory.WRITE, ActionCategory.DDL, ActionCategory.ADMIN) + and session_write_count >= scope.max_write_actions + ): + violation = f"Session exceeded max_write_actions limit ({scope.max_write_actions})" + if mode == AgentGuardMode.MONITOR_ONLY: + return CheckpointDecision.WOULD_BLOCK, f"[monitor] {violation}" + return CheckpointDecision.BLOCK, violation + + return CheckpointDecision.AUTO_APPROVE, None + + +def load_template(template_name: str, variables: Optional[dict[str, str]] = None) -> ScopeManifest: + """Load `templates/{name}.json`; substitute `${key}` from variables.""" + template_path = _TEMPLATES_DIR / f"{template_name}.json" + if not template_path.exists(): + available = [f.stem for f in _TEMPLATES_DIR.glob("*.json")] + raise FileNotFoundError(f"Scope template '{template_name}' not found. Available templates: {available}") + + raw = template_path.read_text() + + if variables: + for key, value in variables.items(): + raw = raw.replace(f"${{{key}}}", value) + + unresolved = re.findall(r"\$\{(\w+)\}", raw) + if unresolved: + raise ValueError( + f"Scope template '{template_name}' has unresolved variables: {unresolved}. " + f"Provide them via variables parameter." + ) + + data = json.loads(raw) + return ScopeManifest.model_validate(data) + + +def list_templates() -> list[str]: + if not _TEMPLATES_DIR.exists(): + return [] + return sorted(f.stem for f in _TEMPLATES_DIR.glob("*.json")) diff --git a/databricks-tools-core/databricks_tools_core/agentguard/session.py b/databricks-tools-core/databricks_tools_core/agentguard/session.py new file mode 100644 index 00000000..b8dd8093 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/session.py @@ -0,0 +1,128 @@ +""" +AgentGuard session lifecycle + +Start: optional scope load, JSONL init, orphan report. Stop: flush ledger to Delta, clear context. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +from databricks_tools_core.agentguard.context import ( + clear_active_session, + get_active_session, + set_active_session, +) +from databricks_tools_core.agentguard.models import ( + AgentGuardMode, + AgentGuardSession, + SessionStatus, +) + +logger = logging.getLogger(__name__) + + +def start_session( + mode: AgentGuardMode = AgentGuardMode.MONITOR_ONLY, + description: str = "", + scope_template: Optional[str] = None, + scope_variables: Optional[dict[str, str]] = None, + agent_id: str = "unknown", + user_id: str = "unknown", +) -> AgentGuardSession: + """Raises ValueError if an active session already exists.""" + existing = get_active_session() + if existing and existing.status == SessionStatus.ACTIVE: + raise ValueError(f"An AgentGuard session is already active: {existing.task_id}. Run /agentguard stop first.") + + project_name = "unknown" + try: + from databricks_tools_core.identity import detect_project_name + + project_name = detect_project_name() + except Exception: + pass + + session = AgentGuardSession( + mode=mode, + description=description, + scope_template=scope_template, + scope_variables=scope_variables, + agent_id=agent_id, + user_id=user_id, + project_name=project_name, + ) + + if scope_template: + try: + from databricks_tools_core.agentguard.scope import load_template + + session.scope = load_template(scope_template, scope_variables) + logger.info(f"Scope template '{scope_template}' loaded for task {session.task_id}") + except (FileNotFoundError, ValueError) as e: + logger.warning(f"Failed to load scope template '{scope_template}': {e}") + + try: + from databricks_tools_core.agentguard.ledger import init_session_file + + init_session_file(session.task_id) + except Exception as e: + logger.warning(f"Could not initialize session JSONL: {e}") + + _report_orphans() + + set_active_session(session) + + mode_label = "monitor-only" if mode == AgentGuardMode.MONITOR_ONLY else "enforce" + scope_label = f" | Scope: {scope_template}" if scope_template else " | No scope" + logger.info(f"AgentGuard session started ({mode_label}{scope_label}). Task ID: {session.task_id}") + return session + + +def stop_session() -> Optional[AgentGuardSession]: + """Complete session, flush JSONL to Delta, return session (or None if none active).""" + session = get_active_session() + if session is None: + return None + + session.complete() + clear_active_session() + + try: + from databricks_tools_core.agentguard.ledger import flush_session_to_delta + + ledger_result = flush_session_to_delta(session) + session._ledger_result = ledger_result + logger.info( + f"Audit ledger flush: {ledger_result.get('status')} " + f"({ledger_result.get('rows', 0)} rows -> {ledger_result.get('destination', 'unknown')})" + ) + except Exception as e: + logger.warning(f"Audit ledger flush failed: {e}") + session._ledger_result = {"status": "error", "error": str(e)} + + logger.info(f"AgentGuard session stopped. {session.summary()}") + return session + + +def get_session_status() -> Optional[str]: + session = get_active_session() + if session is None: + return None + return session.summary() + + +def _report_orphans() -> None: + try: + from databricks_tools_core.agentguard.ledger import find_orphaned_sessions + + orphans = find_orphaned_sessions() + if orphans: + ids = ", ".join(o["task_id"] for o in orphans) + logger.warning( + f"Found {len(orphans)} orphaned AgentGuard session(s) from previous runs. " + f"Data is preserved locally at ~/.agentguard/sessions/. Sessions: {ids}" + ) + except Exception: + pass diff --git a/databricks-tools-core/databricks_tools_core/agentguard/templates/data_quality_check.json b/databricks-tools-core/databricks_tools_core/agentguard/templates/data_quality_check.json new file mode 100644 index 00000000..3a9c98fa --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/templates/data_quality_check.json @@ -0,0 +1,36 @@ +{ + "_description": "Scope for running data quality analysis. Agent can read target tables and write results to a dedicated dq_results schema. No deletes, no resource modifications.", + "_variables": ["catalog", "schema", "target_table"], + "_example_usage": "agentguard_start(mode='enforce', scope_template='data_quality_check', scope_variables={'catalog': 'main', 'schema': 'production', 'target_table': 'customer_orders'})", + + "tables": { + "read": [ + "${catalog}.${schema}.${target_table}", + "${catalog}.${schema}.${target_table}_*", + "${catalog}.information_schema.*" + ], + "write": [ + "${catalog}.dq_results.*" + ], + "ddl": [ + "${catalog}.dq_results.*" + ], + "delete": [] + }, + "clusters": { + "read": ["*"], + "write": [], + "ddl": [], + "delete": [] + }, + "files": { + "read": ["*/dq_configs/*", "*/quality_rules/*"], + "write": ["*/dq_results/*"], + "ddl": [], + "delete": [] + }, + + "default_deny": true, + "max_actions": 200, + "max_write_actions": 30 +} diff --git a/databricks-tools-core/databricks_tools_core/agentguard/templates/etl_pipeline_fix.json b/databricks-tools-core/databricks_tools_core/agentguard/templates/etl_pipeline_fix.json new file mode 100644 index 00000000..a4852179 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/templates/etl_pipeline_fix.json @@ -0,0 +1,49 @@ +{ + "_description": "Scope for fixing an ETL pipeline. Agent can read production source tables, write to staging temp tables, and modify the target pipeline/job. All other resources are denied.", + "_variables": ["catalog", "target_table"], + "_example_usage": "agentguard_start(mode='enforce', scope_template='etl_pipeline_fix', scope_variables={'catalog': 'main', 'target_table': 'customer_orders'})", + + "tables": { + "read": [ + "${catalog}.production.${target_table}", + "${catalog}.production.${target_table}_*", + "${catalog}.staging.*" + ], + "write": [ + "${catalog}.staging.${target_table}_*", + "${catalog}.staging.temp_*" + ], + "ddl": [ + "${catalog}.staging.temp_*" + ], + "delete": [] + }, + "jobs": { + "read": ["*"], + "write": ["*${target_table}*", "*etl*"], + "ddl": [], + "delete": [] + }, + "pipelines": { + "read": ["*"], + "write": ["*${target_table}*", "*etl*"], + "ddl": [], + "delete": [] + }, + "clusters": { + "read": ["*"], + "write": [], + "ddl": [], + "delete": [] + }, + "files": { + "read": ["*/staging/*", "*/etl_configs/*"], + "write": ["*/staging/*"], + "ddl": [], + "delete": [] + }, + + "default_deny": true, + "max_actions": 100, + "max_write_actions": 20 +} diff --git a/databricks-tools-core/databricks_tools_core/agentguard/templates/model_deployment.json b/databricks-tools-core/databricks_tools_core/agentguard/templates/model_deployment.json new file mode 100644 index 00000000..ae8fa1b5 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/templates/model_deployment.json @@ -0,0 +1,37 @@ +{ + "_description": "Scope for deploying an ML model to a serving endpoint. Agent can read model registry and feature store, update the target serving endpoint. No table writes, no deletes.", + "_variables": ["catalog", "endpoint_name"], + "_example_usage": "agentguard_start(mode='enforce', scope_template='model_deployment', scope_variables={'catalog': 'main', 'endpoint_name': 'fraud_detection_v2'})", + + "tables": { + "read": [ + "${catalog}.ml_models.*", + "${catalog}.feature_store.*" + ], + "write": [], + "ddl": [], + "delete": [] + }, + "serving_endpoints": { + "read": ["*"], + "write": ["${endpoint_name}"], + "ddl": [], + "delete": [] + }, + "unity_catalog": { + "read": ["*"], + "write": [], + "ddl": [], + "delete": [] + }, + "clusters": { + "read": ["*"], + "write": [], + "ddl": [], + "delete": [] + }, + + "default_deny": true, + "max_actions": 75, + "max_write_actions": 10 +} diff --git a/databricks-tools-core/databricks_tools_core/agentguard/timing.py b/databricks-tools-core/databricks_tools_core/agentguard/timing.py new file mode 100644 index 00000000..d8f925d8 --- /dev/null +++ b/databricks-tools-core/databricks_tools_core/agentguard/timing.py @@ -0,0 +1,45 @@ +""" +AgentGuard timing + +Per-phase timing for the checkpoint pipeline. +""" + +from __future__ import annotations + +import time +from typing import Any, Callable, TypeVar + +from databricks_tools_core.agentguard.models import TimingRecord + +T = TypeVar("T") + + +class Timer: + """Timing measurements for one action’s checkpoint run.""" + + def __init__(self) -> None: + self.records: list[TimingRecord] = [] + self._start: float = time.perf_counter() + + def measure(self, name: str, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: + start = time.perf_counter() + result = func(*args, **kwargs) + elapsed_ms = (time.perf_counter() - start) * 1000 + self.records.append(TimingRecord(name=name, duration_ms=round(elapsed_ms, 3))) + return result + + async def measure_async(self, name: str, coro: Any) -> Any: + start = time.perf_counter() + result = await coro + elapsed_ms = (time.perf_counter() - start) * 1000 + self.records.append(TimingRecord(name=name, duration_ms=round(elapsed_ms, 3))) + return result + + @property + def total_ms(self) -> float: + return sum(r.duration_ms for r in self.records) + + def summary(self) -> str: + parts = [f"{r.name}: {r.duration_ms:.1f}ms" for r in self.records] + parts.append(f"TOTAL: {self.total_ms:.1f}ms") + return " | ".join(parts) diff --git a/databricks-tools-core/tests/integration/agentguard/__init__.py b/databricks-tools-core/tests/integration/agentguard/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databricks-tools-core/tests/integration/agentguard/conftest.py b/databricks-tools-core/tests/integration/agentguard/conftest.py new file mode 100644 index 00000000..0820a6ee --- /dev/null +++ b/databricks-tools-core/tests/integration/agentguard/conftest.py @@ -0,0 +1,123 @@ +""" +Fixtures for AgentGuard integration tests. + +Extends the root conftest.py fixtures (workspace_client, test_catalog, +test_schema, warehouse_id, test_tables) with AgentGuard-specific fixtures +for session lifecycle, middleware, and audit ledger testing. + +Requires a live Databricks workspace with a running SQL warehouse. +""" + +import logging +import os +from typing import Generator + +import pytest + +from databricks_tools_core.agentguard.context import clear_active_session +from databricks_tools_core.agentguard.models import AgentGuardMode, AgentGuardSession +from databricks_tools_core.agentguard.policy import PolicyEngine +from databricks_tools_core.agentguard.session import start_session, stop_session + +logger = logging.getLogger(__name__) + +# Audit ledger test catalog/schema — isolated from the main "agentguard" catalog +LEDGER_TEST_CATALOG = os.environ.get("TEST_CATALOG", "ai_dev_kit_test") +LEDGER_TEST_SCHEMA = "agentguard_test" + + +@pytest.fixture(autouse=True) +def _clean_session(): + """Ensure no stale session leaks between tests.""" + clear_active_session() + yield + clear_active_session() + + +@pytest.fixture +def policy_engine() -> PolicyEngine: + return PolicyEngine() + + +@pytest.fixture +def monitor_session() -> Generator[AgentGuardSession, None, None]: + """Start a monitor-only AgentGuard session, stop it after the test.""" + session = start_session( + mode=AgentGuardMode.MONITOR_ONLY, + description="integration-test-monitor", + agent_id="test-agent", + user_id="test-user", + ) + yield session + try: + stop_session() + except Exception: + clear_active_session() + + +@pytest.fixture +def enforce_session() -> Generator[AgentGuardSession, None, None]: + """Start an enforce-mode AgentGuard session, stop it after the test.""" + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="integration-test-enforce", + agent_id="test-agent", + user_id="test-user", + ) + yield session + try: + stop_session() + except Exception: + clear_active_session() + + +@pytest.fixture(scope="module") +def ledger_schema(workspace_client, warehouse_id): + """Create and clean up the audit ledger test schema. + + Uses the test catalog from the root conftest and creates a dedicated + schema for ledger tests so we don't pollute production data. + """ + from databricks_tools_core.sql.sql import execute_sql + + full_schema = f"{LEDGER_TEST_CATALOG}.{LEDGER_TEST_SCHEMA}" + + # Ensure catalog exists + execute_sql( + f"CREATE CATALOG IF NOT EXISTS {LEDGER_TEST_CATALOG}", + warehouse_id=warehouse_id, + ) + + # Clean slate + try: + execute_sql( + f"DROP SCHEMA IF EXISTS {full_schema} CASCADE", + warehouse_id=warehouse_id, + ) + except Exception as e: + logger.debug("Schema cleanup on setup failed (may not exist): %s", e) + + execute_sql( + f"CREATE SCHEMA IF NOT EXISTS {full_schema}", + warehouse_id=warehouse_id, + ) + + logger.info("Created ledger test schema: %s", full_schema) + + yield { + "catalog": LEDGER_TEST_CATALOG, + "schema": LEDGER_TEST_SCHEMA, + "full_schema": full_schema, + } + + # Cleanup after all tests in this module + # NOTE: Commented out so you can inspect the table after test runs. + # Re-enable when done investigating. + # try: + # execute_sql( + # f"DROP SCHEMA IF EXISTS {full_schema} CASCADE", + # warehouse_id=warehouse_id, + # ) + # logger.info("Cleaned up ledger test schema: %s", full_schema) + # except Exception as e: + # logger.warning("Failed to clean up ledger schema: %s", e) diff --git a/databricks-tools-core/tests/integration/agentguard/test_audit_ledger.py b/databricks-tools-core/tests/integration/agentguard/test_audit_ledger.py new file mode 100644 index 00000000..6d455dcd --- /dev/null +++ b/databricks-tools-core/tests/integration/agentguard/test_audit_ledger.py @@ -0,0 +1,274 @@ +""" +Integration tests for AgentGuard audit ledger (Delta + JSONL). + +Tests: +- flush_session_to_delta, JSONL append, failure fallback + +Requires live workspace, warehouse, and catalog/schema permissions. +""" + +import json + +import pytest + +from databricks_tools_core.agentguard.ledger import ( + append_action, + flush_session_to_delta, + init_session_file, +) +from databricks_tools_core.agentguard.models import ( + Action, + AgentGuardMode, + AgentGuardSession, + CheckpointDecision, + CheckpointResult, +) + + +def _build_test_session(num_actions: int = 3) -> AgentGuardSession: + """Builds a completed session with actions mirrored to JSONL.""" + session = AgentGuardSession( + mode=AgentGuardMode.MONITOR_ONLY, + description="ledger-integration-test", + agent_id="test-agent", + user_id="test-user", + ) + + init_session_file(session.task_id) + + for i in range(num_actions): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"sql_query": f"SELECT {i} AS val"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + action.checkpoint_result = CheckpointResult( + decision=CheckpointDecision.AUTO_APPROVE, + risk_score=float(i * 10), + risk_breakdown={"action_type": float(i * 5), "environment": 10.0}, + ) + action.final_decision = "executed" + action.execution_success = True + session.record_action(action) + append_action(session.task_id, action, session) + + session.complete() + return session + + +@pytest.fixture(autouse=True) +def _use_tmp_sessions_dir(tmp_path, monkeypatch): + """Points JSONL output at tmp_path for isolation.""" + monkeypatch.setattr( + "databricks_tools_core.agentguard.ledger._SESSIONS_DIR", + tmp_path, + ) + + +@pytest.mark.integration +class TestAuditLedgerDeltaWrite: + """Tests for writing audit rows to Delta.""" + + def test_flush_creates_table_and_writes(self, warehouse_id, ledger_schema): + """Should create or reuse the table and insert rows.""" + session = _build_test_session(num_actions=3) + + result = flush_session_to_delta( + session, + catalog=ledger_schema["catalog"], + schema=ledger_schema["schema"], + warehouse_id=warehouse_id, + ) + + assert result["status"] == "success" + assert result["rows"] == 3 + assert result["method"] == "delta" + assert ledger_schema["catalog"] in result["destination"] + + def test_written_data_is_queryable(self, warehouse_id, ledger_schema): + """Should read back rows with expected columns.""" + from databricks_tools_core.sql.sql import execute_sql + + session = _build_test_session(num_actions=2) + task_id = session.task_id + + flush_session_to_delta( + session, + catalog=ledger_schema["catalog"], + schema=ledger_schema["schema"], + warehouse_id=warehouse_id, + ) + + rows = execute_sql( + f"SELECT * FROM {ledger_schema['full_schema']}.action_log " + f"WHERE task_id = '{task_id}' ORDER BY action_sequence", + warehouse_id=warehouse_id, + ) + + assert len(rows) == 2 + assert rows[0]["tool_name"] == "execute_sql" + assert rows[0]["task_id"] == task_id + assert rows[0]["agent_id"] == "test-agent" + assert rows[0]["final_decision"] == "executed" + + def test_risk_scores_persisted(self, warehouse_id, ledger_schema): + """Should persist cp3_risk_score and breakdown.""" + from databricks_tools_core.sql.sql import execute_sql + + session = _build_test_session(num_actions=1) + task_id = session.task_id + + flush_session_to_delta( + session, + catalog=ledger_schema["catalog"], + schema=ledger_schema["schema"], + warehouse_id=warehouse_id, + ) + + rows = execute_sql( + f"SELECT cp3_risk_score, cp3_risk_breakdown FROM " + f"{ledger_schema['full_schema']}.action_log " + f"WHERE task_id = '{task_id}'", + warehouse_id=warehouse_id, + ) + + assert len(rows) == 1 + risk_score = float(rows[0]["cp3_risk_score"]) + assert risk_score == 0.0 + + def test_empty_session_skipped(self, warehouse_id, ledger_schema): + """Should skip flush when there are no actions.""" + session = AgentGuardSession( + mode=AgentGuardMode.MONITOR_ONLY, + description="empty-session", + ) + init_session_file(session.task_id) + session.complete() + + result = flush_session_to_delta( + session, + catalog=ledger_schema["catalog"], + schema=ledger_schema["schema"], + warehouse_id=warehouse_id, + ) + + assert result["status"] == "skipped" + assert result["rows"] == 0 + + def test_sql_injection_safe(self, warehouse_id, ledger_schema): + """Should store malicious-looking text without breaking SQL.""" + from databricks_tools_core.sql.sql import execute_sql + + session = AgentGuardSession( + mode=AgentGuardMode.MONITOR_ONLY, + description="injection'; DROP TABLE evil; --", + agent_id="test-agent", + ) + init_session_file(session.task_id) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"sql_query": "SELECT 'O''Brien' AS name"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + action.checkpoint_result = CheckpointResult() + action.final_decision = "executed" + action.execution_success = True + session.record_action(action) + append_action(session.task_id, action, session) + session.complete() + + result = flush_session_to_delta( + session, + catalog=ledger_schema["catalog"], + schema=ledger_schema["schema"], + warehouse_id=warehouse_id, + ) + + assert result["status"] == "success" + + rows = execute_sql( + f"SELECT session_description FROM {ledger_schema['full_schema']}.action_log " + f"WHERE task_id = '{session.task_id}'", + warehouse_id=warehouse_id, + ) + assert len(rows) == 1 + assert "DROP TABLE evil" in rows[0]["session_description"] + + +@pytest.mark.integration +class TestAuditLedgerJSONLPersistence: + """Tests for JSONL hot path.""" + + def test_append_action_writes_jsonl(self): + """Should append one JSON line per action.""" + session = AgentGuardSession( + mode=AgentGuardMode.MONITOR_ONLY, + description="jsonl-write-test", + agent_id="test-agent", + ) + jsonl_path = init_session_file(session.task_id) + + for i in range(2): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"sql_query": f"SELECT {i}"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + action.checkpoint_result = CheckpointResult() + action.final_decision = "executed" + session.record_action(action) + append_action(session.task_id, action, session) + + assert jsonl_path.exists() + lines = [line for line in jsonl_path.read_text().splitlines() if line.strip()] + assert len(lines) == 2 + + row = json.loads(lines[0]) + assert row["tool_name"] == "execute_sql" + assert row["task_id"] == session.task_id + + def test_delta_failure_preserves_jsonl(self, monkeypatch): + """Should keep JSONL when Delta write fails.""" + session = AgentGuardSession( + mode=AgentGuardMode.MONITOR_ONLY, + description="delta-failure-test", + agent_id="test-agent", + ) + jsonl_path = init_session_file(session.task_id) + + for i in range(2): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"sql_query": f"SELECT {i}"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + action.checkpoint_result = CheckpointResult() + action.final_decision = "executed" + session.record_action(action) + append_action(session.task_id, action, session) + + session.complete() + + def _failing_write(*args, **kwargs): + raise ConnectionError("Simulated Delta failure") + + monkeypatch.setattr( + "databricks_tools_core.agentguard.ledger._write_to_delta", + _failing_write, + ) + + result = flush_session_to_delta(session) + + assert result["status"] == "pending" + assert result["rows"] == 2 + assert result["method"] == "local_jsonl" + assert jsonl_path.exists(), "JSONL should be preserved when Delta fails" diff --git a/databricks-tools-core/tests/integration/agentguard/test_full_pipeline.py b/databricks-tools-core/tests/integration/agentguard/test_full_pipeline.py new file mode 100644 index 00000000..892a971a --- /dev/null +++ b/databricks-tools-core/tests/integration/agentguard/test_full_pipeline.py @@ -0,0 +1,461 @@ +""" +Integration tests for the full AgentGuard checkpoint pipeline. + +Tests: +- monitor vs enforce sessions, scope, limits, risk escalation, summary +""" + +import pytest + +from databricks_tools_core.agentguard.models import ( + Action, + AgentGuardMode, + CheckpointDecision, + CheckpointResult, + SessionStatus, +) +from databricks_tools_core.agentguard.policy import PolicyEngine +from databricks_tools_core.agentguard.risk import compute_risk_score +from databricks_tools_core.agentguard.scope import ( + ScopeManifest, + ResourceScope, + check_scope, + check_scope_limits, +) +from databricks_tools_core.agentguard.session import start_session, stop_session + + +def _run_checkpoint_pipeline( + action: Action, + session, + policy_engine: PolicyEngine, +) -> CheckpointResult: + """Runs policy, scope/limits, and risk like AgentGuardMiddleware.""" + result = CheckpointResult() + scope_violated = False + + policy_decision, rule_hit = policy_engine.check(action, session.mode) + result.policy_result = policy_decision.value + result.policy_rule_hit = rule_hit + + if policy_decision in (CheckpointDecision.BLOCK, CheckpointDecision.WOULD_BLOCK): + result.decision = policy_decision + result.block_reason = rule_hit + result.blocking_checkpoint = "CP-1: Policy" + elif policy_decision == CheckpointDecision.FLAG: + result.decision = CheckpointDecision.FLAG + result.block_reason = rule_hit + result.blocking_checkpoint = "CP-1: Policy -> CP-4: Approval" + + if session.scope is not None: + scope_decision, scope_violation = check_scope(action, session.scope, session.mode) + result.scope_result = scope_decision.value + result.scope_violation = scope_violation + + if scope_decision in (CheckpointDecision.BLOCK, CheckpointDecision.WOULD_BLOCK): + scope_violated = True + if result.decision in (CheckpointDecision.AUTO_APPROVE, CheckpointDecision.FLAG): + result.decision = scope_decision + result.block_reason = scope_violation + result.blocking_checkpoint = "CP-2: Scope" + + limit_decision, limit_violation = check_scope_limits( + action, + session.scope, + session.action_count, + session.write_count, + session.mode, + ) + if limit_decision in (CheckpointDecision.BLOCK, CheckpointDecision.WOULD_BLOCK): + if result.decision in (CheckpointDecision.AUTO_APPROVE, CheckpointDecision.FLAG): + result.decision = limit_decision + result.block_reason = limit_violation + result.blocking_checkpoint = "CP-2: Scope (limit)" + + risk = compute_risk_score(action, scope_violated=scope_violated) + result.risk_score = risk.score + result.risk_breakdown = risk.breakdown + + if risk.decision == CheckpointDecision.BLOCK: + if result.decision in ( + CheckpointDecision.AUTO_APPROVE, + CheckpointDecision.FLAG, + CheckpointDecision.HOLD_FOR_APPROVAL, + ): + result.decision = CheckpointDecision.BLOCK + result.block_reason = result.block_reason or f"Risk score {risk.score:.0f} exceeds block threshold" + result.blocking_checkpoint = "CP-3: Risk Score" + elif risk.decision == CheckpointDecision.HOLD_FOR_APPROVAL: + if result.decision in (CheckpointDecision.AUTO_APPROVE, CheckpointDecision.FLAG): + result.decision = CheckpointDecision.HOLD_FOR_APPROVAL + result.block_reason = result.block_reason or f"Risk score {risk.score:.0f} requires human approval" + result.blocking_checkpoint = "CP-3: Risk Score -> CP-4: Approval" + elif risk.decision == CheckpointDecision.FLAG: + if result.decision == CheckpointDecision.AUTO_APPROVE: + result.decision = CheckpointDecision.FLAG + result.block_reason = result.block_reason or f"Risk score {risk.score:.0f} exceeds flag threshold" + result.blocking_checkpoint = "CP-3: Risk Score -> CP-4: Approval" + + return result + + +@pytest.mark.integration +class TestFullPipelineMonitorMode: + """Tests for pipeline behavior in MONITOR_ONLY.""" + + def test_mixed_operations_all_recorded(self): + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.MONITOR_ONLY, + description="full-pipeline-monitor", + ) + + tool_calls = [ + ("execute_sql", {"query": "SELECT * FROM main.analytics.orders"}), + ("execute_sql", {"query": "INSERT INTO main.staging.temp VALUES (1)"}), + ("execute_sql", {"query": "DROP TABLE main.staging.old_temp"}), + ("list_clusters", {}), + ("delete_app", {"name": "test-app"}), + ] + + for tool_name, tool_params in tool_calls: + action = Action.from_tool_call( + tool_name=tool_name, + tool_params=tool_params, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + result = _run_checkpoint_pipeline(action, session, policy_engine) + action.checkpoint_result = result + + if result.decision in (CheckpointDecision.BLOCK, CheckpointDecision.WOULD_BLOCK): + action.final_decision = "would_block" + elif result.decision == CheckpointDecision.FLAG: + action.final_decision = "flagged" + else: + action.final_decision = "executed" + action.execution_success = True + session.record_action(action) + + completed = stop_session() + + assert completed is not None + assert completed.status == SessionStatus.COMPLETED + assert completed.action_count == 5 + assert completed.would_block_count >= 0 + + def test_dangerous_sql_recorded_as_would_block(self): + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.MONITOR_ONLY, + description="monitor-dangerous-sql", + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "TRUNCATE TABLE main.production.billing"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + result = _run_checkpoint_pipeline(action, session, policy_engine) + action.checkpoint_result = result + action.final_decision = "would_block" + action.execution_success = True + session.record_action(action) + + assert result.decision == CheckpointDecision.WOULD_BLOCK + assert result.blocking_checkpoint == "CP-1: Policy" + assert session.would_block_count == 1 + + stop_session() + + +@pytest.mark.integration +class TestFullPipelineEnforceMode: + """Tests for pipeline behavior in ENFORCE.""" + + def test_safe_read_passes_all_checkpoints(self): + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="enforce-safe-read", + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT count(*) FROM main.staging.temp"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + result = _run_checkpoint_pipeline(action, session, policy_engine) + action.checkpoint_result = result + action.final_decision = "executed" + action.execution_success = True + session.record_action(action) + + assert result.decision == CheckpointDecision.AUTO_APPROVE + assert result.risk_score < 35 + + stop_session() + + def test_drop_database_hard_blocked(self): + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="enforce-drop-db", + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP DATABASE production"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + result = _run_checkpoint_pipeline(action, session, policy_engine) + action.checkpoint_result = result + action.final_decision = "blocked" + session.record_action(action) + + assert result.decision == CheckpointDecision.BLOCK + assert result.blocking_checkpoint == "CP-1: Policy" + assert session.blocked_count == 1 + + stop_session() + + def test_delete_app_flagged_for_approval(self): + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="enforce-delete-app", + ) + + action = Action.from_tool_call( + tool_name="delete_app", + tool_params={"name": "critical-production-app"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + result = _run_checkpoint_pipeline(action, session, policy_engine) + + assert result.decision == CheckpointDecision.FLAG + assert result.blocking_checkpoint is not None + + stop_session() + + +@pytest.mark.integration +class TestFullPipelineWithScope: + """Tests for pipeline with scope attached.""" + + def test_in_scope_read_passes(self): + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="scope-in", + ) + session.scope = ScopeManifest( + tables=ResourceScope( + read=["main.analytics.*"], + write=["main.staging.temp_*"], + ), + default_deny=True, + max_actions=100, + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.analytics.orders"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + action.target_resource_id = "main.analytics.orders" + + result = _run_checkpoint_pipeline(action, session, policy_engine) + + assert result.decision == CheckpointDecision.AUTO_APPROVE + + stop_session() + + def test_out_of_scope_read_blocked(self): + """Should block reads outside allowed table patterns.""" + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="scope-out", + ) + session.scope = ScopeManifest( + tables=ResourceScope( + read=["main.analytics.orders"], + ), + default_deny=True, + max_actions=100, + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.analytics.customer_pii"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + action.target_resource_id = "main.analytics.customer_pii" + + result = _run_checkpoint_pipeline(action, session, policy_engine) + + assert result.decision == CheckpointDecision.BLOCK + assert result.blocking_checkpoint == "CP-2: Scope" + assert "not in scope" in result.scope_violation + + stop_session() + + def test_scope_limit_blocks_after_max_actions(self): + """Should block once session action count reaches max_actions.""" + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="scope-limit", + ) + session.scope = ScopeManifest(max_actions=3) + + for i in range(3): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": f"SELECT {i}"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + action.checkpoint_result = CheckpointResult() + action.final_decision = "executed" + session.record_action(action) + + action_4 = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT 3"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + + result = _run_checkpoint_pipeline(action_4, session, policy_engine) + + assert result.decision == CheckpointDecision.BLOCK + assert "max_actions" in result.block_reason + + stop_session() + + +@pytest.mark.integration +class TestRiskEscalation: + """Tests for risk overriding policy auto-approve.""" + + def test_policy_approve_risk_flag_escalates(self): + """Should escalate to FLAG or HOLD when risk is high despite policy approve.""" + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="risk-escalation", + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO main.production.billing VALUES (1, 100)"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + + result = _run_checkpoint_pipeline(action, session, policy_engine) + + assert result.policy_result == CheckpointDecision.AUTO_APPROVE.value + assert result.risk_score >= 35 + assert result.decision in (CheckpointDecision.FLAG, CheckpointDecision.HOLD_FOR_APPROVAL) + + stop_session() + + def test_policy_block_not_downgraded_by_low_risk(self): + """Should keep policy BLOCK even when environment risk is lower.""" + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="no-downgrade", + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "TRUNCATE TABLE main.staging.small_table"}, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + + result = _run_checkpoint_pipeline(action, session, policy_engine) + + assert result.decision == CheckpointDecision.BLOCK + assert result.blocking_checkpoint == "CP-1: Policy" + + stop_session() + + +@pytest.mark.integration +class TestSessionSummaryAccuracy: + """Tests for session.summary() after pipeline runs.""" + + def test_summary_counts_match(self): + policy_engine = PolicyEngine() + + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="summary-test", + ) + + actions_spec = [ + ("execute_sql", {"query": "SELECT 1"}), + ("execute_sql", {"query": "SELECT 2"}), + ("execute_sql", {"query": "DROP DATABASE test_db"}), + ("delete_app", {"name": "test-app"}), + ] + + for tool_name, tool_params in actions_spec: + action = Action.from_tool_call( + tool_name=tool_name, + tool_params=tool_params, + task_id=session.task_id, + agent_id="test-agent", + sequence=session.next_sequence(), + ) + result = _run_checkpoint_pipeline(action, session, policy_engine) + action.checkpoint_result = result + + if result.decision == CheckpointDecision.BLOCK: + action.final_decision = "blocked" + elif result.decision == CheckpointDecision.FLAG: + action.final_decision = "flagged" + else: + action.final_decision = "executed" + session.record_action(action) + + summary = session.summary() + + assert "Actions: 4" in summary + assert "Blocked: 1" in summary + + completed = stop_session() + assert completed.action_count == 4 + assert completed.blocked_count == 1 diff --git a/databricks-tools-core/tests/integration/agentguard/test_policy_enforcement.py b/databricks-tools-core/tests/integration/agentguard/test_policy_enforcement.py new file mode 100644 index 00000000..7f238161 --- /dev/null +++ b/databricks-tools-core/tests/integration/agentguard/test_policy_enforcement.py @@ -0,0 +1,561 @@ +""" +Integration tests for AgentGuard policy enforcement. + +Tests: +- destructive SQL, grants/revokes, deletions, monitor vs enforce, multi-SQL +- Action.from_tool_call and real MCP parameter names +""" + +import pytest + +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + AgentGuardMode, + CheckpointDecision, +) +from databricks_tools_core.agentguard.policy import PolicyEngine + + +@pytest.mark.integration +class TestDestructiveSQLBlocked: + """Tests for blocking or flagging dangerous SQL.""" + + def test_drop_database_blocked(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP DATABASE analytics_prod"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, rule = policy_engine.check(action, AgentGuardMode.ENFORCE) + + assert decision == CheckpointDecision.BLOCK + assert rule is not None + assert "blocked" in rule.lower() + + def test_drop_schema_blocked(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP SCHEMA main.customer_data"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_truncate_table_blocked(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "TRUNCATE TABLE main.billing.records"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_drop_table_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP TABLE main.analytics.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, rule = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert "approval" in rule.lower() + + def test_delete_from_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DELETE FROM main.warehouse.users WHERE id = 5"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_create_or_replace_table_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "CREATE OR REPLACE TABLE main.staging.temp AS SELECT 1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_drop_table_if_exists_still_caught(self, policy_engine): + """Should flag DROP TABLE after IF EXISTS is stripped.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP TABLE IF EXISTS main.analytics.temp"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + +@pytest.mark.integration +class TestPrivilegeEscalationCaught: + """Tests for grant/revoke policy behavior.""" + + def test_grant_all_privileges_blocked(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "GRANT ALL PRIVILEGES ON CATALOG main TO `agent-sp`"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_revoke_blocked(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "REVOKE ALL ON main.prod.customers FROM analyst"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_narrow_grant_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "GRANT SELECT ON main.analytics.table TO analyst"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_uc_grants_tool_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="manage_uc_grants", + tool_params={"action": "grant", "privilege": "SELECT", "principal": "user@test.com"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + +@pytest.mark.integration +class TestResourceDeletionCaught: + """Tests for non-SQL destructive tools.""" + + def test_delete_app_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="delete_app", + tool_params={"name": "my-production-app"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert action.operation == "DELETE_APP" + + def test_delete_pipeline_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="delete_pipeline", + tool_params={"pipeline_id": "abc-123"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_delete_volume_file_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="delete_volume_file", + tool_params={"volume_path": "/Volumes/main/data/configs/rules.json"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert action.target_resource_id == "/Volumes/main/data/configs/rules.json" + + def test_delete_tracked_resource_job_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="delete_tracked_resource", + tool_params={"type": "job", "resource_id": "job-42"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert action.operation == "DELETE_JOB" + + def test_manage_jobs_delete_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="manage_jobs", + tool_params={"action": "delete", "job_id": "12345"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert action.operation == "DELETE_JOB" + + def test_code_execution_flagged(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_databricks_command", + tool_params={"cluster_id": "0123-abc", "code": "print('hello')"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert action.operation == "EXECUTE_CODE" + + +@pytest.mark.integration +class TestMonitorModeNeverBlocks: + """Tests for MONITOR_ONLY policy decisions.""" + + def test_drop_database_would_block(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP DATABASE production_db"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.MONITOR_ONLY) + assert decision == CheckpointDecision.WOULD_BLOCK + + def test_truncate_would_block(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "TRUNCATE TABLE important_data"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.MONITOR_ONLY) + assert decision == CheckpointDecision.WOULD_BLOCK + + def test_delete_app_would_block(self, policy_engine): + """Should record FLAG or WOULD_BLOCK for delete_app in monitor mode.""" + action = Action.from_tool_call( + tool_name="delete_app", + tool_params={"name": "my-app"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.MONITOR_ONLY) + assert decision in (CheckpointDecision.FLAG, CheckpointDecision.WOULD_BLOCK) + + +@pytest.mark.integration +class TestSafeOperationsAllowed: + """Tests for operations that policy auto-approves.""" + + def test_select_auto_approved(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.analytics.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_insert_auto_approved(self, policy_engine): + """Should auto-approve INSERT statements.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO staging.temp VALUES (1, 'test')"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_read_only_tools_auto_approved(self, policy_engine): + for tool in ("list_clusters", "get_app", "list_warehouses", "get_current_user"): + action = Action.from_tool_call( + tool_name=tool, + tool_params={}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE, f"{tool} should be auto-approved" + + +@pytest.mark.integration +class TestMultiSQLBypassPrevention: + """Tests for execute_sql_multi policy coverage.""" + + def test_hidden_drop_in_multi_sql(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={ + "sqls": [ + "SELECT * FROM orders", + "DROP TABLE main.analytics.orders", + ] + }, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_hidden_truncate_in_multi_sql(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={ + "sqls": [ + "SELECT count(*) FROM billing", + "TRUNCATE TABLE main.billing.records", + "SELECT 1", + ] + }, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + +@pytest.mark.integration +class TestActionClassificationAccuracy: + """Tests for Action.from_tool_call classifications.""" + + def test_sql_select_classified_as_read(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM test_table"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.READ + assert action.operation == "SELECT" + + def test_sql_drop_classified_as_ddl(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP TABLE main.staging.temp"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DDL + assert action.operation == "DROP" + + def test_delete_app_classified_as_admin(self): + action = Action.from_tool_call( + tool_name="delete_app", + tool_params={"name": "my-app"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "DELETE_APP" + assert action.target_resource_type == "app" + assert action.target_resource_id == "my-app" + + def test_start_cluster_classified_as_write(self): + action = Action.from_tool_call( + tool_name="start_cluster", + tool_params={"cluster_id": "0123-abc"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.WRITE + assert action.operation == "START_CLUSTER" + + def test_grant_sql_classified_as_dcl(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "GRANT SELECT ON TABLE main.data.t TO user@test.com"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DCL + assert action.operation == "GRANT" + + def test_execute_code_classified_as_admin(self): + action = Action.from_tool_call( + tool_name="execute_databricks_command", + tool_params={"cluster_id": "abc", "code": "print(1)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "EXECUTE_CODE" + assert action.target_resource_type == "code_execution" + + def test_publish_dashboard_classified_as_write(self): + action = Action.from_tool_call( + tool_name="publish_dashboard", + tool_params={"dashboard_id": "dash-42", "publish": True}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.WRITE + assert action.operation == "PUBLISH" + assert action.target_resource_id == "dash-42" + + def test_delete_lakebase_sync_has_resource_id(self): + action = Action.from_tool_call( + tool_name="delete_lakebase_sync", + tool_params={"table_name": "my_sync_table"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.operation == "DELETE_LAKEBASE_SYNC" + assert action.target_resource_id == "my_sync_table" + + +@pytest.mark.integration +class TestRealMCPParamNames: + """Tests for MCP-style parameter names on from_tool_call.""" + + def test_sql_query_param_classified_correctly(self, policy_engine): + """Should parse sql_query like execute_sql from MCP.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"sql_query": "DROP TABLE main.analytics.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DDL + assert action.operation == "DROP" + assert action.sql_statement == "DROP TABLE main.analytics.orders" + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_sql_query_select_classified_as_read(self, policy_engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"sql_query": "SELECT * FROM main.db.table LIMIT 10"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.READ + assert action.operation == "SELECT" + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_sql_content_multi_classified_correctly(self, policy_engine): + """Should parse sql_content for execute_sql_multi.""" + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={"sql_content": "SELECT 1; DROP TABLE main.analytics.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DDL + assert "DROP" in action.sql_statement + + decision, _ = policy_engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_genie_space_id_extracted(self): + """Should read space_id for delete_genie.""" + action = Action.from_tool_call( + tool_name="delete_genie", + tool_params={"space_id": "genie-abc-123"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.target_resource_id == "genie-abc-123" + + def test_display_name_extracted(self): + """Should use display_name for dashboard create tools.""" + action = Action.from_tool_call( + tool_name="create_or_update_dashboard", + tool_params={"display_name": "Revenue Dashboard", "parent_path": "/"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.target_resource_id == "Revenue Dashboard" + + def test_full_name_extracted(self): + """Should use full_name for manage_uc_grants.""" + action = Action.from_tool_call( + tool_name="manage_uc_grants", + tool_params={"action": "grant", "full_name": "main.prod.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.target_resource_id == "main.prod.orders" + + def test_cluster_id_extracted(self): + """Should use cluster_id for start_cluster.""" + action = Action.from_tool_call( + tool_name="start_cluster", + tool_params={"cluster_id": "0123-abc-xyz"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.target_resource_id == "0123-abc-xyz" diff --git a/databricks-tools-core/tests/integration/agentguard/test_risk_scoring.py b/databricks-tools-core/tests/integration/agentguard/test_risk_scoring.py new file mode 100644 index 00000000..c049f795 --- /dev/null +++ b/databricks-tools-core/tests/integration/agentguard/test_risk_scoring.py @@ -0,0 +1,227 @@ +""" +Integration tests for AgentGuard risk scoring. + +Tests: +- compute_risk_score on realistic actions, breakdown, thresholds, scope penalty +""" + +import pytest + +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + CheckpointDecision, +) +from databricks_tools_core.agentguard.risk import compute_risk_score + + +@pytest.mark.integration +class TestRiskScoreComputation: + """Tests for end-to-end risk scores.""" + + def test_select_low_risk(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.analytics.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + + assert risk.score < 35, f"SELECT should be low risk, got {risk.score}" + assert risk.decision == CheckpointDecision.AUTO_APPROVE + + def test_drop_table_high_risk(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP TABLE main.production.billing"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + + assert risk.score >= 35, f"DROP TABLE on production should be high risk, got {risk.score}" + assert risk.decision in ( + CheckpointDecision.FLAG, + CheckpointDecision.HOLD_FOR_APPROVAL, + CheckpointDecision.BLOCK, + ) + + def test_insert_into_production_flagged(self): + """Should score INSERT into production above the low-risk band.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO main.production.billing VALUES (1, 100)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + + assert risk.score >= 35, f"INSERT into production should be flaggable, got {risk.score}" + + def test_update_without_where_high_risk(self): + """Should treat unbounded UPDATE on production as high risk.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "UPDATE main.production.users SET status = 'inactive'"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + + assert risk.score >= 35, f"Unbounded UPDATE on production should flag, got {risk.score}" + + def test_read_only_tool_low_risk(self): + action = Action.from_tool_call( + tool_name="list_clusters", + tool_params={}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + assert risk.score < 35 + assert risk.decision == CheckpointDecision.AUTO_APPROVE + + def test_delete_app_moderate_risk(self): + action = Action.from_tool_call( + tool_name="delete_app", + tool_params={"name": "my-staging-app"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + + assert risk.score > 20, f"DELETE_APP should have meaningful risk, got {risk.score}" + + def test_execute_code_high_risk(self): + action = Action.from_tool_call( + tool_name="execute_databricks_command", + tool_params={"cluster_id": "prod-cluster", "code": "df.write.mode('overwrite').saveAsTable('t')"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + assert risk.score >= 35, f"Code execution should be flaggable, got {risk.score}" + + +@pytest.mark.integration +class TestRiskBreakdown: + """Tests for risk breakdown keys and environment signal.""" + + def test_breakdown_has_all_factors(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP TABLE main.production.billing"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + + expected_factors = { + "action_type", + "environment", + "blast_radius", + "time_context", + "behavioral", + "data_sensitivity", + } + assert expected_factors.issubset(set(risk.breakdown.keys())), ( + f"Missing factors: {expected_factors - set(risk.breakdown.keys())}" + ) + + def test_production_environment_scores_high(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.production.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + + assert risk.breakdown.get("environment", 0) >= 60 + + +@pytest.mark.integration +class TestRiskThresholds: + """Tests for threshold-driven decisions.""" + + def test_auto_approve_below_35(self): + """Should auto-approve trivial reads.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT 1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + assert risk.score < 35 + assert risk.decision == CheckpointDecision.AUTO_APPROVE + + def test_truncate_production_very_high_risk(self): + """Should score TRUNCATE on production very high.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "TRUNCATE TABLE main.production.billing_records"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk = compute_risk_score(action) + assert risk.score >= 50, f"TRUNCATE production should be very high risk, got {risk.score}" + + +@pytest.mark.integration +class TestScopeViolationPenalty: + """Tests for scope_violated scoring.""" + + def test_scope_violation_increases_score(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.staging.temp"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk_normal = compute_risk_score(action, scope_violated=False) + risk_violated = compute_risk_score(action, scope_violated=True) + + assert risk_violated.score > risk_normal.score, ( + f"Scope violation should increase risk: normal={risk_normal.score}, violated={risk_violated.score}" + ) + + def test_scope_violation_penalty_is_25_points(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT 1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + risk_normal = compute_risk_score(action, scope_violated=False) + risk_violated = compute_risk_score(action, scope_violated=True) + + penalty = risk_violated.score - risk_normal.score + assert abs(penalty - 25) < 1, f"Expected ~25 point penalty, got {penalty}" diff --git a/databricks-tools-core/tests/integration/agentguard/test_scope_enforcement.py b/databricks-tools-core/tests/integration/agentguard/test_scope_enforcement.py new file mode 100644 index 00000000..df9f6a3e --- /dev/null +++ b/databricks-tools-core/tests/integration/agentguard/test_scope_enforcement.py @@ -0,0 +1,323 @@ +""" +Integration tests for AgentGuard scope enforcement. + +Tests: +- templates, in/out of scope, default_deny, limits, monitor mode +""" + +import pytest + +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + AgentGuardMode, + CheckpointDecision, +) +from databricks_tools_core.agentguard.scope import ( + ScopeManifest, + ResourceScope, + check_scope, + check_scope_limits, + list_templates, + load_template, +) + + +@pytest.mark.integration +class TestScopeTemplateLoading: + """Tests for loading scope templates.""" + + def test_list_templates_returns_known_templates(self): + templates = list_templates() + assert "etl_pipeline_fix" in templates + assert "data_quality_check" in templates + assert "model_deployment" in templates + + def test_load_etl_template_with_variables(self): + scope = load_template( + "etl_pipeline_fix", + { + "catalog": "main", + "target_table": "customer_orders", + }, + ) + + assert isinstance(scope, ScopeManifest) + assert scope.default_deny is True + assert scope.max_actions is not None + assert scope.max_write_actions is not None + + assert any("customer_orders" in p for p in scope.tables.read) + + def test_load_data_quality_template(self): + scope = load_template( + "data_quality_check", + { + "catalog": "main", + "schema": "production", + "target_table": "customer_orders", + }, + ) + + assert isinstance(scope, ScopeManifest) + assert scope.default_deny is True + + def test_load_model_deployment_template(self): + scope = load_template( + "model_deployment", + { + "catalog": "ml_catalog", + "endpoint_name": "fraud_detection_v2", + }, + ) + + assert isinstance(scope, ScopeManifest) + + def test_missing_variables_raises(self): + with pytest.raises(ValueError, match="unresolved variables"): + load_template("etl_pipeline_fix", {}) + + def test_unknown_template_raises(self): + with pytest.raises(FileNotFoundError, match="not found"): + load_template("nonexistent_template_xyz", {}) + + +@pytest.mark.integration +class TestScopeMatchingInScope: + """Tests for actions inside the declared scope.""" + + @pytest.fixture + def etl_scope(self): + return load_template( + "etl_pipeline_fix", + { + "catalog": "main", + "target_table": "customer_orders", + }, + ) + + def test_read_allowed_table_in_scope(self, etl_scope): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.production.customer_orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + action.target_resource_id = "main.production.customer_orders" + action.action_category = ActionCategory.READ + + decision, violation = check_scope(action, etl_scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + assert violation is None + + +@pytest.mark.integration +class TestScopeMatchingOutOfScope: + """Tests for actions outside the declared scope.""" + + @pytest.fixture + def strict_scope(self): + return ScopeManifest( + tables=ResourceScope( + read=["main.analytics.orders"], + write=["main.staging.temp_*"], + ), + default_deny=True, + max_actions=50, + max_write_actions=10, + ) + + def test_read_pii_table_blocked(self, strict_scope): + """Should block reads outside allowed table patterns.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.analytics.customer_pii_data"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + action.target_resource_id = "main.analytics.customer_pii_data" + action.action_category = ActionCategory.READ + + decision, violation = check_scope(action, strict_scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + assert "not in scope" in violation + + def test_write_to_non_matching_table_blocked(self, strict_scope): + """Should block writes that do not match write patterns.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO main.production.billing VALUES (1)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + action.target_resource_id = "main.production.billing" + action.action_category = ActionCategory.WRITE + + decision, violation = check_scope(action, strict_scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_write_to_matching_table_allowed(self, strict_scope): + """Should allow writes matching temp_* under staging.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO main.staging.temp_results VALUES (1)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + action.target_resource_id = "main.staging.temp_results" + action.action_category = ActionCategory.WRITE + + decision, _ = check_scope(action, strict_scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_create_unrelated_job_blocked(self, strict_scope): + """Should pass or block job create depending on whether jobs are in scope.""" + action = Action.from_tool_call( + tool_name="manage_jobs", + tool_params={"action": "create", "name": "unrelated-etl-job"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = check_scope(action, strict_scope, AgentGuardMode.ENFORCE) + assert decision in (CheckpointDecision.AUTO_APPROVE, CheckpointDecision.BLOCK) + + +@pytest.mark.integration +class TestScopeMonitorMode: + """Tests for scope decisions in MONITOR_ONLY.""" + + def test_out_of_scope_would_block_in_monitor(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.allowed.*"]), + default_deny=True, + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.forbidden.secrets"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + action.target_resource_id = "main.forbidden.secrets" + action.action_category = ActionCategory.READ + + decision, violation = check_scope(action, scope, AgentGuardMode.MONITOR_ONLY) + assert decision == CheckpointDecision.WOULD_BLOCK + assert "[monitor]" in violation + + +@pytest.mark.integration +class TestDefaultDenyBehavior: + """Tests for default_deny on table scope.""" + + def test_default_deny_false_allows_unspecified(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.analytics.*"]), + default_deny=False, + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO main.analytics.temp VALUES (1)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + action.target_resource_id = "main.analytics.temp" + action.action_category = ActionCategory.WRITE + + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_default_deny_true_blocks_unspecified(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.analytics.*"]), + default_deny=True, + ) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO main.analytics.temp VALUES (1)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + action.target_resource_id = "main.analytics.temp" + action.action_category = ActionCategory.WRITE + + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + +@pytest.mark.integration +class TestScopeLimits: + """Tests for max_actions and max_write_actions.""" + + def test_max_actions_exceeded(self): + scope = ScopeManifest(max_actions=5, max_write_actions=3) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT 1"}, + task_id="t1", + agent_id="a1", + sequence=6, + ) + + decision, violation = check_scope_limits( + action, + scope, + session_action_count=5, + session_write_count=0, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.BLOCK + assert "max_actions" in violation + + def test_max_write_actions_exceeded(self): + scope = ScopeManifest(max_actions=100, max_write_actions=3) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO test VALUES (1)"}, + task_id="t1", + agent_id="a1", + sequence=4, + ) + + decision, violation = check_scope_limits( + action, + scope, + session_action_count=10, + session_write_count=3, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.BLOCK + assert "max_write_actions" in violation + + def test_within_limits_passes(self): + scope = ScopeManifest(max_actions=100, max_write_actions=20) + + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT 1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + + decision, _ = check_scope_limits( + action, + scope, + session_action_count=5, + session_write_count=2, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.AUTO_APPROVE diff --git a/databricks-tools-core/tests/integration/agentguard/test_session_lifecycle.py b/databricks-tools-core/tests/integration/agentguard/test_session_lifecycle.py new file mode 100644 index 00000000..ce36a908 --- /dev/null +++ b/databricks-tools-core/tests/integration/agentguard/test_session_lifecycle.py @@ -0,0 +1,194 @@ +""" +Integration tests for AgentGuard session lifecycle. + +Tests: +- start_session, stop_session, get_session_status +- actions and scope templates +""" + +import pytest + +from databricks_tools_core.agentguard.context import ( + clear_active_session, + get_active_session, + has_active_session, +) +from databricks_tools_core.agentguard.models import ( + Action, + AgentGuardMode, + SessionStatus, +) +from databricks_tools_core.agentguard.session import ( + get_session_status, + start_session, + stop_session, +) + + +@pytest.mark.integration +class TestSessionStartStop: + """Tests for starting and stopping sessions.""" + + def test_start_creates_active_session(self): + """Should create an active monitor session with metadata.""" + session = start_session( + mode=AgentGuardMode.MONITOR_ONLY, + description="lifecycle-test", + agent_id="test-agent", + user_id="test-user", + ) + + assert session.status == SessionStatus.ACTIVE + assert session.mode == AgentGuardMode.MONITOR_ONLY + assert session.description == "lifecycle-test" + assert session.task_id.startswith("task_") + assert has_active_session() + + def test_stop_completes_session(self): + """Should complete the session and clear the active handle.""" + start_session(mode=AgentGuardMode.MONITOR_ONLY, description="stop-test") + + session = stop_session() + + assert session is not None + assert session.status == SessionStatus.COMPLETED + assert session.completed_at is not None + assert not has_active_session() + + def test_stop_returns_none_when_no_session(self): + """Should return None when no session is active.""" + result = stop_session() + assert result is None + + def test_double_start_raises(self): + """Should reject starting a second session while one is active.""" + start_session(mode=AgentGuardMode.MONITOR_ONLY, description="first") + + with pytest.raises(ValueError, match="already active"): + start_session(mode=AgentGuardMode.MONITOR_ONLY, description="second") + + def test_enforce_mode_session(self): + """Should start in enforce mode.""" + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="enforce-test", + ) + + assert session.mode == AgentGuardMode.ENFORCE + + def test_session_status_string(self): + """Should expose a human-readable status string.""" + session = start_session( + mode=AgentGuardMode.MONITOR_ONLY, + description="status-test", + ) + + status = get_session_status() + + assert status is not None + assert "monitor-only" in status + assert session.task_id in status + + +@pytest.mark.integration +class TestSessionWithActions: + """Tests for recording actions on a session.""" + + def test_actions_recorded_in_session(self, monitor_session): + """Should append actions to the active session.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT 1"}, + task_id=monitor_session.task_id, + agent_id="test-agent", + sequence=monitor_session.next_sequence(), + ) + action.final_decision = "executed" + monitor_session.record_action(action) + + assert monitor_session.action_count == 1 + assert monitor_session.actions[0].tool_name == "execute_sql" + + def test_session_summary_after_actions(self, monitor_session): + """Should reflect action count in summary().""" + for i in range(3): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": f"SELECT {i}"}, + task_id=monitor_session.task_id, + agent_id="test-agent", + sequence=monitor_session.next_sequence(), + ) + action.final_decision = "executed" + monitor_session.record_action(action) + + summary = monitor_session.summary() + assert "Actions: 3" in summary + + def test_write_count_tracks_mutations(self, monitor_session): + """Should count writes separately from reads.""" + read_action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM test"}, + task_id=monitor_session.task_id, + agent_id="test-agent", + sequence=monitor_session.next_sequence(), + ) + monitor_session.record_action(read_action) + + write_action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO test VALUES (1)"}, + task_id=monitor_session.task_id, + agent_id="test-agent", + sequence=monitor_session.next_sequence(), + ) + monitor_session.record_action(write_action) + + assert monitor_session.write_count == 1 + assert monitor_session.action_count == 2 + + def test_session_persists_across_calls(self): + """Should reuse the same session object from context helpers.""" + start_session( + mode=AgentGuardMode.MONITOR_ONLY, + description="persist-test", + ) + + session = get_active_session() + assert session is not None + assert session.description == "persist-test" + + session2 = get_active_session() + assert session2 is session + + +@pytest.mark.integration +class TestSessionWithScopeTemplate: + """Tests for scope templates on session start.""" + + def test_start_with_scope_template(self): + """Should attach scope when a template loads successfully.""" + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="scope-template-test", + scope_template="etl_pipeline_fix", + scope_variables={ + "catalog": "main", + "target_table": "customer_orders", + }, + ) + + assert session.scope is not None + assert session.scope_template == "etl_pipeline_fix" + + def test_invalid_scope_template_warns_but_continues(self): + """Should start with scope=None when the template name is unknown.""" + session = start_session( + mode=AgentGuardMode.ENFORCE, + description="bad-template-test", + scope_template="nonexistent_template_xyz", + ) + + assert session.scope is None + assert session.status == SessionStatus.ACTIVE diff --git a/databricks-tools-core/tests/unit/agentguard/__init__.py b/databricks-tools-core/tests/unit/agentguard/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/databricks-tools-core/tests/unit/agentguard/test_context.py b/databricks-tools-core/tests/unit/agentguard/test_context.py new file mode 100644 index 00000000..c6d4299a --- /dev/null +++ b/databricks-tools-core/tests/unit/agentguard/test_context.py @@ -0,0 +1,53 @@ +""" +Unit tests for AgentGuard session context helpers. + +Tests: +- set_active_session, get_active_session, has_active_session, clear_active_session +""" + +from databricks_tools_core.agentguard.context import ( + clear_active_session, + get_active_session, + has_active_session, + set_active_session, +) +from databricks_tools_core.agentguard.models import AgentGuardSession + + +class TestSessionContext: + """Tests for thread-local active session helpers.""" + + def setup_method(self): + clear_active_session() + + def teardown_method(self): + clear_active_session() + + def test_no_session_initially(self): + assert get_active_session() is None + assert has_active_session() is False + + def test_set_and_get(self): + session = AgentGuardSession() + set_active_session(session) + assert get_active_session() is session + assert has_active_session() is True + + def test_clear(self): + session = AgentGuardSession() + set_active_session(session) + clear_active_session() + assert get_active_session() is None + assert has_active_session() is False + + def test_replace_session(self): + s1 = AgentGuardSession() + s2 = AgentGuardSession() + set_active_session(s1) + set_active_session(s2) + assert get_active_session() is s2 + + def test_clear_when_empty_is_safe(self): + clear_active_session() + clear_active_session() + assert get_active_session() is None diff --git a/databricks-tools-core/tests/unit/agentguard/test_models.py b/databricks-tools-core/tests/unit/agentguard/test_models.py new file mode 100644 index 00000000..86bee4c3 --- /dev/null +++ b/databricks-tools-core/tests/unit/agentguard/test_models.py @@ -0,0 +1,387 @@ +""" +Unit tests for AgentGuard core models. + +Tests: +- Action.from_tool_call (SQL, read-only, delete, action-based, writes, fallback) +- CheckpointResult, AgentGuardSession +""" + +import pytest +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + AgentGuardMode, + AgentGuardSession, + CheckpointDecision, + CheckpointResult, + SessionStatus, +) + + +class TestActionSQL: + def test_execute_sql_select(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.prod.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.READ + assert action.operation == "SELECT" + assert action.target_resource_type == "sql" + assert action.sql_statement == "SELECT * FROM main.prod.orders" + + def test_execute_sql_insert(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO staging.t VALUES (1)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.WRITE + assert action.operation == "INSERT" + + def test_execute_sql_drop(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP TABLE main.prod.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DDL + assert action.operation == "DROP" + + def test_execute_sql_grant(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "GRANT SELECT ON TABLE main.t TO user1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DCL + assert action.operation == "GRANT" + + def test_execute_sql_empty_query(self): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": ""}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.READ + assert action.operation == "SQL_EMPTY" + + def test_execute_sql_multi_batch(self): + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={"sqls": ["SELECT 1", "DROP TABLE t"]}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DDL + assert action.operation == "MULTI(DROP)" + assert "SELECT 1; DROP TABLE t" in action.sql_statement + + def test_execute_sql_multi_empty(self): + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={"sqls": []}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.READ + assert action.operation == "SQL_MULTI_EMPTY" + + +class TestActionReadOnly: + @pytest.mark.parametrize( + "tool_name", + [ + "get_table_details", + "list_warehouses", + "get_best_warehouse", + "list_clusters", + "get_cluster_status", + "get_app", + "get_dashboard", + "get_current_user", + "list_volume_files", + ], + ) + def test_read_only_tools(self, tool_name): + action = Action.from_tool_call( + tool_name=tool_name, + tool_params={}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.READ + assert action.operation == "READ" + + +class TestActionDelete: + def test_delete_app(self): + action = Action.from_tool_call( + tool_name="delete_app", + tool_params={"name": "my-app"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "DELETE_APP" + assert action.target_resource_type == "app" + assert action.target_resource_id == "my-app" + + def test_delete_pipeline(self): + action = Action.from_tool_call( + tool_name="delete_pipeline", + tool_params={"pipeline_id": "p-123"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "DELETE_PIPELINE" + assert action.target_resource_id == "p-123" + + def test_delete_volume_file(self): + action = Action.from_tool_call( + tool_name="delete_volume_file", + tool_params={"volume_path": "/Volumes/cat/sch/vol/file.csv"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "DELETE_FILE" + assert action.target_resource_id == "/Volumes/cat/sch/vol/file.csv" + + def test_delete_tracked_resource_dynamic_type(self): + action = Action.from_tool_call( + tool_name="delete_tracked_resource", + tool_params={"type": "job", "resource_id": "j-456"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "DELETE_JOB" + assert action.target_resource_type == "job" + assert action.target_resource_id == "j-456" + + +class TestActionBased: + def test_manage_jobs_create(self): + action = Action.from_tool_call( + tool_name="manage_jobs", + tool_params={"action": "create", "name": "etl-daily"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.WRITE + assert action.operation == "CREATE" + assert action.target_resource_type == "job" + + def test_manage_jobs_delete(self): + action = Action.from_tool_call( + tool_name="manage_jobs", + tool_params={"action": "delete", "job_id": "j-1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "DELETE_JOB" + + def test_manage_jobs_get(self): + action = Action.from_tool_call( + tool_name="manage_jobs", + tool_params={"action": "get", "job_id": "j-1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.READ + assert action.operation == "GET" + + def test_manage_uc_grants_grant(self): + action = Action.from_tool_call( + tool_name="manage_uc_grants", + tool_params={"action": "grant", "name": "main.prod"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.DCL + assert action.operation == "GRANT" + + +class TestActionStandaloneWrite: + def test_start_cluster(self): + action = Action.from_tool_call( + tool_name="start_cluster", + tool_params={"name": "my-cluster"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.WRITE + assert action.operation == "START_CLUSTER" + + def test_upload_file(self): + action = Action.from_tool_call( + tool_name="upload_file", + tool_params={"name": "data.csv"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.WRITE + assert action.operation == "UPLOAD" + + def test_execute_databricks_command(self): + action = Action.from_tool_call( + tool_name="execute_databricks_command", + tool_params={"cluster_id": "c-123"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.ADMIN + assert action.operation == "EXECUTE_CODE" + assert action.target_resource_type == "code_execution" + assert action.target_resource_id == "c-123" + + +class TestActionFallback: + def test_unknown_tool(self): + action = Action.from_tool_call( + tool_name="some_new_tool_not_mapped", + tool_params={}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + assert action.action_category == ActionCategory.UNKNOWN + assert action.operation == "SOME_NEW_TOOL_NOT_MAPPED" + + +class TestCheckpointResult: + def test_blocked_only_on_block(self): + assert CheckpointResult(decision=CheckpointDecision.BLOCK).blocked is True + + def test_not_blocked_on_would_block(self): + assert CheckpointResult(decision=CheckpointDecision.WOULD_BLOCK).blocked is False + + def test_not_blocked_on_flag(self): + assert CheckpointResult(decision=CheckpointDecision.FLAG).blocked is False + + def test_not_blocked_on_auto_approve(self): + assert CheckpointResult(decision=CheckpointDecision.AUTO_APPROVE).blocked is False + + def test_total_overhead(self): + from databricks_tools_core.agentguard.models import TimingRecord + + result = CheckpointResult( + timings=[ + TimingRecord(name="cp1", duration_ms=1.5), + TimingRecord(name="cp2", duration_ms=2.3), + ] + ) + assert abs(result.total_overhead_ms - 3.8) < 0.01 + + +class TestSession: + def test_sequence_counter(self): + session = AgentGuardSession() + assert session.next_sequence() == 1 + assert session.next_sequence() == 2 + assert session.next_sequence() == 3 + + def test_action_count(self): + session = AgentGuardSession() + assert session.action_count == 0 + action = Action(tool_name="test") + session.record_action(action) + assert session.action_count == 1 + + def test_write_count_incremented_for_writes(self): + session = AgentGuardSession() + read_action = Action(tool_name="t", action_category=ActionCategory.READ) + write_action = Action(tool_name="t", action_category=ActionCategory.WRITE) + ddl_action = Action(tool_name="t", action_category=ActionCategory.DDL) + admin_action = Action(tool_name="t", action_category=ActionCategory.ADMIN) + + session.record_action(read_action) + assert session.write_count == 0 + + session.record_action(write_action) + assert session.write_count == 1 + + session.record_action(ddl_action) + assert session.write_count == 2 + + session.record_action(admin_action) + assert session.write_count == 3 + + def test_risk_score_tracking(self): + session = AgentGuardSession() + a1 = Action(tool_name="t", checkpoint_result=CheckpointResult(risk_score=30.0)) + a2 = Action(tool_name="t", checkpoint_result=CheckpointResult(risk_score=70.0)) + + session.record_action(a1) + session.record_action(a2) + + assert session.total_risk_score == 100.0 + assert session.max_risk_score == 70.0 + assert session.avg_risk_score == 50.0 + + def test_blocked_count(self): + session = AgentGuardSession() + blocked = Action( + tool_name="t", + checkpoint_result=CheckpointResult(decision=CheckpointDecision.BLOCK), + ) + session.record_action(blocked) + assert session.blocked_count == 1 + assert session.would_block_count == 0 + + def test_would_block_count(self): + session = AgentGuardSession() + wb = Action( + tool_name="t", + checkpoint_result=CheckpointResult(decision=CheckpointDecision.WOULD_BLOCK), + ) + session.record_action(wb) + assert session.would_block_count == 1 + assert session.blocked_count == 0 + + def test_complete(self): + session = AgentGuardSession() + assert session.status == SessionStatus.ACTIVE + assert session.completed_at is None + + session.complete() + assert session.status == SessionStatus.COMPLETED + assert session.completed_at is not None + + def test_summary_contains_key_info(self): + session = AgentGuardSession(mode=AgentGuardMode.MONITOR_ONLY) + summary = session.summary() + assert "monitor-only" in summary + assert session.task_id in summary + assert "Actions: 0" in summary + + def test_avg_risk_score_zero_division(self): + session = AgentGuardSession() + assert session.avg_risk_score == 0.0 diff --git a/databricks-tools-core/tests/unit/agentguard/test_policy.py b/databricks-tools-core/tests/unit/agentguard/test_policy.py new file mode 100644 index 00000000..29c03a56 --- /dev/null +++ b/databricks-tools-core/tests/unit/agentguard/test_policy.py @@ -0,0 +1,266 @@ +""" +Unit tests for AgentGuard policy engine. + +Tests: +- always-block, require-approval, auto-approve rules +- execute_sql_multi splitting +""" + +import pytest +from databricks_tools_core.agentguard.models import ( + Action, + AgentGuardMode, + CheckpointDecision, +) +from databricks_tools_core.agentguard.policy import PolicyEngine + + +@pytest.fixture +def engine(): + """Returns a fresh PolicyEngine instance.""" + return PolicyEngine() + + +class TestAlwaysBlock: + def test_drop_database_blocked_in_enforce(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP DATABASE my_catalog"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, rule = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + assert rule is not None + + def test_drop_database_would_block_in_monitor(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP DATABASE my_catalog"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, rule = engine.check(action, AgentGuardMode.MONITOR_ONLY) + assert decision == CheckpointDecision.WOULD_BLOCK + assert rule is not None + + def test_truncate_any_table_blocked(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "TRUNCATE TABLE main.analytics.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_grant_all_privileges_blocked(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "GRANT ALL PRIVILEGES ON TABLE t TO user1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_revoke_from_blocked(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "REVOKE SELECT ON TABLE t FROM user1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_drop_any_schema_blocked(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP SCHEMA main.analytics"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_if_exists_noise_stripped(self, engine): + """Should block DROP SCHEMA after IF EXISTS is stripped from the statement.""" + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP SCHEMA IF EXISTS main.temp_schema"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + +class TestRequireApproval: + def test_drop_any_table_flagged(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DROP TABLE main.analytics.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, rule = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert rule is not None + + def test_delete_from_any_table_flagged(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "DELETE FROM main.staging.temp WHERE id = 1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_create_or_replace_table_flagged(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "CREATE OR REPLACE TABLE main.staging.temp AS SELECT 1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_delete_app_flagged(self, engine): + action = Action.from_tool_call( + tool_name="delete_app", + tool_params={"name": "my-app"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, rule = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + assert rule is not None + + def test_delete_pipeline_flagged(self, engine): + action = Action.from_tool_call( + tool_name="delete_pipeline", + tool_params={"pipeline_id": "p-1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_execute_code_flagged(self, engine): + action = Action.from_tool_call( + tool_name="execute_databricks_command", + tool_params={"cluster_id": "c-1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_grant_via_manage_uc_grants_flagged(self, engine): + action = Action.from_tool_call( + tool_name="manage_uc_grants", + tool_params={"action": "grant", "name": "main.prod"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_delete_job_via_manage_jobs_flagged(self, engine): + action = Action.from_tool_call( + tool_name="manage_jobs", + tool_params={"action": "delete", "job_id": "j-1"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + +class TestAutoApprove: + def test_select_approved(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "SELECT * FROM main.prod.orders"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, rule = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + assert rule is None + + def test_read_only_tool_approved(self, engine): + action = Action.from_tool_call( + tool_name="get_table_details", + tool_params={}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_insert_not_blocked(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql", + tool_params={"query": "INSERT INTO staging.t VALUES (1)"}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + +class TestMultiSQL: + def test_multi_sql_with_block_statement(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={"sqls": ["SELECT 1", "DROP DATABASE my_catalog"]}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + def test_multi_sql_with_flag_statement(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={"sqls": ["SELECT 1", "DROP TABLE main.analytics.orders"]}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.FLAG + + def test_multi_sql_safe_statements_approved(self, engine): + action = Action.from_tool_call( + tool_name="execute_sql_multi", + tool_params={"sqls": ["SELECT 1", "INSERT INTO main.staging.t VALUES (1)"]}, + task_id="t1", + agent_id="a1", + sequence=1, + ) + decision, _ = engine.check(action, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE diff --git a/databricks-tools-core/tests/unit/agentguard/test_risk.py b/databricks-tools-core/tests/unit/agentguard/test_risk.py new file mode 100644 index 00000000..bdfe0dee --- /dev/null +++ b/databricks-tools-core/tests/unit/agentguard/test_risk.py @@ -0,0 +1,215 @@ +""" +Unit tests for AgentGuard risk scoring. + +Tests: +- per-factor scores and compute_risk_score +""" + +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + CheckpointDecision, +) +from databricks_tools_core.agentguard.risk import ( + _score_action_type, + _score_blast_radius, + _score_data_sensitivity, + _score_environment, + compute_risk_score, +) + + +class TestActionTypeScoring: + def test_select_is_zero(self): + action = Action(tool_name="t", operation="SELECT", action_category=ActionCategory.READ) + assert _score_action_type(action) == 0 + + def test_delete_is_high(self): + action = Action(tool_name="t", operation="DELETE", action_category=ActionCategory.WRITE) + assert _score_action_type(action) == 80 + + def test_drop_is_very_high(self): + action = Action(tool_name="t", operation="DROP", action_category=ActionCategory.DDL) + assert _score_action_type(action) == 95 + + def test_delete_app_uses_prefix_rule(self): + action = Action(tool_name="t", operation="DELETE_APP", action_category=ActionCategory.ADMIN) + assert _score_action_type(action) == 80 + + def test_restart_not_confused_with_start(self): + """Should score RESTART_PIPELINE using the RESTART keyword, not START.""" + action = Action(tool_name="t", operation="RESTART_PIPELINE", action_category=ActionCategory.ADMIN) + score = _score_action_type(action) + assert score == 40, f"Expected RESTART score (40), got {score}" + + def test_start_cluster_exact_match(self): + action = Action(tool_name="t", operation="START_CLUSTER", action_category=ActionCategory.WRITE) + assert _score_action_type(action) == 45 + + def test_execute_code(self): + action = Action(tool_name="t", operation="EXECUTE_CODE", action_category=ActionCategory.ADMIN) + assert _score_action_type(action) == 75 + + def test_unknown_falls_to_category_default(self): + action = Action(tool_name="t", operation="FOOBAR", action_category=ActionCategory.WRITE) + assert _score_action_type(action) == 40 + + def test_multi_operation(self): + action = Action(tool_name="t", operation="MULTI(DROP)", action_category=ActionCategory.DDL) + assert _score_action_type(action) == 95 + + +class TestEnvironmentScoring: + def test_production_high_risk(self): + action = Action(tool_name="t", target_resource_id="main.production.orders") + assert _score_environment(action) == 90 + + def test_prod_dot_high_risk(self): + action = Action(tool_name="t", sql_statement="SELECT * FROM prod.billing") + assert _score_environment(action) == 90 + + def test_staging_moderate(self): + action = Action(tool_name="t", target_resource_id="main.staging.temp") + assert _score_environment(action) == 40 + + def test_dev_low(self): + action = Action(tool_name="t", target_resource_id="main.dev.scratch") + assert _score_environment(action) == 10 + + def test_no_signal_default(self): + action = Action(tool_name="t", target_resource_id="my-job-123") + assert _score_environment(action) == 30 + + +class TestBlastRadiusScoring: + def test_delete_without_where_high(self): + action = Action(tool_name="t", sql_statement="DELETE FROM orders", action_category=ActionCategory.WRITE) + assert _score_blast_radius(action) == 90 + + def test_delete_with_where_moderate(self): + action = Action( + tool_name="t", sql_statement="DELETE FROM orders WHERE id = 1", action_category=ActionCategory.WRITE + ) + assert _score_blast_radius(action) == 30 + + def test_drop_table_max_blast(self): + action = Action(tool_name="t", sql_statement="DROP TABLE orders", action_category=ActionCategory.DDL) + assert _score_blast_radius(action) == 100 + + def test_select_low_blast(self): + action = Action(tool_name="t", sql_statement="SELECT * FROM orders", action_category=ActionCategory.READ) + assert _score_blast_radius(action) == 10 + + def test_rows_affected_scales(self): + action = Action(tool_name="t", rows_affected=5) + assert _score_blast_radius(action) == 15 + + action_large = Action(tool_name="t", rows_affected=5_000_000) + assert _score_blast_radius(action_large) == 85 + + +class TestDataSensitivityScoring: + def test_pii_detected(self): + action = Action(tool_name="t", target_resource_id="main.prod.customer_pii_data") + assert _score_data_sensitivity(action) >= 70 + + def test_ssn_highest(self): + action = Action(tool_name="t", sql_statement="SELECT ssn FROM users") + assert _score_data_sensitivity(action) == 95 + + def test_no_sensitive_data(self): + action = Action(tool_name="t", target_resource_id="main.staging.temp_metrics") + assert _score_data_sensitivity(action) == 0 + + def test_hipaa_in_table_name(self): + action = Action(tool_name="t", target_resource_id="main.prod.hipaa_records") + assert _score_data_sensitivity(action) == 90 + + +class TestComputeRiskScore: + def test_read_operation_low_score(self): + action = Action( + tool_name="execute_sql", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_id="main.dev.scratch", + sql_statement="SELECT * FROM main.dev.scratch", + ) + risk = compute_risk_score(action) + assert risk.score < 35 + assert risk.decision == CheckpointDecision.AUTO_APPROVE + + def test_drop_prod_table_high_score(self): + action = Action( + tool_name="execute_sql", + operation="DROP", + action_category=ActionCategory.DDL, + target_resource_id="main.production.orders", + sql_statement="DROP TABLE main.production.orders", + ) + risk = compute_risk_score(action) + assert risk.score >= 70 + assert risk.decision in ( + CheckpointDecision.HOLD_FOR_APPROVAL, + CheckpointDecision.BLOCK, + ) + + def test_scope_violation_adds_penalty(self): + action = Action( + tool_name="execute_sql", + operation="INSERT", + action_category=ActionCategory.WRITE, + sql_statement="INSERT INTO staging.t VALUES (1)", + ) + score_no_violation = compute_risk_score(action, scope_violated=False) + score_with_violation = compute_risk_score(action, scope_violated=True) + assert score_with_violation.score == score_no_violation.score + 25 + + def test_breakdown_contains_all_factors(self): + action = Action(tool_name="t", operation="SELECT", action_category=ActionCategory.READ) + risk = compute_risk_score(action) + expected_keys = { + "action_type", + "environment", + "blast_radius", + "time_context", + "behavioral", + "data_sensitivity", + "scope_violation_penalty", + } + assert set(risk.breakdown.keys()) == expected_keys + + def test_threshold_auto_approve(self): + """Should auto-approve when the combined score stays below the flag threshold.""" + action = Action( + tool_name="t", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_id="main.dev.x", + sql_statement="SELECT 1", + ) + risk = compute_risk_score(action) + assert risk.decision == CheckpointDecision.AUTO_APPROVE + + def test_threshold_hold_for_approval_used(self): + """Should yield FLAG or HOLD_FOR_APPROVAL for a high-risk production DELETE.""" + action = Action( + tool_name="t", + operation="DELETE", + action_category=ActionCategory.WRITE, + target_resource_id="main.production.orders", + sql_statement="DELETE FROM main.production.orders", + ) + risk = compute_risk_score(action) + assert risk.decision in ( + CheckpointDecision.FLAG, + CheckpointDecision.HOLD_FOR_APPROVAL, + ) + + def test_custom_thresholds(self): + action = Action(tool_name="t", operation="SELECT", action_category=ActionCategory.READ) + risk = compute_risk_score( + action, + thresholds={"flag_above": 1, "hold_above": 2, "block_above": 3}, + ) + assert risk.decision == CheckpointDecision.BLOCK diff --git a/databricks-tools-core/tests/unit/agentguard/test_scope.py b/databricks-tools-core/tests/unit/agentguard/test_scope.py new file mode 100644 index 00000000..45b94039 --- /dev/null +++ b/databricks-tools-core/tests/unit/agentguard/test_scope.py @@ -0,0 +1,340 @@ +""" +Unit tests for AgentGuard scope enforcement and templates. + +Tests: +- check_scope, check_scope_limits, load_template, list_templates +""" + +import pytest +from databricks_tools_core.agentguard.models import ( + Action, + ActionCategory, + AgentGuardMode, + CheckpointDecision, +) +from databricks_tools_core.agentguard.scope import ( + ResourceScope, + ScopeManifest, + check_scope, + check_scope_limits, + list_templates, + load_template, +) + + +class TestCheckScopeInScope: + def test_exact_match(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.production.orders"]), + ) + action = Action( + tool_name="execute_sql", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_type="sql", + target_resource_id="main.production.orders", + ) + decision, violation = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + assert violation is None + + def test_glob_pattern_match(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.staging.*"]), + ) + action = Action( + tool_name="execute_sql", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_type="sql", + target_resource_id="main.staging.temp_orders", + ) + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_case_insensitive_matching(self): + scope = ScopeManifest( + tables=ResourceScope(read=["Main.Production.Orders"]), + ) + action = Action( + tool_name="execute_sql", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_type="sql", + target_resource_id="main.production.orders", + ) + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + +class TestCheckScopeOutOfScope: + def test_no_matching_pattern(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.staging.*"]), + ) + action = Action( + tool_name="execute_sql", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_type="sql", + target_resource_id="main.production.billing", + ) + decision, violation = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + assert "billing" in violation + assert "not in scope" in violation + + def test_out_of_scope_monitor_mode(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.staging.*"]), + ) + action = Action( + tool_name="execute_sql", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_type="sql", + target_resource_id="main.production.billing", + ) + decision, violation = check_scope(action, scope, AgentGuardMode.MONITOR_ONLY) + assert decision == CheckpointDecision.WOULD_BLOCK + assert "[monitor]" in violation + + def test_write_to_read_only_scope(self): + scope = ScopeManifest( + tables=ResourceScope( + read=["main.production.*"], + write=["main.staging.temp_*"], + ), + ) + action = Action( + tool_name="execute_sql", + operation="INSERT", + action_category=ActionCategory.WRITE, + target_resource_type="sql", + target_resource_id="main.production.orders", + ) + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + + +class TestDefaultDeny: + def test_empty_patterns_allowed_when_default_deny_false(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.*"]), + default_deny=False, + ) + action = Action( + tool_name="execute_sql", + operation="INSERT", + action_category=ActionCategory.WRITE, + target_resource_type="sql", + target_resource_id="main.staging.t", + ) + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_empty_patterns_blocked_when_default_deny_true(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.*"]), + default_deny=True, + ) + action = Action( + tool_name="execute_sql", + operation="INSERT", + action_category=ActionCategory.WRITE, + target_resource_type="sql", + target_resource_id="main.staging.t", + ) + decision, violation = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.BLOCK + assert "default_deny" in violation + + def test_default_deny_passes_for_completely_unscoped_resource(self): + """Should not apply default_deny when the resource type has no scope patterns.""" + scope = ScopeManifest( + tables=ResourceScope(read=["main.*"]), + default_deny=True, + ) + action = Action( + tool_name="delete_app", + operation="DELETE_APP", + action_category=ActionCategory.ADMIN, + target_resource_type="app", + target_resource_id="my-app", + ) + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_default_deny_monitor_mode(self): + scope = ScopeManifest( + tables=ResourceScope(read=["main.*"]), + default_deny=True, + ) + action = Action( + tool_name="execute_sql", + operation="INSERT", + action_category=ActionCategory.WRITE, + target_resource_type="sql", + target_resource_id="main.staging.t", + ) + decision, violation = check_scope(action, scope, AgentGuardMode.MONITOR_ONLY) + assert decision == CheckpointDecision.WOULD_BLOCK + assert "[monitor]" in violation + + +class TestScopePassthrough: + def test_no_resource_id_passes(self): + scope = ScopeManifest(tables=ResourceScope(read=["main.*"])) + action = Action( + tool_name="execute_sql", + operation="SELECT", + action_category=ActionCategory.READ, + target_resource_type="sql", + target_resource_id="", + ) + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_unknown_resource_type_passes(self): + scope = ScopeManifest(tables=ResourceScope(read=["main.*"])) + action = Action( + tool_name="t", + operation="SOMETHING", + action_category=ActionCategory.UNKNOWN, + target_resource_type="alien_resource", + target_resource_id="x", + ) + decision, _ = check_scope(action, scope, AgentGuardMode.ENFORCE) + assert decision == CheckpointDecision.AUTO_APPROVE + + +class TestScopeLimits: + def test_max_actions_exceeded(self): + scope = ScopeManifest(max_actions=10) + action = Action(tool_name="t", action_category=ActionCategory.READ) + decision, violation = check_scope_limits( + action, + scope, + session_action_count=10, + session_write_count=0, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.BLOCK + assert "max_actions" in violation + + def test_max_actions_not_exceeded(self): + scope = ScopeManifest(max_actions=10) + action = Action(tool_name="t", action_category=ActionCategory.READ) + decision, _ = check_scope_limits( + action, + scope, + session_action_count=5, + session_write_count=0, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_max_write_actions_exceeded(self): + scope = ScopeManifest(max_write_actions=5) + action = Action(tool_name="t", action_category=ActionCategory.WRITE) + decision, violation = check_scope_limits( + action, + scope, + session_action_count=20, + session_write_count=5, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.BLOCK + assert "max_write_actions" in violation + + def test_max_write_actions_only_checks_writes(self): + scope = ScopeManifest(max_write_actions=5) + action = Action(tool_name="t", action_category=ActionCategory.READ) + decision, _ = check_scope_limits( + action, + scope, + session_action_count=100, + session_write_count=10, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.AUTO_APPROVE + + def test_limits_in_monitor_mode(self): + scope = ScopeManifest(max_actions=5) + action = Action(tool_name="t", action_category=ActionCategory.READ) + decision, violation = check_scope_limits( + action, + scope, + session_action_count=5, + session_write_count=0, + mode=AgentGuardMode.MONITOR_ONLY, + ) + assert decision == CheckpointDecision.WOULD_BLOCK + assert "[monitor]" in violation + + def test_no_limits_passes(self): + scope = ScopeManifest() + action = Action(tool_name="t", action_category=ActionCategory.WRITE) + decision, _ = check_scope_limits( + action, + scope, + session_action_count=9999, + session_write_count=9999, + mode=AgentGuardMode.ENFORCE, + ) + assert decision == CheckpointDecision.AUTO_APPROVE + + +class TestTemplates: + def test_list_templates_returns_available(self): + templates = list_templates() + assert "etl_pipeline_fix" in templates + assert "data_quality_check" in templates + assert "model_deployment" in templates + + def test_load_etl_template(self): + scope = load_template( + "etl_pipeline_fix", + { + "catalog": "main", + "target_table": "customer_orders", + }, + ) + assert isinstance(scope, ScopeManifest) + assert scope.default_deny is True + assert "main.production.customer_orders" in scope.tables.read + assert "main.staging.temp_*" in scope.tables.write + assert scope.max_actions == 100 + + def test_load_data_quality_template(self): + scope = load_template( + "data_quality_check", + { + "catalog": "main", + "schema": "production", + "target_table": "orders", + }, + ) + assert scope.default_deny is True + assert "main.production.orders" in scope.tables.read + assert "main.dq_results.*" in scope.tables.write + + def test_load_model_deployment_template(self): + scope = load_template( + "model_deployment", + { + "catalog": "main", + "endpoint_name": "fraud_v2", + }, + ) + assert scope.default_deny is True + assert "fraud_v2" in scope.serving_endpoints.write + assert scope.tables.write == [] + + def test_load_nonexistent_template_raises(self): + with pytest.raises(FileNotFoundError, match="not found"): + load_template("does_not_exist", {}) + + def test_load_template_missing_variable_raises(self): + with pytest.raises(ValueError, match="unresolved variables"): + load_template("etl_pipeline_fix", {"catalog": "main"}) diff --git a/databricks-tools-core/tests/unit/agentguard/test_timing.py b/databricks-tools-core/tests/unit/agentguard/test_timing.py new file mode 100644 index 00000000..6ac35eae --- /dev/null +++ b/databricks-tools-core/tests/unit/agentguard/test_timing.py @@ -0,0 +1,67 @@ +""" +Unit tests for AgentGuard timing helpers. + +Tests: +- Timer.measure and recorded durations +""" + +import time + +import pytest +from databricks_tools_core.agentguard.timing import Timer + + +class TestTimer: + """Tests for Timer.""" + + def test_measure_sync(self): + timer = Timer() + result = timer.measure("test_step", lambda x: x * 2, 5) + assert result == 10 + assert len(timer.records) == 1 + assert timer.records[0].name == "test_step" + assert timer.records[0].duration_ms >= 0 + + def test_measure_captures_kwargs(self): + timer = Timer() + + def add(a, b): + return a + b + + result = timer.measure("add", add, a=3, b=7) + assert result == 10 + + def test_multiple_measurements(self): + timer = Timer() + timer.measure("step1", lambda: 1) + timer.measure("step2", lambda: 2) + timer.measure("step3", lambda: 3) + assert len(timer.records) == 3 + assert [r.name for r in timer.records] == ["step1", "step2", "step3"] + + def test_total_ms(self): + timer = Timer() + + def slow(): + time.sleep(0.01) + return True + + timer.measure("slow_step", slow) + assert timer.total_ms >= 10 + assert timer.total_ms < 500 + + def test_total_ms_empty(self): + timer = Timer() + assert timer.total_ms == 0.0 + + def test_measure_propagates_exception(self): + timer = Timer() + + def boom(): + raise ValueError("test error") + + with pytest.raises(ValueError, match="test error"): + timer.measure("boom", boom) + + # measure() does not record timing when the wrapped callable raises + assert len(timer.records) == 0