Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions examples/ai_modinput_app/bin/agentic_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from _collections_abc import dict_items
from typing import final, override

from splunklib.ai.messages import AIMessage, ContentBlock, TextBlock

# ! NOTE: This insert is only needed for splunk-sdk-python CI/CD to work.
# ! Remove this if you're modifying this example locally.
sys.path.insert(0, "/splunklib-deps")
Expand Down Expand Up @@ -95,9 +97,9 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:
weather_events += list(reader)

for weather_event in weather_events:
weather_event["human_readable"] = asyncio.run(
self.invoke_agent(weather_event)
)
result = asyncio.run(self.invoke_agent(weather_event))
weather_event["human_readable"] = self.parse_content(result)

logger.debug(f"{weather_event=}")

event = Event(
Expand All @@ -112,7 +114,7 @@ def stream_events(self, inputs: InputDefinition, ew: EventWriter) -> None:

logger.debug(f"Finishing enrichment for {input_name} at {csv_file_path}")

async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
async def invoke_agent(self, weather_event: dict[str, str | int]) -> AIMessage:
if not self.service:
raise AssertionError("No Splunk connection available")

Expand All @@ -127,7 +129,27 @@ async def invoke_agent(self, weather_event: dict[str, str | int]) -> str:
data=weather_event,
)
logger.debug(f"{response=}")
return response.final_message.content
return response.final_message

def _parse_content_block(self, block: str | ContentBlock) -> str | None:
match block:
case TextBlock():
return block.text
case str():
return block
case _:
return None

def parse_content(self, message: AIMessage) -> str:
"""Parses the content from AIMessage and builds a single string our of it"""
if isinstance(message.content, str):
return message.content

return " ".join(
parsed_block
for block in message.content
if (parsed_block := self._parse_content_block(block))
)


