Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/agents/react/react_using_mellea.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Simple tool for searching. Requires the langchain-community package.
# Mellea allows you to interop with langchain defined tools.
lc_ddg_search = DuckDuckGoSearchResults(output_format="list")
search_tool = MelleaTool.from_langchain(lc_ddg_search)
search_tool: MelleaTool = MelleaTool.from_langchain(lc_ddg_search)


class Email(pydantic.BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/tools/smolagents_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
python_tool_hf = PythonInterpreterTool()

# Convert to Mellea tool - now you can use it with Mellea!
python_tool = MelleaTool.from_smolagents(python_tool_hf)
python_tool: MelleaTool = MelleaTool.from_smolagents(python_tool_hf)

# Use with Mellea session
m = start_session()
Expand Down
67 changes: 41 additions & 26 deletions mellea/backends/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
from collections import defaultdict
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from typing import Any, Literal, overload
from typing import Any, Literal, ParamSpec, TypeVar, overload

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Generic is imported but not used — MelleaTool picks up genericity via AbstractMelleaTool[P, R]. Project ignores F401 so CI won't complain, but it's dead code.

Suggested change
from typing import Any, Literal, ParamSpec, TypeVar, overload

from pydantic import BaseModel, ConfigDict, Field

Expand All @@ -22,16 +22,23 @@
from ..core.base import AbstractMelleaTool
from .model_options import ModelOption

P = ParamSpec("P")
R = TypeVar("R")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Both mellea/core/base.py (lines 958-959) and this module declare their own P = ParamSpec("P") / R = TypeVar("R"). Since MelleaTool subclasses AbstractMelleaTool, tools.py could import P, R from base.py for a single identity. Minor.


class MelleaTool(AbstractMelleaTool):

class MelleaTool(AbstractMelleaTool[P, R]):
"""Tool class to represent a callable tool with an OpenAI-compatible JSON schema.

Wraps a Python callable alongside its JSON schema representation so it can be
registered with backends that support tool calling (OpenAI, Ollama, HuggingFace, etc.).

Type parameters:
P: Parameter specification for the underlying callable
R: Return type of the tool

Args:
name (str): The tool name used for identification and lookup.
tool_call (Callable): The underlying Python callable to invoke when the tool is run.
tool_call (Callable[P, R]): The underlying Python callable to invoke when the tool is run.
as_json_tool (dict[str, Any]): The OpenAI-compatible JSON schema dict describing
the tool's parameters.

Expand All @@ -42,25 +49,25 @@ class MelleaTool(AbstractMelleaTool):

name: str
_as_json_tool: dict[str, Any]
_call_tool: Callable[..., Any]
_call_tool: Callable[P, R]

def __init__(
self, name: str, tool_call: Callable, as_json_tool: dict[str, Any]
self, name: str, tool_call: Callable[P, R], as_json_tool: dict[str, Any]
) -> None:
"""Initialize the tool with a name, tool call and as_json_tool dict."""
self.name = name
self._as_json_tool = as_json_tool
self._call_tool = tool_call

def run(self, *args, **kwargs) -> Any:
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Run the tool with the given arguments.

Args:
args: Positional arguments forwarded to the underlying callable.
kwargs: Keyword arguments forwarded to the underlying callable.
*args: Positional arguments forwarded to the underlying callable.
**kwargs: Keyword arguments forwarded to the underlying callable.

Returns:
Any: The return value of the underlying callable.
R: The return value of the underlying callable.
"""
return self._call_tool(*args, **kwargs)

Expand All @@ -70,14 +77,14 @@ def as_json_tool(self) -> dict[str, Any]:
return self._as_json_tool.copy()

@classmethod
def from_langchain(cls, tool: Any) -> "MelleaTool":
def from_langchain(cls, tool: Any) -> "MelleaTool[..., Any]":
"""Create a MelleaTool from a LangChain tool object.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: MelleaTool[P, Any] with a module-level P collapses to MelleaTool[..., Any] under pyright and MelleaTool[Any, Any] under mypy — both type-check fine, but spelling it as MelleaTool[..., Any] is the explicit form for "unknown parameter spec" and matches what pyright shows. Worth updating the docstring Returns: line too.

Suggested change
"""Create a MelleaTool from a LangChain tool object.
def from_langchain(cls, tool: Any) -> "MelleaTool[..., Any]":


Args:
tool (Any): A ``langchain_core.tools.BaseTool`` instance to wrap.

Returns:
MelleaTool: A Mellea tool wrapping the LangChain tool.
MelleaTool[..., Any]: A Mellea tool wrapping the LangChain tool.

Raises:
ImportError: If ``langchain-core`` is not installed.
Expand Down Expand Up @@ -117,14 +124,14 @@ def parameter_remapper(*args, **kwargs):
) from e

@classmethod
def from_smolagents(cls, tool: Any) -> "MelleaTool":
def from_smolagents(cls, tool: Any) -> "MelleaTool[..., Any]":
"""Create a Tool from a HuggingFace smolagents tool object.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Same as from_langchainMelleaTool[..., Any] is the idiomatic spelling for "unknown parameter spec". Worth updating the docstring Returns: line too.

Suggested change
"""Create a Tool from a HuggingFace smolagents tool object.
def from_smolagents(cls, tool: Any) -> "MelleaTool[..., Any]":


Args:
tool: A smolagents.Tool instance

Returns:
MelleaTool: A Mellea tool wrapping the smolagents tool
MelleaTool[..., Any]: A Mellea tool wrapping the smolagents tool

