Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/examples/bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ uv pip install mellea[litellm]

```python
from mellea import MelleaSession
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend
from mellea.backends.model_ids import OPENAI_GPT_OSS_120B
from mellea.stdlib.context import ChatContext

bedrock_oai_backend = create_bedrock_mantle_backend(model_id=OPENAI_GPT_OSS_120B, region="us-east-1")
bedrock_oai_backend = create_bedrock_openai_backend(model_id=OPENAI_GPT_OSS_120B, region="us-east-1")

m = MelleaSession(backend=bedrock_oai_backend, ctx=ChatContext())

Expand All @@ -38,10 +38,10 @@ You can also use your own model IDs as strings, as long as they're accessible us

```python
from mellea import MelleaSession
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend
from mellea.stdlib.context import ChatContext

bedrock_oai_backend = create_bedrock_mantle_backend(
bedrock_oai_backend = create_bedrock_openai_backend(
model_id="qwen.qwen3-coder-480b-a35b-instruct",
region="us-east-1"
)
Expand Down
17 changes: 8 additions & 9 deletions docs/examples/bedrock/bedrock_litellm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import os

import mellea
from mellea.backends.bedrock import create_bedrock_litellm_backend
from mellea.backends.model_ids import MISTRALAI_DEVSTRAL_2_123B
from mellea.stdlib.context import SimpleContext

try:
import boto3
Expand All @@ -20,16 +23,12 @@
"Run `uv pip install mellea[litellm]`"
)

assert "AWS_BEARER_TOKEN_BEDROCK" in os.environ.keys(), (
"Using AWS Bedrock requires setting a AWS_BEARER_TOKEN_BEDROCK environment variable. "
"Generate a key from the AWS console at: https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/api-keys?tab=long-term "
"Then run `export AWS_BEARER_TOKEN_BEDROCK=<insert your key here>"
)
MODEL_ID = MISTRALAI_DEVSTRAL_2_123B

MODEL_ID = "bedrock/converse/us.amazon.nova-pro-v1:0"
backend = create_bedrock_litellm_backend(MODEL_ID)
ctx = SimpleContext()
m = mellea.MelleaSession(backend, ctx)

m = mellea.start_session(backend_name="litellm", model_id=MODEL_ID)

result = m.chat("Give me three facts about Amazon.")
result = m.chat("What model am I talking to rn?")

print(result.content)
7 changes: 4 additions & 3 deletions docs/examples/bedrock/bedrock_openai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from mellea import MelleaSession
from mellea.backends import model_ids
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend
from mellea.backends.model_ids import OPENAI_GPT_OSS_120B
from mellea.backends.openai import OpenAIBackend
from mellea.stdlib.context import ChatContext

Expand All @@ -22,10 +23,10 @@
)

m = MelleaSession(
backend=create_bedrock_mantle_backend(model_id=model_ids.OPENAI_GPT_OSS_120B),
backend=create_bedrock_openai_backend(model_id=OPENAI_GPT_OSS_120B),
ctx=ChatContext(),
)

result = m.chat("Give me three facts about Amazon.")
result = m.chat("What model am I talking to rn?")

print(result.content)
85 changes: 84 additions & 1 deletion mellea/backends/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,64 @@

import os

import logging

import boto3
import botocore.exceptions

# botocore logs a credential-resolution message on every boto3.Session() call. Suppress it.
logging.getLogger("botocore.credentials").setLevel(logging.WARNING)
from openai import OpenAI
from openai.pagination import SyncPage

from mellea.backends.litellm import LiteLLMBackend
from mellea.backends.model_ids import ModelIdentifier
from mellea.backends.openai import OpenAIBackend


def _assert_region(region: str | None) -> str:
resolved_region = (
region
or os.environ.get("AWS_REGION_NAME")
or os.environ.get("AWS_DEFAULT_REGION")
or os.environ.get("AWS_REGION")
)
assert (
resolved_region is not None
), "you must specify a region: pass `region` explicitly or set AWS_REGION_NAME, AWS_DEFAULT_REGION, or AWS_REGION."


def _assert_bedrock_auth() -> None:
"""Raises if no valid AWS credentials can be resolved.

Accepts any credential source that boto3 supports:
- Static env vars (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY)
- Named profile (AWS_PROFILE or ~/.aws/credentials)
- ECS task role (AWS_CONTAINER_CREDENTIALS_RELATIVE_URI)
- EC2 / ECS instance profile (IMDSv2)
- LiteLLM-specific Bedrock API key (AWS_BEARER_TOKEN_BEDROCK)
"""
if "AWS_BEARER_TOKEN_BEDROCK" in os.environ:
return