if __name__ == "__main__":
Expand Down
144 changes: 126 additions & 18 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@
AgentResponse,
AIMessage,
BaseMessage,
ContentBlock,
HumanMessage,
OpaqueBlock,
OutputT,
StructuredOutputCall,
StructuredOutputMessage,
Expand All @@ -87,6 +89,7 @@
SubagentStructuredResult,
SubagentTextResult,
SystemMessage,
TextBlock,
ToolCall,
ToolFailureResult,
ToolMessage,
Expand Down Expand Up @@ -955,7 +958,7 @@ async def awrap_tool_call(
return LC_ToolMessage(
name=_normalize_agent_name(call.name),
tool_call_id=call.id,
content=content,
content=_map_content_to_langchain(content),
status=status,
artifact=sdk_result,
)
Expand Down Expand Up @@ -1089,7 +1092,10 @@ def _convert_model_response_to_model_result(
# This invariant is asserted via ModelResponse.__post_init__
assert len(resp.message.structured_output_calls) <= 1

lc_message = LC_AIMessage(content=resp.message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(resp.message.content),
additional_kwargs=resp.message.extras or {},
)
# This field can't be set via __init__()
lc_message.tool_calls = [_map_tool_call_to_langchain(c) for c in resp.message.calls]

Expand Down Expand Up @@ -1164,7 +1170,7 @@ def _convert_tool_message_to_lc(
name=name,
tool_call_id=message.call_id,
status=status,
content=content,
content=_map_content_to_langchain(content),
artifact=artifact,
)

Expand Down Expand Up @@ -1247,9 +1253,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
ai_message = model_response
structured_response = None

additional_kwargs = cast(dict[str, Any], ai_message.additional_kwargs)
return ModelResponse(
message=AIMessage(
content=ai_message.content.__str__(),
content=_map_content_from_langchain(ai_message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in ai_message.tool_calls
Expand All @@ -1264,6 +1271,7 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe
for tc in ai_message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=additional_kwargs,
),
structured_output=structured_response,
)
Expand Down Expand Up @@ -1426,6 +1434,28 @@ def _is_agent_name_valid(name: str) -> bool:
return set(name).issubset(AGENT_NAME_ALLOWED_CHARS)


def _parse_content_block(block: str | ContentBlock) -> str | None:
match block:
case TextBlock():
return block.text
case str():
return block
case _:
return None


def _parse_content(content: str | list[str | ContentBlock]) -> str:
"""Parses the content from AIMessage and builds a single string our of it"""
if isinstance(content, str):
return content

return " ".join(
parsed_block
for block in content
if (parsed_block := _parse_content_block(block))
)


def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
if not agent.name:
raise AssertionError("Agent must have a name to be used by other Agents")
Expand All @@ -1437,7 +1467,10 @@ def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:

async def invoke_agent(
message: HumanMessage, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str,
SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke([message], thread_id=thread_id)

if agent.output_schema:
Expand All @@ -1446,23 +1479,28 @@ async def invoke_agent(
structured_output=result.structured_output.model_dump(),
)

return result.final_message.content, SubagentTextResult(
content=result.final_message.content
)
text_content = _parse_content(result.final_message.content)
return text_content, SubagentTextResult(content=text_content)

InputSchema = agent.input_schema
if InputSchema is None:
if agent.conversation_store:

async def _run( # pyright: ignore[reportRedeclaration]
content: str, thread_id: str
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str,
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), thread_id)
else:

async def _run( # pyright: ignore[reportRedeclaration]
content: str,
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str,
SubagentStructuredResult | SubagentTextResult,
]:
return await invoke_agent(HumanMessage(content=content), None)

return StructuredTool.from_function(
Expand All @@ -1475,7 +1513,10 @@ async def _run( # pyright: ignore[reportRedeclaration]

async def invoke_agent_structured(
content: BaseModel, thread_id: str | None
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str,
SubagentStructuredResult | SubagentTextResult,
]:
result = await agent.invoke_with_data(
instructions="Follow the system prompt.",
data=content.model_dump(),
Expand All @@ -1488,15 +1529,17 @@ async def invoke_agent_structured(
structured_output=result.structured_output.model_dump(),
)

return result.final_message.content, SubagentTextResult(
content=result.final_message.content
)
text_content = _parse_content(result.final_message.content)
return text_content, SubagentTextResult(content=text_content)

if agent.conversation_store:

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str,
SubagentStructuredResult | SubagentTextResult,
]:
content: BaseModel = kwargs["content"]
thread_id: str = kwargs["thread_id"]
return await invoke_agent_structured(content, thread_id)
Expand All @@ -1516,7 +1559,10 @@ async def _run(

async def _run(
**kwargs: Any, # noqa: ANN401
) -> tuple[OutputT | str, SubagentStructuredResult | SubagentTextResult]:
) -> tuple[
OutputT | str,
SubagentStructuredResult | SubagentTextResult,
]:
content = InputSchema(**kwargs)
return await invoke_agent_structured(content, None)

Expand Down Expand Up @@ -1568,11 +1614,69 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
return LC_ToolCall(id=call.id, name=name, args=args)


def _map_content_from_langchain(
content: str | list[str | dict[str, Any]],
) -> str | list[str | ContentBlock]:
if isinstance(content, str):
return content

result_content = [_map_content_block_from_langchain(b) for b in content]

return result_content


def _map_content_block_from_langchain(
block: str | dict[str, Any],
) -> str | ContentBlock:
if isinstance(block, str):
return block

match block.get("type"):
case "text":
return TextBlock(
text=block["text"], extras=block.get("extras"), id=block.get("id")
)
Comment thread
mateusz834 marked this conversation as resolved.
case _:
# NOTE: we return data we're not handling
# as opaque content blocks so they
# are preserved and sent back to the LLM
return OpaqueBlock(_data=block)


def _map_content_to_langchain(
content: str | list[str | ContentBlock],
) -> str | list[str | dict[str, Any]]:
if isinstance(content, str):
return content

result_content = [_map_content_block_to_langchain(b) for b in content]

return result_content


def _map_content_block_to_langchain(block: str | ContentBlock) -> str | dict[str, Any]:
if isinstance(block, str):
return block

match block:
case TextBlock():
result: dict[str, Any] = {
"type": "text",
"text": block.text,
"id": block.id,
}
if block.extras:
result["extras"] = block.extras
return result
case OpaqueBlock():
return block._data # pyright: ignore[reportPrivateUsage]


def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
match message:
case LC_AIMessage():
return AIMessage(
content=message.content.__str__(),
content=_map_content_from_langchain(message.content), # pyright: ignore[reportUnknownArgumentType]
calls=[
_map_tool_call_from_langchain(tc)
for tc in message.tool_calls
Expand All @@ -1587,6 +1691,7 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
for tc in message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
],
extras=cast(dict[str, Any], message.additional_kwargs),
)
case LC_HumanMessage():
return HumanMessage(content=message.content.__str__())
Expand All @@ -1601,7 +1706,10 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
match message:
case AIMessage():
lc_message = LC_AIMessage(content=message.content)
lc_message = LC_AIMessage(
content=_map_content_to_langchain(message.content),
additional_kwargs=message.extras or {},
)
# This field can't be set via constructor
lc_message.tool_calls = [
_map_tool_call_to_langchain(c) for c in message.calls
Expand Down
2 changes: 0 additions & 2 deletions splunklib/ai/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,3 @@ async def model_middleware(
if self._deadline is not None and monotonic() >= self._deadline:
raise TimeoutExceededException(timeout_seconds=self._seconds)
return await handler(request)


Loading