Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions src/adcp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,41 @@ async def _dispatch_tool(client: ADCPClient, tool_name: str, payload: dict[str,
)


_HEADER_FORBIDDEN = ("\r", "\n", "\x00")


def parse_header_args(header_args: list[str] | None) -> dict[str, str] | None:
"""Parse repeated --header KEY=VALUE arguments into a dict."""
if not header_args:
return None
result: dict[str, str] = {}
for raw in header_args:
key, sep, value = raw.partition("=")
if not sep:
print(
f"Error: --header value must be KEY=VALUE, got: {raw!r}",
file=sys.stderr,
)
sys.exit(1)
if not key:
print(f"Error: --header key cannot be empty in: {raw!r}", file=sys.stderr)
sys.exit(1)
if any(c in key for c in _HEADER_FORBIDDEN):
print(
f"Error: --header key contains illegal characters (CRLF/null): {key!r}",
file=sys.stderr,
)
sys.exit(1)
if any(c in value for c in _HEADER_FORBIDDEN):
print(
f"Error: --header value for {key!r} contains illegal characters (CRLF/null)",
file=sys.stderr,
)
sys.exit(1)
result[key] = value
return result or None


def load_payload(payload_arg: str | None) -> dict[str, Any]:
"""Load payload from argument (JSON, @file, or stdin)."""
if not payload_arg:
Expand Down Expand Up @@ -424,7 +459,12 @@ def load_payload(payload_arg: str | None) -> dict[str, Any]:
sys.exit(1)


def handle_save_auth(alias: str, url: str | None, protocol: str | None) -> None:
def handle_save_auth(
alias: str,
url: str | None,
protocol: str | None,
extra_headers: dict[str, str] | None = None,
) -> None:
"""Handle --save-auth command."""
if not url:
# Interactive mode
Expand All @@ -438,7 +478,7 @@ def handle_save_auth(alias: str, url: str | None, protocol: str | None) -> None:

auth_token = input("Auth token (optional): ").strip() or None

save_agent(alias, url, protocol, auth_token)
save_agent(alias, url, protocol, auth_token, extra_headers or None)
print(f"✓ Saved agent '{alias}'")


Expand All @@ -457,6 +497,9 @@ def handle_list_agents() -> None:
print(f" URL: {config.get('agent_uri')}")
print(f" Protocol: {config.get('protocol', 'mcp').upper()}")
print(f" Auth: {auth}")
extra_h = config.get("extra_headers")
if extra_h:
print(f" Extra headers: {', '.join(extra_h.keys())}")


def handle_remove_agent(alias: str) -> None:
Expand Down Expand Up @@ -520,6 +563,14 @@ def main() -> None:
# Execution options
parser.add_argument("--protocol", choices=["mcp", "a2a"], help="Force protocol type")
parser.add_argument("--auth", help="Authentication token")
parser.add_argument(
"--header",
"-H",
metavar="KEY=VALUE",
action="append",
dest="headers",
help="Extra request header (repeatable, e.g. --header x-adcp-tenant=acme)",
)
parser.add_argument("--json", action="store_true", help="Output as JSON")
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
parser.add_argument("--help", "-h", action="store_true", help="Show help")
Expand Down Expand Up @@ -572,7 +623,8 @@ def main() -> None:
if args.save_auth:
url = args.agent if args.agent else None
protocol = args.tool if args.tool else None
handle_save_auth(args.save_auth, url, protocol)
parsed_extra_headers = parse_header_args(args.headers)
handle_save_auth(args.save_auth, url, protocol, parsed_extra_headers)
sys.exit(0)

if args.list_agents:
Expand Down Expand Up @@ -609,6 +661,12 @@ def main() -> None:
if args.debug:
agent_config["debug"] = True

extra_headers = parse_header_args(args.headers)
if extra_headers:
# Merge CLI headers on top of any stored per-agent extra_headers
stored: dict[str, str] = agent_config.get("extra_headers") or {}
agent_config["extra_headers"] = {**stored, **extra_headers}

# Load payload
payload = load_payload(args.payload)

Expand Down
9 changes: 8 additions & 1 deletion src/adcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def save_config(config: dict[str, Any]) -> None:


def save_agent(
alias: str, url: str, protocol: str | None = None, auth_token: str | None = None
alias: str,
url: str,
protocol: str | None = None,
auth_token: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> None:
"""Save agent configuration."""
config = load_config()
Expand All @@ -54,6 +58,9 @@ def save_agent(
if auth_token:
config["agents"][alias]["auth_token"] = auth_token

if extra_headers:
config["agents"][alias]["extra_headers"] = extra_headers

save_config(config)


Expand Down
18 changes: 15 additions & 3 deletions src/adcp/protocols/a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,24 @@ async def _get_httpx_client(self) -> httpx.AsyncClient:
keepalive_expiry=30.0,
)

headers = {}
headers: dict[str, str] = {}
if self.agent_config.extra_headers:
headers.update(self.agent_config.extra_headers)
if self.agent_config.auth_token:
# auth always wins on conflict
if self.agent_config.auth_type == "bearer":
headers["Authorization"] = f"Bearer {self.agent_config.auth_token}"
auth_header_name = "Authorization"
auth_value = f"Bearer {self.agent_config.auth_token}"
else:
headers[self.agent_config.auth_header] = self.agent_config.auth_token
auth_header_name = self.agent_config.auth_header
auth_value = self.agent_config.auth_token
if auth_header_name in headers:
logger.warning(
"extra_headers contains %r which conflicts with auth header; "
"auth_token wins",
auth_header_name,
)
headers[auth_header_name] = auth_value

# When ADCPClient installed a signing_request_hook, register it as
# an httpx request event hook so RFC 9421 signature headers are
Expand Down
19 changes: 13 additions & 6 deletions src/adcp/protocols/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,22 @@ async def _get_session(self) -> ClientSession:
self._exit_stack = AsyncExitStack()

# Create SSE client with authentication header
headers = {}
headers: dict[str, str] = {}
if self.agent_config.extra_headers:
headers.update(self.agent_config.extra_headers)
if self.agent_config.auth_token:
# Support custom auth headers and types
if self.agent_config.auth_type == "bearer":
headers[self.agent_config.auth_header] = (
f"Bearer {self.agent_config.auth_token}"
# Support custom auth headers and types; auth always wins on conflict
auth_header_name = self.agent_config.auth_header
if auth_header_name in headers:
logger.warning(
"extra_headers contains %r which conflicts with auth_header; "
"auth_token wins",
auth_header_name,
)
if self.agent_config.auth_type == "bearer":
headers[auth_header_name] = f"Bearer {self.agent_config.auth_token}"
else:
headers[self.agent_config.auth_header] = self.agent_config.auth_token
headers[auth_header_name] = self.agent_config.auth_token

# Try the user's exact URL first
urls_to_try = [self.agent_config.agent_uri]
Expand Down
14 changes: 14 additions & 0 deletions src/adcp/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AgentConfig(BaseModel):
"streamable_http" # "streamable_http" (default, modern) or "sse" (legacy fallback)
)
debug: bool = False # Enable debug mode to capture request/response details
extra_headers: dict[str, str] | None = None # Extra request headers (e.g. x-adcp-tenant)

@field_validator("agent_uri")
@classmethod
Expand Down Expand Up @@ -86,6 +87,19 @@ def validate_auth_type(cls, v: str) -> str:
)
return v

@field_validator("extra_headers")
@classmethod
def validate_extra_headers(cls, v: dict[str, str] | None) -> dict[str, str] | None:
"""Reject CRLF/null sequences in header names and values (header injection guard)."""
if v is None:
return v
for key, value in v.items():
if "\r" in key or "\n" in key or "\x00" in key:
raise ValueError(f"header name contains CRLF or null byte: {key!r}")
if "\r" in value or "\n" in value or "\x00" in value:
raise ValueError(f"header value for {key!r} contains CRLF or null byte")
return v


class TaskStatus(str, Enum):
"""Task execution status."""
Expand Down
101 changes: 99 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

import pytest

from adcp.__main__ import load_payload, resolve_agent_config
from adcp.config import save_agent
from adcp.__main__ import load_payload, parse_header_args, resolve_agent_config
from adcp.config import get_agent, save_agent


class TestCLIBasics:
Expand Down Expand Up @@ -465,3 +465,100 @@ def test_check_deprecated_fields_handles_list(self, capsys):
_check_deprecated_fields(formats)
captured = capsys.readouterr()
assert "deprecated" not in captured.err.lower()


class TestParseHeaderArgs:
"""Tests for --header KEY=VALUE parsing."""

def test_single_header(self):
result = parse_header_args(["x-adcp-tenant=acme"])
assert result == {"x-adcp-tenant": "acme"}

def test_repeated_headers(self):
result = parse_header_args(["x-adcp-tenant=acme", "x-trace-id=abc"])
assert result == {"x-adcp-tenant": "acme", "x-trace-id": "abc"}

def test_value_contains_equals(self):
"""Value containing = must not be split at the second =."""
result = parse_header_args(["Authorization=Basic dXNlcjpwYXNzd29yZA=="])
assert result == {"Authorization": "Basic dXNlcjpwYXNzd29yZA=="}

def test_none_returns_none(self):
assert parse_header_args(None) is None

def test_empty_list_returns_none(self):
assert parse_header_args([]) is None

def test_missing_equals_exits(self):
with pytest.raises(SystemExit):
parse_header_args(["x-adcp-tenant"])

def test_empty_key_exits(self):
with pytest.raises(SystemExit):
parse_header_args(["=value"])

def test_crlf_in_key_exits(self):
with pytest.raises(SystemExit):
parse_header_args(["x-bad\r\nkey=value"])

def test_null_in_value_exits(self):
with pytest.raises(SystemExit):
parse_header_args(["x-key=val\x00ue"])


class TestExtraHeadersSaveLoad:
"""Tests for extra_headers persistence in config.json."""

def test_save_agent_persists_extra_headers(self, tmp_path, monkeypatch):
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps({"agents": {}}))