try:
creds = boto3.Session().get_credentials()
if creds is None:
raise botocore.exceptions.NoCredentialsError()
# Resolve to catch expired/invalid assume-role chains early.
creds.get_frozen_credentials()
except botocore.exceptions.NoCredentialsError:
raise AssertionError(
"No AWS credentials found. Provide one of:\n"
" - AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)\n"
" - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY\n"
" - AWS_PROFILE pointing to a configured profile\n"
" - An IAM role attached to the instance/task (EC2, ECS, Lambda)"
)
except botocore.exceptions.NoRegionError:
pass # Credentials exist; region is validated separately.


def _make_region_for_uri(region: str | None):
if region is None:
region = "us-east-1"
Expand Down Expand Up @@ -53,7 +104,39 @@ def stringify_mantle_model_ids(region: str | None = None) -> str:
return f" * {model_names}"


def create_bedrock_mantle_backend(
def create_bedrock_litellm_backend(
model_id: ModelIdentifier | str, region: str | None = None, num_retries: int = 15
) -> LiteLLMBackend:
"""Returns a LiteLLM backend that points to Bedrock for model `model_id`.

Use this instead of `create_bedrock_openai_backend` when you need auth with an AWS_ACCESS_KEY_ID.
"""
_assert_bedrock_auth()
_assert_region(region)

model_name = ""
match model_id:
case ModelIdentifier():
if model_id.bedrock_litellm_name is None:
raise Exception(
f"We do not have a known bedrock model identifier for {model_id}. If Bedrock supports this model, please pass the model_id string directly and open an issue to add the model id: https://github.com/generative-computing/mellea/issues/new"
)
else:
model_name = model_id.bedrock_litellm_name
case str():
model_name = model_id
assert (
model_name != ""
), f"Model identifier {model_id} does not specify a bedrock_name."

backend = LiteLLMBackend(model_id=model_name, num_retries=num_retries)

# TODO litellm doesn't even appear to use this...?
backend._base_url = None # type: ignore
return backend


def create_bedrock_openai_backend(
model_id: ModelIdentifier | str, region: str | None = None
) -> OpenAIBackend:
"""Return an OpenAI backend that points to Bedrock mantle for the given model.
Expand Down
79 changes: 61 additions & 18 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'Please install them with: pip install "mellea[litellm]"'
) from e


from ..backends import model_ids
from ..core import (
BaseModelSubclass,
Expand Down Expand Up @@ -55,7 +56,9 @@
validate_tool_arguments,
)

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
format: None = (
None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
)


class LiteLLMBackend(FormatterBackend):
Expand Down Expand Up @@ -84,6 +87,7 @@ def __init__(
formatter: ChatFormatter | None = None,
base_url: str | None = "http://localhost:11434",
model_options: dict | None = None,
num_retries: int = 0,
):
"""Initialize a LiteLLM-compatible backend for the given model ID and endpoint."""
super().__init__(
Expand All @@ -104,6 +108,8 @@ def __init__(
else:
self._base_url = base_url

self._num_retries = num_retries

# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
# These are usually values that must be extracted before hand or that are common among backend providers.
# OpenAI has some deprecated parameters. Those map to the same mellea parameter, but
Expand Down Expand Up @@ -164,6 +170,7 @@ async def _generate_from_context(
assert ctx.is_chat_context, NotImplementedError(
"The Openai backend only supports chat-like contexts."
)

span = start_generate_span(
backend=self, action=action, ctx=ctx, format=format, tool_calls=tool_calls
)
Expand Down Expand Up @@ -260,10 +267,20 @@ def _make_backend_specific_and_remove(
# We want to flag both for the end user.
standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
unknown_keys = [] # Keys that are unknown to litellm.
unsupported_openai_params = [] # OpenAI params that are known to litellm but not supported for this model/provider.
unsupported_openai_params = (
[]
) # OpenAI params that are known to litellm but not supported for this model/provider.
# Bedrock-specific pass-through params that LiteLLM accepts but doesn't list as supported OpenAI params.
known_provider_passthrough = {
"additional_model_request_fields",
"additional_model_response_field_paths",
}

for key in backend_specific.keys():
if key not in supported_params:
if key in standard_openai_subset:
if key in known_provider_passthrough:
pass # Expected provider-specific params; no warning needed.
elif key in standard_openai_subset:
# LiteLLM is pretty confident that this standard OpenAI parameter won't work.
unsupported_openai_params.append(key)
else:
Expand All @@ -288,18 +305,19 @@ async def _generate_from_chat_context_standard(
action: Component[C] | CBlock,
ctx: Context,
*,
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
_format: (
type[BaseModelSubclass] | None
) = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
tool_calls: bool = False,
) -> ModelOutputThunk[C]:
await self.do_generate_walk(action)

model_opts = self._simplify_and_merge(model_options)
linearized_context = ctx.view_for_generation()
assert linearized_context is not None, (
"Cannot generate from a non-linear context in a FormatterBackend."
)
assert (
linearized_context is not None
), "Cannot generate from a non-linear context in a FormatterBackend."
# Convert our linearized context into a sequence of chat messages. Template formatters have a standard way of doing this.
messages: list[Message] = self.formatter.to_chat_messages(linearized_context)

Expand Down Expand Up @@ -361,6 +379,7 @@ async def _generate_from_chat_context_standard(
tools=formatted_tools,
reasoning_effort=thinking, # type: ignore
drop_params=True, # See note in `_make_backend_specific_and_remove`.
num_retries=self._num_retries,
**extra_params,
**model_specific_options,
)
Expand Down Expand Up @@ -441,6 +460,16 @@ async def processing(
if content_chunk is not None:
mot._underlying_value += content_chunk

if getattr(choice, "logprobs", None) is not None:
mot._meta["logprobs"] = choice.logprobs

# In some cases (converse API) Bedrock returns logprobs via additionalModelResponseFields.
additional_fields = getattr(chunk, "model_extra", {}) or {}
if "additionalModelResponseFields" in additional_fields:
mot._meta["additionalModelResponseFields"] = additional_fields[
"additionalModelResponseFields"
]

# Store the full response (includes usage) as a dict
mot._meta["litellm_full_response"] = chunk.model_dump()
# Also store just the choice for backward compatibility
Expand All @@ -459,6 +488,12 @@ async def processing(
if content_chunk is not None:
mot._underlying_value += content_chunk

stream_logprobs = getattr(chunk.choices[0], "logprobs", None)
if stream_logprobs is not None:
if "logprobs" not in mot._meta:
mot._meta["logprobs"] = []
mot._meta["logprobs"].append(stream_logprobs)
Comment on lines +491 to +495
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajbozarth can you chime in on whether this something we should add a mot field for instead of putting in _meta (cf GenerationMetadata)?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took some time to figure out exactly what I wanted to say, but I think I got Claude to summarize my thoughts correctly:


The pattern here would be mot.logprobs rather than adding it into GenerationMetadata
(which holds execution metadata: usage, model, provider, etc.). I'd lean toward
keeping it in _meta for now for two reasons:

  1. Type isn't settled. In streaming it accumulates as a list of per-chunk objects; in
    non-streaming it's a single object. A public field needs a consistent, defined type.
  2. Coverage is narrow. Logprobs are only returned when explicitly requested, and only
    by some backends. The existing public fields on mot are things every backend always
    populates.

If logprobs support expands to other backends, that would be the right time to define the
type and promote it to a real field — the same process the telemetry fields went through
before landing in GenerationMetadata.



if mot._meta.get("litellm_chat_response_streamed", None) is None:
mot._meta["litellm_chat_response_streamed"] = []
mot._meta["litellm_chat_response_streamed"].append(
Expand Down Expand Up @@ -493,6 +528,7 @@ async def post_processing(
_format: The structured output format class used during generation, if any.
"""
# Reconstruct the chat_response from chunks if streamed.

streamed_chunks = mot._meta.get("litellm_chat_response_streamed", None)
if streamed_chunks is not None:
# Must handle ollama differently due to: https://github.com/BerriAI/litellm/issues/14579.
Expand All @@ -504,19 +540,20 @@ async def post_processing(
streamed_chunks, force_all_tool_calls_separate=separate_tools
)

assert mot._action is not None, (
"ModelOutputThunks should have their action assigned during generation"
)
assert mot._model_options is not None, (
"ModelOutputThunks should have their model_opts assigned during generation"
)
assert (
mot._action is not None
), "ModelOutputThunks should have their action assigned during generation"
assert (
mot._model_options is not None
), "ModelOutputThunks should have their model_opts assigned during generation"

# OpenAI-like streamed responses potentially give you chunks of tool calls.
# As a result, we have to store data between calls and only then
# check for complete tool calls in the post_processing step.
tool_chunk = extract_model_tool_requests(
tools, mot._meta["litellm_chat_response"]
)

if tool_chunk is not None:
if mot.tool_calls is None:
mot.tool_calls = {}
Expand All @@ -539,6 +576,7 @@ async def post_processing(
}
generate_log.action = mot._action
generate_log.result = mot

mot._generate_log = generate_log

# Extract token usage from full response dict or streaming usage
Expand Down Expand Up @@ -674,7 +712,10 @@ async def generate_from_raw(
prompts = [self.formatter.print(action) for action in actions]

completion_response = await litellm.atext_completion(
model=self._model_id, prompt=prompts, **model_specific_options
model=self._model_id,
prompt=prompts,
num_retries=self._num_retries,
**model_specific_options,
)

# Necessary for type checker.
Expand All @@ -695,9 +736,11 @@ async def generate_from_raw(
output._model_options = model_opts
output._meta = {
"litellm_chat_response": res.model_dump(),
"usage": completion_response.usage.model_dump()
if completion_response.usage
else None,
"usage": (
completion_response.usage.model_dump()
if completion_response.usage
else None
),
}

output.parsed_repr = (
Expand Down
Loading
Loading