diff --git a/src/adcp/__main__.py b/src/adcp/__main__.py index e45e7d6e2..09aa87948 100644 --- a/src/adcp/__main__.py +++ b/src/adcp/__main__.py @@ -397,6 +397,41 @@ async def _dispatch_tool(client: ADCPClient, tool_name: str, payload: dict[str, ) +_HEADER_FORBIDDEN = ("\r", "\n", "\x00") + + +def parse_header_args(header_args: list[str] | None) -> dict[str, str] | None: + """Parse repeated --header KEY=VALUE arguments into a dict.""" + if not header_args: + return None + result: dict[str, str] = {} + for raw in header_args: + key, sep, value = raw.partition("=") + if not sep: + print( + f"Error: --header value must be KEY=VALUE, got: {raw!r}", + file=sys.stderr, + ) + sys.exit(1) + if not key: + print(f"Error: --header key cannot be empty in: {raw!r}", file=sys.stderr) + sys.exit(1) + if any(c in key for c in _HEADER_FORBIDDEN): + print( + f"Error: --header key contains illegal characters (CRLF/null): {key!r}", + file=sys.stderr, + ) + sys.exit(1) + if any(c in value for c in _HEADER_FORBIDDEN): + print( + f"Error: --header value for {key!r} contains illegal characters (CRLF/null)", + file=sys.stderr, + ) + sys.exit(1) + result[key] = value + return result or None + + def load_payload(payload_arg: str | None) -> dict[str, Any]: """Load payload from argument (JSON, @file, or stdin).""" if not payload_arg: @@ -424,7 +459,12 @@ def load_payload(payload_arg: str | None) -> dict[str, Any]: sys.exit(1) -def handle_save_auth(alias: str, url: str | None, protocol: str | None) -> None: +def handle_save_auth( + alias: str, + url: str | None, + protocol: str | None, + extra_headers: dict[str, str] | None = None, +) -> None: """Handle --save-auth command.""" if not url: # Interactive mode @@ -438,7 +478,7 @@ def handle_save_auth(alias: str, url: str | None, protocol: str | None) -> None: auth_token = input("Auth token (optional): ").strip() or None - save_agent(alias, url, protocol, auth_token) + save_agent(alias, url, protocol, auth_token, extra_headers or None) print(f"✓ Saved agent '{alias}'") @@ -457,6 +497,9 @@ def handle_list_agents() -> None: print(f" URL: {config.get('agent_uri')}") print(f" Protocol: {config.get('protocol', 'mcp').upper()}") print(f" Auth: {auth}") + extra_h = config.get("extra_headers") + if extra_h: + print(f" Extra headers: {', '.join(extra_h.keys())}") def handle_remove_agent(alias: str) -> None: @@ -520,6 +563,14 @@ def main() -> None: # Execution options parser.add_argument("--protocol", choices=["mcp", "a2a"], help="Force protocol type") parser.add_argument("--auth", help="Authentication token") + parser.add_argument( + "--header", + "-H", + metavar="KEY=VALUE", + action="append", + dest="headers", + help="Extra request header (repeatable, e.g. --header x-adcp-tenant=acme)", + ) parser.add_argument("--json", action="store_true", help="Output as JSON") parser.add_argument("--debug", action="store_true", help="Enable debug mode") parser.add_argument("--help", "-h", action="store_true", help="Show help") @@ -572,7 +623,8 @@ def main() -> None: if args.save_auth: url = args.agent if args.agent else None protocol = args.tool if args.tool else None - handle_save_auth(args.save_auth, url, protocol) + parsed_extra_headers = parse_header_args(args.headers) + handle_save_auth(args.save_auth, url, protocol, parsed_extra_headers) sys.exit(0) if args.list_agents: @@ -609,6 +661,12 @@ def main() -> None: if args.debug: agent_config["debug"] = True + extra_headers = parse_header_args(args.headers) + if extra_headers: + # Merge CLI headers on top of any stored per-agent extra_headers + stored: dict[str, str] = agent_config.get("extra_headers") or {} + agent_config["extra_headers"] = {**stored, **extra_headers} + # Load payload payload = load_payload(args.payload) diff --git a/src/adcp/config.py b/src/adcp/config.py index 95d09127d..595fb19d9 100644 --- a/src/adcp/config.py +++ b/src/adcp/config.py @@ -38,7 +38,11 @@ def save_config(config: dict[str, Any]) -> None: def save_agent( - alias: str, url: str, protocol: str | None = None, auth_token: str | None = None + alias: str, + url: str, + protocol: str | None = None, + auth_token: str | None = None, + extra_headers: dict[str, str] | None = None, ) -> None: """Save agent configuration.""" config = load_config() @@ -54,6 +58,9 @@ def save_agent( if auth_token: config["agents"][alias]["auth_token"] = auth_token + if extra_headers: + config["agents"][alias]["extra_headers"] = extra_headers + save_config(config) diff --git a/src/adcp/protocols/a2a.py b/src/adcp/protocols/a2a.py index 1d9a2a996..bd9615d55 100644 --- a/src/adcp/protocols/a2a.py +++ b/src/adcp/protocols/a2a.py @@ -227,12 +227,24 @@ async def _get_httpx_client(self) -> httpx.AsyncClient: keepalive_expiry=30.0, ) - headers = {} + headers: dict[str, str] = {} + if self.agent_config.extra_headers: + headers.update(self.agent_config.extra_headers) if self.agent_config.auth_token: + # auth always wins on conflict if self.agent_config.auth_type == "bearer": - headers["Authorization"] = f"Bearer {self.agent_config.auth_token}" + auth_header_name = "Authorization" + auth_value = f"Bearer {self.agent_config.auth_token}" else: - headers[self.agent_config.auth_header] = self.agent_config.auth_token + auth_header_name = self.agent_config.auth_header + auth_value = self.agent_config.auth_token + if auth_header_name in headers: + logger.warning( + "extra_headers contains %r which conflicts with auth header; " + "auth_token wins", + auth_header_name, + ) + headers[auth_header_name] = auth_value # When ADCPClient installed a signing_request_hook, register it as # an httpx request event hook so RFC 9421 signature headers are diff --git a/src/adcp/protocols/mcp.py b/src/adcp/protocols/mcp.py index a88167de5..54c4902d4 100644 --- a/src/adcp/protocols/mcp.py +++ b/src/adcp/protocols/mcp.py @@ -340,15 +340,22 @@ async def _get_session(self) -> ClientSession: self._exit_stack = AsyncExitStack() # Create SSE client with authentication header - headers = {} + headers: dict[str, str] = {} + if self.agent_config.extra_headers: + headers.update(self.agent_config.extra_headers) if self.agent_config.auth_token: - # Support custom auth headers and types - if self.agent_config.auth_type == "bearer": - headers[self.agent_config.auth_header] = ( - f"Bearer {self.agent_config.auth_token}" + # Support custom auth headers and types; auth always wins on conflict + auth_header_name = self.agent_config.auth_header + if auth_header_name in headers: + logger.warning( + "extra_headers contains %r which conflicts with auth_header; " + "auth_token wins", + auth_header_name, ) + if self.agent_config.auth_type == "bearer": + headers[auth_header_name] = f"Bearer {self.agent_config.auth_token}" else: - headers[self.agent_config.auth_header] = self.agent_config.auth_token + headers[auth_header_name] = self.agent_config.auth_token # Try the user's exact URL first urls_to_try = [self.agent_config.agent_uri] diff --git a/src/adcp/types/core.py b/src/adcp/types/core.py index 18a5f7ce9..371408694 100644 --- a/src/adcp/types/core.py +++ b/src/adcp/types/core.py @@ -30,6 +30,7 @@ class AgentConfig(BaseModel): "streamable_http" # "streamable_http" (default, modern) or "sse" (legacy fallback) ) debug: bool = False # Enable debug mode to capture request/response details + extra_headers: dict[str, str] | None = None # Extra request headers (e.g. x-adcp-tenant) @field_validator("agent_uri") @classmethod @@ -86,6 +87,19 @@ def validate_auth_type(cls, v: str) -> str: ) return v + @field_validator("extra_headers") + @classmethod + def validate_extra_headers(cls, v: dict[str, str] | None) -> dict[str, str] | None: + """Reject CRLF/null sequences in header names and values (header injection guard).""" + if v is None: + return v + for key, value in v.items(): + if "\r" in key or "\n" in key or "\x00" in key: + raise ValueError(f"header name contains CRLF or null byte: {key!r}") + if "\r" in value or "\n" in value or "\x00" in value: + raise ValueError(f"header value for {key!r} contains CRLF or null byte") + return v + class TaskStatus(str, Enum): """Task execution status.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 56f4fdb49..ba615602a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,8 +13,8 @@ import pytest -from adcp.__main__ import load_payload, resolve_agent_config -from adcp.config import save_agent +from adcp.__main__ import load_payload, parse_header_args, resolve_agent_config +from adcp.config import get_agent, save_agent class TestCLIBasics: @@ -465,3 +465,100 @@ def test_check_deprecated_fields_handles_list(self, capsys): _check_deprecated_fields(formats) captured = capsys.readouterr() assert "deprecated" not in captured.err.lower() + + +class TestParseHeaderArgs: + """Tests for --header KEY=VALUE parsing.""" + + def test_single_header(self): + result = parse_header_args(["x-adcp-tenant=acme"]) + assert result == {"x-adcp-tenant": "acme"} + + def test_repeated_headers(self): + result = parse_header_args(["x-adcp-tenant=acme", "x-trace-id=abc"]) + assert result == {"x-adcp-tenant": "acme", "x-trace-id": "abc"} + + def test_value_contains_equals(self): + """Value containing = must not be split at the second =.""" + result = parse_header_args(["Authorization=Basic dXNlcjpwYXNzd29yZA=="]) + assert result == {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="} + + def test_none_returns_none(self): + assert parse_header_args(None) is None + + def test_empty_list_returns_none(self): + assert parse_header_args([]) is None + + def test_missing_equals_exits(self): + with pytest.raises(SystemExit): + parse_header_args(["x-adcp-tenant"]) + + def test_empty_key_exits(self): + with pytest.raises(SystemExit): + parse_header_args(["=value"]) + + def test_crlf_in_key_exits(self): + with pytest.raises(SystemExit): + parse_header_args(["x-bad\r\nkey=value"]) + + def test_null_in_value_exits(self): + with pytest.raises(SystemExit): + parse_header_args(["x-key=val\x00ue"]) + + +class TestExtraHeadersSaveLoad: + """Tests for extra_headers persistence in config.json.""" + + def test_save_agent_persists_extra_headers(self, tmp_path, monkeypatch): + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps({"agents": {}})) + + import adcp.config + + monkeypatch.setattr(adcp.config, "CONFIG_FILE", config_file) + + save_agent( + "local", + "http://localhost:8000/mcp", + "mcp", + "tok", + {"x-adcp-tenant": "acme"}, + ) + + raw = json.loads(config_file.read_text()) + assert raw["agents"]["local"]["extra_headers"] == {"x-adcp-tenant": "acme"} + + def test_get_agent_returns_extra_headers(self, tmp_path, monkeypatch): + config_data = { + "agents": { + "local": { + "agent_uri": "http://localhost:8000/mcp", + "protocol": "mcp", + "extra_headers": {"x-adcp-tenant": "acme"}, + } + } + } + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps(config_data)) + + import adcp.config + + monkeypatch.setattr(adcp.config, "CONFIG_FILE", config_file) + + agent = get_agent("local") + assert agent is not None + assert agent["extra_headers"] == {"x-adcp-tenant": "acme"} + + def test_save_agent_without_extra_headers(self, tmp_path, monkeypatch): + """Agents saved without extra_headers should not have the key in config.""" + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps({"agents": {}})) + + import adcp.config + + monkeypatch.setattr(adcp.config, "CONFIG_FILE", config_file) + + save_agent("bare", "http://localhost:8000/mcp", "mcp") + + raw = json.loads(config_file.read_text()) + assert "extra_headers" not in raw["agents"]["bare"] diff --git a/tests/test_client.py b/tests/test_client.py index f1a283d5e..6c5a007e8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -20,6 +20,75 @@ def test_agent_config_creation(): assert config.protocol == Protocol.A2A +def test_agent_config_extra_headers(): + """AgentConfig accepts extra_headers and stores as dict.""" + config = AgentConfig( + id="test", + agent_uri="https://agent.example.com", + protocol=Protocol.MCP, + extra_headers={"x-adcp-tenant": "acme", "x-trace-id": "abc123"}, + ) + assert config.extra_headers == {"x-adcp-tenant": "acme", "x-trace-id": "abc123"} + assert config.extra_headers is not None + + +def test_agent_config_extra_headers_default_none(): + """extra_headers defaults to None.""" + config = AgentConfig( + id="test", + agent_uri="https://agent.example.com", + protocol=Protocol.MCP, + ) + assert config.extra_headers is None + + +def test_agent_config_extra_headers_roundtrip(): + """extra_headers survives model_dump / model_validate round-trip as plain dict.""" + config = AgentConfig( + id="rt", + agent_uri="https://agent.example.com", + protocol=Protocol.MCP, + extra_headers={"x-adcp-tenant": "t1"}, + ) + dumped = config.model_dump() + restored = AgentConfig.model_validate(dumped) + assert restored.extra_headers == {"x-adcp-tenant": "t1"} + assert isinstance(restored.extra_headers, dict) + + +def test_agent_config_extra_headers_crlf_key_rejected(): + """extra_headers validator rejects CRLF in header names.""" + with pytest.raises(Exception, match="CRLF or null"): + AgentConfig( + id="test", + agent_uri="https://agent.example.com", + protocol=Protocol.MCP, + extra_headers={"x-bad\r\nkey": "value"}, + ) + + +def test_agent_config_extra_headers_crlf_value_rejected(): + """extra_headers validator rejects CRLF in header values.""" + with pytest.raises(Exception, match="CRLF or null"): + AgentConfig( + id="test", + agent_uri="https://agent.example.com", + protocol=Protocol.MCP, + extra_headers={"x-ok-key": "bad\nvalue"}, + ) + + +def test_agent_config_extra_headers_null_byte_rejected(): + """extra_headers validator rejects null bytes (matching webhooks.py _HEADER_FORBIDDEN_CHARS).""" + with pytest.raises(Exception, match="CRLF or null"): + AgentConfig( + id="test", + agent_uri="https://agent.example.com", + protocol=Protocol.MCP, + extra_headers={"x-ok-key": "val\x00ue"}, + ) + + def test_client_creation(): """Test creating ADCP client.""" config = AgentConfig(