import adcp.config

monkeypatch.setattr(adcp.config, "CONFIG_FILE", config_file)

save_agent(
"local",
"http://localhost:8000/mcp",
"mcp",
"tok",
{"x-adcp-tenant": "acme"},
)

raw = json.loads(config_file.read_text())
assert raw["agents"]["local"]["extra_headers"] == {"x-adcp-tenant": "acme"}

def test_get_agent_returns_extra_headers(self, tmp_path, monkeypatch):
config_data = {
"agents": {
"local": {
"agent_uri": "http://localhost:8000/mcp",
"protocol": "mcp",
"extra_headers": {"x-adcp-tenant": "acme"},
}
}
}
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps(config_data))

import adcp.config

monkeypatch.setattr(adcp.config, "CONFIG_FILE", config_file)

agent = get_agent("local")
assert agent is not None
assert agent["extra_headers"] == {"x-adcp-tenant": "acme"}

def test_save_agent_without_extra_headers(self, tmp_path, monkeypatch):
"""Agents saved without extra_headers should not have the key in config."""
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps({"agents": {}}))

import adcp.config

monkeypatch.setattr(adcp.config, "CONFIG_FILE", config_file)

save_agent("bare", "http://localhost:8000/mcp", "mcp")

raw = json.loads(config_file.read_text())
assert "extra_headers" not in raw["agents"]["bare"]
Loading
Loading