Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/tokenutil/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
"gpt-3.5",
"text-embedding-3",
"text-embedding-ada",
# Sluice policy aliases.
"auto",
"cheap-fast",
"cheap-reasoning",
"cheap-long-context",
"cheap-coding",
"premium",
"openrouter-free",
# Inference aggregators hosting OpenAI-compatible open-weight models
"groq",
"together",
Expand Down Expand Up @@ -97,7 +105,7 @@ def count_tokens_for_text(text: str, model: str) -> int:
# Optional SentencePiece path for Gemini
if "gemini" in model.lower():
try:
from tokenutil._sentencepiece import count_sp # type: ignore[import-not-found]
from tokenutil._sentencepiece import count_sp

return count_sp(text)
except ImportError:
Expand Down
42 changes: 42 additions & 0 deletions src/tokenutil/_sentencepiece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Optional SentencePiece backend for Gemini-family models.

This module is imported only when a Gemini model is requested and the
``tokenutil[gemini]`` extra is installed. The initial implementation uses a
generic SentencePieceProcessor path configured by environment variable because
Google's Gemini tokenizer model is not distributed by this package.
"""

from __future__ import annotations

import os
import warnings


def count_sp(text: str) -> int:
"""Count tokens with a SentencePiece model configured at runtime.

Set ``TOKENUTIL_SENTENCEPIECE_MODEL`` to the local ``.model`` file. If the
optional dependency or model file is missing, this function raises
``ImportError`` so the caller can fall back to the conservative heuristic.
"""
try:
import sentencepiece as spm # type: ignore[import-not-found]
except ImportError as exc: # pragma: no cover - optional dependency
raise ImportError("sentencepiece is not installed") from exc

model_path = os.getenv("TOKENUTIL_SENTENCEPIECE_MODEL", "")
if not model_path:
warnings.warn(
"tokenutil: TOKENUTIL_SENTENCEPIECE_MODEL is not set; "
"using fallback tokenizer for Gemini.",
UserWarning,
stacklevel=2,
)
raise ImportError("TOKENUTIL_SENTENCEPIECE_MODEL is not set")

try:
processor = spm.SentencePieceProcessor(model_file=model_path)
except Exception as exc:
raise ImportError("failed to load TOKENUTIL_SENTENCEPIECE_MODEL") from exc

return len(processor.encode(text, out_type=int))
27 changes: 21 additions & 6 deletions tests/test_count_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import warnings
from types import SimpleNamespace

import pytest

Expand Down Expand Up @@ -87,13 +88,9 @@ def test_responses_api_input_text_type() -> None:


def test_sluice_alias_cheap_fast_uses_cl100k() -> None:
# "cheap-fast" contains no known keyword; falls to heuristic BUT
# we just want a non-zero count without crash
text = "The quick brown fox jumps over the lazy dog."
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
n = count_tokens(text, model="cheap-fast")
assert n > 0
n = count_tokens(text, model="cheap-fast")
assert n == count_tokens(text, model="gpt-4")


def test_sluice_alias_groq_model() -> None:
Expand All @@ -109,6 +106,24 @@ def test_kimi_coding_alias() -> None:
assert n == count_tokens(text, model="gpt-4") # kimi → cl100k_base


def test_gemini_sentencepiece_load_error_falls_back(monkeypatch: pytest.MonkeyPatch) -> None:
class BrokenSentencePieceProcessor:
def __init__(self, model_file: str) -> None:
raise OSError(f"missing model: {model_file}")

monkeypatch.setenv("TOKENUTIL_SENTENCEPIECE_MODEL", "missing.model")
monkeypatch.setitem(
__import__("sys").modules,
"sentencepiece",
SimpleNamespace(SentencePieceProcessor=BrokenSentencePieceProcessor),
)

with pytest.warns(UserWarning, match="using 4 chars/token heuristic"):
n = count_tokens("a" * 400, model="gemini-flash")

assert n == 100


# ---------------------------------------------------------------------------
# Unknown model — should warn and return a positive heuristic count
# ---------------------------------------------------------------------------
Expand Down
Loading