forked from i-am-bee/beeai-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcustom_agent.py
More file actions
107 lines (82 loc) · 3.33 KB
/
custom_agent.py
File metadata and controls
107 lines (82 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import asyncio
import sys
import traceback
from pydantic import BaseModel, Field, InstanceOf
from beeai_framework.adapters.ollama import OllamaChatModel
from beeai_framework.agents import AgentMeta, BaseAgent, BaseAgentRunOptions
from beeai_framework.backend import AnyMessage, AssistantMessage, ChatModel, SystemMessage, UserMessage
from beeai_framework.context import Run, RunContext
from beeai_framework.emitter import Emitter
from beeai_framework.errors import FrameworkError
from beeai_framework.memory import BaseMemory, UnconstrainedMemory
class State(BaseModel):
thought: str
final_answer: str
class RunInput(BaseModel):
message: InstanceOf[AnyMessage]
class CustomAgentRunOptions(BaseAgentRunOptions):
max_retries: int | None = None
class CustomAgentRunOutput(BaseModel):
message: InstanceOf[AnyMessage]
state: State
class CustomAgent(BaseAgent[CustomAgentRunOutput]):
memory: BaseMemory | None = None
def __init__(self, llm: ChatModel, memory: BaseMemory) -> None:
super().__init__()
self.model = llm
self.memory = memory
def _create_emitter(self) -> Emitter:
return Emitter.root().child(
namespace=["agent", "custom"],
creator=self,
)
def run(
self,
run_input: RunInput,
options: CustomAgentRunOptions | None = None,
) -> Run[CustomAgentRunOutput]:
async def handler(context: RunContext) -> CustomAgentRunOutput:
class CustomSchema(BaseModel):
thought: str = Field(description="Describe your thought process before coming with a final answer")
final_answer: str = Field(
description="Here you should provide concise answer to the original question."
)
response = await self.model.create_structure(
schema=CustomSchema,
messages=[
SystemMessage("You are a helpful assistant. Always use JSON format for your responses."),
*(self.memory.messages if self.memory is not None else []),
run_input.message,
],
max_retries=options.max_retries if options else None,
abort_signal=context.signal,
)
result = AssistantMessage(response.object["final_answer"])
await self.memory.add(result) if self.memory else None
return CustomAgentRunOutput(
message=result,
state=State(thought=response.object["thought"], final_answer=response.object["final_answer"]),
)
return self._to_run(
handler, signal=options.signal if options else None, run_params={"input": run_input, "options": options}
)
@property
def meta(self) -> AgentMeta:
return AgentMeta(
name="CustomAgent",
description="Custom Agent is a simple LLM agent.",
tools=[],
)
async def main() -> None:
agent = CustomAgent(
llm=OllamaChatModel("granite3.1-dense:8b"),
memory=UnconstrainedMemory(),
)
response = await agent.run(RunInput(message=UserMessage("Why is the sky blue?")))
print(response.state)
if __name__ == "__main__":
try:
asyncio.run(main())
except FrameworkError as e:
traceback.print_exc()
sys.exit(e.explain())