Raises:
ImportError: If smolagents is not installed
Expand Down Expand Up @@ -172,18 +179,20 @@ def tool_call(*args, **kwargs):
) from e

@classmethod
def from_callable(cls, func: Callable, name: str | None = None) -> "MelleaTool":
def from_callable(
cls, func: Callable[P, R], name: str | None = None
) -> "MelleaTool[P, R]":
"""Create a MelleaTool from a plain Python callable.

Introspects the callable's signature and docstring to build an
OpenAI-compatible JSON schema automatically.

Args:
func (Callable): The Python callable to wrap as a tool.
func (Callable[P, R]): The Python callable to wrap as a tool.
name (str | None): Optional name override; defaults to ``func.__name__``.

Returns:
MelleaTool: A Mellea tool wrapping the callable.
MelleaTool[P, R]: A Mellea tool wrapping the callable with preserved parameter and return types.
"""
# Use the function name if the name is '' or None.
tool_name = name or func.__name__
Expand All @@ -195,28 +204,34 @@ def from_callable(cls, func: Callable, name: str | None = None) -> "MelleaTool":


@overload
def tool(func: Callable, *, name: str | None = None) -> MelleaTool: ...
def tool(func: Callable[P, R], *, name: str | None = None) -> MelleaTool[P, R]: ...


@overload
def tool(*, name: str | None = None) -> Callable[[Callable], MelleaTool]: ...
def tool(
*, name: str | None = None
) -> Callable[[Callable[P, R]], MelleaTool[P, R]]: ...


def tool(
func: Callable | None = None, name: str | None = None
) -> MelleaTool | Callable[[Callable], MelleaTool]:
"""Decorator to mark a function as a Mellea tool.
func: Callable[P, R] | None = None, name: str | None = None
) -> MelleaTool[P, R] | Callable[[Callable[P, R]], MelleaTool[P, R]]:
"""Decorator to mark a function as a Mellea tool with type-safe parameter and return types.

This decorator wraps a function to make it usable as a tool without
requiring explicit MelleaTool.from_callable() calls. The decorated
function returns a MelleaTool instance that must be called via .run().

Type parameters:
P: Parameter specification of the decorated function
R: Return type of the decorated function

Args:
func: The function to decorate (when used without arguments)
name: Optional custom name for the tool (defaults to function name)

Returns:
A MelleaTool instance. Use .run() to invoke the tool.
A MelleaTool[P, R] instance with preserved parameter and return types. Use .run() to invoke.
The returned object passes isinstance(result, MelleaTool) checks.

Examples:
Expand All @@ -237,8 +252,8 @@ def tool(
>>> # Can be used directly in tools list (no extraction needed)
>>> tools = [get_weather]
>>>
>>> # Must use .run() to invoke the tool
>>> result = get_weather.run(location="Boston")
>>> # Must use .run() to invoke the tool - now with type hints
>>> result = get_weather.run(location="Boston") # IDE shows: location: str, days: int = 1

With custom name (as decorator):
>>> @tool(name="weather_api")
Expand All @@ -252,8 +267,8 @@ def tool(
>>> differently_named_tool = tool(new_tool, name="different_name")
"""

def decorator(f: Callable) -> MelleaTool:
# Simply return the base MelleaTool instance
def decorator(f: Callable[P, R]) -> MelleaTool[P, R]:
# Simply return the base MelleaTool instance with preserved types
return MelleaTool.from_callable(f, name=name)

# Handle both @tool and @tool() syntax
Expand Down
26 changes: 21 additions & 5 deletions mellea/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
from copy import copy, deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Generic, Literal, Protocol, TypeVar, runtime_checkable
from typing import (
Any,
Generic,
Literal,
ParamSpec,
Protocol,
TypeVar,
runtime_checkable,
)

import typing_extensions
from PIL import Image as PILImage
Expand Down Expand Up @@ -947,8 +955,16 @@ def view_for_generation(self) -> list[Component | CBlock] | None:
...


class AbstractMelleaTool(abc.ABC):
"""Abstract base class for Mellea Tool.
P = ParamSpec("P")
R = TypeVar("R")


class AbstractMelleaTool(abc.ABC, Generic[P, R]):
"""Abstract base class for Mellea Tool with parameter and return type support.

Type parameters:
P: Parameter specification for the tool's callable (via ParamSpec)
R: Return type of the tool

Attributes:
name (str): The unique name used to identify the tool in JSON descriptions and tool-call dispatch.
Expand All @@ -960,15 +976,15 @@ class AbstractMelleaTool(abc.ABC):
"""Name of the tool."""

@abc.abstractmethod
def run(self, *args: Any, **kwargs: Any) -> Any:
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the tool with the provided arguments and returns the result.

Args:
*args: Positional arguments forwarded to the tool implementation.
**kwargs: Keyword arguments forwarded to the tool implementation.

Returns:
Any: The result produced by the tool; the concrete type depends on the implementation.
R: The result produced by the tool; the concrete type depends on the implementation.
"""

@property
Expand Down
2 changes: 1 addition & 1 deletion test/backends/test_mellea_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_from_langchain_args_handling(caplog):
@pytest.mark.ollama
@pytest.mark.e2e
def test_from_langchain_generation(session: MelleaSession):
t = MelleaTool.from_langchain(langchain_tool)
t: MelleaTool = MelleaTool.from_langchain(langchain_tool)

out = session.instruct(
"Call the langchain_tool.",
Expand Down
Loading
Loading