Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _normalize_model_kwargs(kwargs: dict) -> dict:
limit_requests=current_settings["chat_model_rl_requests"],
limit_input=current_settings["chat_model_rl_input"],
limit_output=current_settings["chat_model_rl_output"],
limit_concurrent=current_settings["chat_model_rl_concurrent"],
kwargs=_normalize_model_kwargs(current_settings["chat_model_kwargs"]),
)

Expand All @@ -51,6 +52,7 @@ def _normalize_model_kwargs(kwargs: dict) -> dict:
limit_requests=current_settings["util_model_rl_requests"],
limit_input=current_settings["util_model_rl_input"],
limit_output=current_settings["util_model_rl_output"],
limit_concurrent=current_settings["util_model_rl_concurrent"],
kwargs=_normalize_model_kwargs(current_settings["util_model_kwargs"]),
)
# embedding model from user settings
Expand All @@ -60,6 +62,8 @@ def _normalize_model_kwargs(kwargs: dict) -> dict:
name=current_settings["embed_model_name"],
api_base=current_settings["embed_model_api_base"],
limit_requests=current_settings["embed_model_rl_requests"],
limit_input=current_settings["embed_model_rl_input"],
limit_concurrent=current_settings["embed_model_rl_concurrent"],
kwargs=_normalize_model_kwargs(current_settings["embed_model_kwargs"]),
)
# browser model from user settings
Expand All @@ -69,6 +73,10 @@ def _normalize_model_kwargs(kwargs: dict) -> dict:
name=current_settings["browser_model_name"],
api_base=current_settings["browser_model_api_base"],
vision=current_settings["browser_model_vision"],
limit_requests=current_settings["browser_model_rl_requests"],
limit_input=current_settings["browser_model_rl_input"],
limit_output=current_settings["browser_model_rl_output"],
limit_concurrent=current_settings["browser_model_rl_concurrent"],
kwargs=_normalize_model_kwargs(current_settings["browser_model_kwargs"]),
)
# agent configuration
Expand Down
190 changes: 113 additions & 77 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TypedDict,
)

from litellm import completion, acompletion, embedding
from litellm import EmbeddingResponse, completion, acompletion, embedding
import litellm
import openai
from litellm.types.utils import ModelResponse
Expand Down Expand Up @@ -76,6 +76,7 @@ class ModelConfig:
limit_requests: int = 0
limit_input: int = 0
limit_output: int = 0
limit_concurrent: int = 0
vision: bool = False
kwargs: dict = field(default_factory=dict)

Expand Down Expand Up @@ -215,13 +216,14 @@ def get_api_key(service: str) -> str:


def get_rate_limiter(
provider: str, name: str, requests: int, input: int, output: int
provider: str, name: str, requests: int, input: int, output: int, concurrent: int = 0
) -> RateLimiter:
key = f"{provider}\\{name}"
rate_limiters[key] = limiter = rate_limiters.get(key, RateLimiter(seconds=60))
limiter.limits["requests"] = requests or 0
limiter.limits["input"] = input or 0
limiter.limits["output"] = output or 0
limiter.set_concurrent_limit(concurrent or 0)
return limiter


Expand Down Expand Up @@ -258,18 +260,20 @@ async def apply_rate_limiter(
) = None,
):
if not model_config:
return
return None, None
limiter = get_rate_limiter(
model_config.provider,
model_config.name,
model_config.limit_requests,
model_config.limit_input,
model_config.limit_output,
model_config.limit_concurrent,
)
limiter.add(input=approximate_tokens(input_text))
limiter.add(requests=1)
await limiter.wait(rate_limiter_callback)
return limiter
semaphore = await limiter.acquire(rate_limiter_callback)
return limiter, semaphore


def apply_rate_limiter_sync(
Expand All @@ -280,7 +284,7 @@ def apply_rate_limiter_sync(
) = None,
):
if not model_config:
return
return None, None
import asyncio, nest_asyncio

nest_asyncio.apply()
Expand Down Expand Up @@ -385,17 +389,21 @@ def _call(
msgs = self._convert_messages(messages)

# Apply rate limiting if configured
apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
limiter, semaphore = apply_rate_limiter_sync(self.a0_model_conf, str(msgs))

# Call the model
resp = completion(
model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs}
)
try:
# Call the model
resp = completion(
model=self.model_name, messages=msgs, stop=stop, **{**self.kwargs, **kwargs}
)

# Parse output
parsed = _parse_chunk(resp)
output = ChatGenerationResult(parsed).output()
return output["response_delta"]
# Parse output
parsed = _parse_chunk(resp)
output = ChatGenerationResult(parsed).output()
return output["response_delta"]
finally:
if limiter:
limiter.release(semaphore)

def _stream(
self,
Expand All @@ -409,26 +417,30 @@ def _stream(
msgs = self._convert_messages(messages)

# Apply rate limiting if configured
apply_rate_limiter_sync(self.a0_model_conf, str(msgs))
limiter, semaphore = apply_rate_limiter_sync(self.a0_model_conf, str(msgs))

result = ChatGenerationResult()
try:
result = ChatGenerationResult()

for chunk in completion(
model=self.model_name,
messages=msgs,
stream=True,
stop=stop,
**{**self.kwargs, **kwargs},
):
# parse chunk
parsed = _parse_chunk(chunk) # chunk parsing
output = result.add_chunk(parsed) # chunk processing

# Only yield chunks with non-None content
if output["response_delta"]:
yield ChatGenerationChunk(
message=AIMessageChunk(content=output["response_delta"])
)
for chunk in completion(
model=self.model_name,
messages=msgs,
stream=True,
stop=stop,
**{**self.kwargs, **kwargs},
):
# parse chunk
parsed = _parse_chunk(chunk) # chunk parsing
output = result.add_chunk(parsed) # chunk processing

# Only yield chunks with non-None content
if output["response_delta"]:
yield ChatGenerationChunk(
message=AIMessageChunk(content=output["response_delta"])
)
finally:
if limiter:
limiter.release(semaphore)

async def _astream(
self,
Expand All @@ -440,27 +452,31 @@ async def _astream(
msgs = self._convert_messages(messages)

# Apply rate limiting if configured
await apply_rate_limiter(self.a0_model_conf, str(msgs))
limiter, semaphore = await apply_rate_limiter(self.a0_model_conf, str(msgs))

result = ChatGenerationResult()
try:
result = ChatGenerationResult()

response = await acompletion(
model=self.model_name,
messages=msgs,
stream=True,
stop=stop,
**{**self.kwargs, **kwargs},
)
async for chunk in response: # type: ignore
# parse chunk
parsed = _parse_chunk(chunk) # chunk parsing
output = result.add_chunk(parsed) # chunk processing

# Only yield chunks with non-None content
if output["response_delta"]:
yield ChatGenerationChunk(
message=AIMessageChunk(content=output["response_delta"])
)
response = await acompletion(
model=self.model_name,
messages=msgs,
stream=True,
stop=stop,
**{**self.kwargs, **kwargs},
)
async for chunk in response: # type: ignore
# parse chunk
parsed = _parse_chunk(chunk) # chunk parsing
output = result.add_chunk(parsed) # chunk processing

# Only yield chunks with non-None content
if output["response_delta"]:
yield ChatGenerationChunk(
message=AIMessageChunk(content=output["response_delta"])
)
finally:
if limiter:
limiter.release(semaphore)

async def unified_call(
self,
Expand Down Expand Up @@ -490,10 +506,6 @@ async def unified_call(
# convert to litellm format
msgs_conv = self._convert_messages(messages, explicit_caching=explicit_caching)

# Apply rate limiting if configured
limiter = await apply_rate_limiter(
self.a0_model_conf, str(msgs_conv), rate_limiter_callback
)

# Prepare call kwargs and retry config (strip A0-only params before calling LiteLLM)
call_kwargs: dict[str, Any] = {**self.kwargs, **kwargs}
Expand All @@ -506,6 +518,10 @@ async def unified_call(

attempt = 0
while True:
# Apply rate limiting if configured
limiter, semaphore = await apply_rate_limiter(
self.a0_model_conf, str(msgs_conv), rate_limiter_callback
)
got_any_chunk = False
try:
# call model
Expand Down Expand Up @@ -570,6 +586,9 @@ async def unified_call(
raise
attempt += 1
await asyncio.sleep(retry_delay_s)
finally:
if limiter:
limiter.release(semaphore)


class AsyncAIChatReplacement:
Expand Down Expand Up @@ -625,7 +644,7 @@ async def _acall(
**kwargs: Any,
):
# Apply rate limiting if configured
apply_rate_limiter_sync(self._wrapper.a0_model_conf, str(messages))
limiter, semaphore = await apply_rate_limiter(self._wrapper.a0_model_conf, str(messages))

# Call the model
try:
Expand Down Expand Up @@ -655,6 +674,9 @@ async def _acall(

except Exception as e:
raise e
finally:
if limiter:
limiter.release(semaphore)

# another hack for browser-use post process invalid jsons
try:
Expand Down Expand Up @@ -685,21 +707,27 @@ def __init__(

def embed_documents(self, texts: List[str]) -> List[List[float]]:
# Apply rate limiting if configured
apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))

resp = embedding(model=self.model_name, input=texts, **self.kwargs)
return [
item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
for item in resp.data # type: ignore
]
limiter, semaphore = apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
try:
resp: EmbeddingResponse = embedding(model=self.model_name, input=texts, **self.kwargs)
return [
item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
for item in resp.data # type: ignore
]
finally:
if limiter:
limiter.release(semaphore)

def embed_query(self, text: str) -> List[float]:
# Apply rate limiting if configured
apply_rate_limiter_sync(self.a0_model_conf, text)

resp = embedding(model=self.model_name, input=[text], **self.kwargs)
item = resp.data[0] # type: ignore
return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
limiter, semaphore = apply_rate_limiter_sync(self.a0_model_conf, text)
try:
resp: EmbeddingResponse = embedding(model=self.model_name, input=[text], **self.kwargs)
item = resp.data[0]
return item.get("embedding") if isinstance(item, dict) else item.embedding # type: ignore
finally:
if limiter:
limiter.release(semaphore)


class LocalSentenceTransformerWrapper(Embeddings):
Expand Down Expand Up @@ -736,20 +764,28 @@ def __init__(

def embed_documents(self, texts: List[str]) -> List[List[float]]:
# Apply rate limiting if configured
apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))
limiter, semaphore = apply_rate_limiter_sync(self.a0_model_conf, " ".join(texts))

embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore
return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore
try:
embeddings = self.model.encode(texts, convert_to_tensor=False) # type: ignore
return embeddings.tolist() if hasattr(embeddings, "tolist") else embeddings # type: ignore
finally:
if limiter:
limiter.release(semaphore)

def embed_query(self, text: str) -> List[float]:
# Apply rate limiting if configured
apply_rate_limiter_sync(self.a0_model_conf, text)
limiter, semaphore = apply_rate_limiter_sync(self.a0_model_conf, text)

embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore
result = (
embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0]
)
return result # type: ignore
try:
embedding = self.model.encode([text], convert_to_tensor=False) # type: ignore
result = (
embedding[0].tolist() if hasattr(embedding[0], "tolist") else embedding[0]
)
return result # type: ignore
finally:
if limiter:
limiter.release(semaphore)


def _get_litellm_chat(
Expand Down
Loading