diff --git a/.github/workflows/build-binaries.yml b/.github/workflows/build-binaries.yml index 2454d349..365af733 100644 --- a/.github/workflows/build-binaries.yml +++ b/.github/workflows/build-binaries.yml @@ -263,7 +263,7 @@ jobs: githubToken: ${{ github.token }} install: | apt-get update - apt-get install -y --no-install-recommends python3 python3-pip python3-venv cargo + apt-get install -y --no-install-recommends python3 python3-pip python3-venv python3-dev cargo pip3 install -U pip # Create and use a virtual environment to avoid the externally-managed-environment error run: | @@ -307,7 +307,7 @@ jobs: image: alpine:latest options: -v ${{ github.workspace }}:/io -w /io run: | - apk add python3 py3-pip rust + apk add python3 python3-dev py3-pip rust python -m venv .venv .venv/bin/pip3 install dist/${{ env.PACKAGE_NAME }}-*.whl --force-reinstall .venv/bin/${{ env.EXECUTABLE_NAME }} --help @@ -348,7 +348,7 @@ jobs: distro: alpine_latest githubToken: ${{ github.token }} install: | - apk add python3 py3-pip rust + apk add python3 python3-dev py3-pip rust run: | python -m venv .venv .venv/bin/pip3 install dist/${{ env.PACKAGE_NAME }}-*.whl --force-reinstall diff --git a/.gitignore b/.gitignore index ce32d891..06be6620 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,10 @@ *.pyc __pycache__ + +# may contain sensitive data +pytest.ini + # # Artifacts from the Rust client generation process # diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9b388533 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/README.md b/README.md index 321a3513..7a064da4 100644 --- a/README.md +++ b/README.md @@ -189,3 +189,7 @@ uv run pytest tests If you need to get the latest OpenAPI SDK, you can run `./scripts/generate-python-api-client.sh`. + +## Testing +We use pytest to run tests. Copy `pytest.ini.template` to `pytest.ini` and +replace the values of environment variables \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5d84ebbd..08ddd708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,11 +40,14 @@ classifiers = [ dependencies = [ "attrs==24.2.0", "httpx==0.28.1", + "huggingface-hub>=0.34.3", + "ollama>=0.4.7", + "pyiceberg==0.9.0", "python-dateutil==2.9.0.post0", ] [project.optional-dependencies] -ai = ["huggingface-hub==0.30.2", "ollama==0.4.7"] +ai = ["huggingface-hub==0.34.3", "ollama==0.4.7"] iceberg = ["polars==1.27.1", "pyarrow==19.0.1", "pyiceberg==0.9.0"] all = ["tower[ai,iceberg]"] @@ -65,5 +68,6 @@ dev = [ "openapi-python-client==0.24.3", "pytest==8.3.5", "pytest-httpx==0.35.0", + "pytest-env>=1.1.3", "pyiceberg[sql-sqlite]==0.9.0", ] diff --git a/pytest.ini.template b/pytest.ini.template new file mode 100644 index 00000000..d00fc786 --- /dev/null +++ b/pytest.ini.template @@ -0,0 +1,12 @@ +[pytest] +pythonpath = src +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning +env = + TOWER_INFERENCE_ROUTER_API_KEY= diff --git a/src/tower/_context.py b/src/tower/_context.py index 995dc9f2..51c63697 100644 --- a/src/tower/_context.py +++ b/src/tower/_context.py @@ -1,13 +1,16 @@ import os class TowerContext: - def __init__(self, tower_url: str, environment: str, api_key: str = None, hugging_face_provider: str = None, hugging_face_api_key: str = None, jwt: str = None): + def __init__(self, tower_url: str, environment: str, api_key: str = None, + inference_router: str = None, inference_router_api_key: str = None, + inference_provider: str = None, jwt: str = None): self.tower_url = tower_url self.environment = environment self.api_key = api_key self.jwt = jwt - self.hugging_face_provider = hugging_face_provider - self.hugging_face_api_key = hugging_face_api_key + self.inference_router = inference_router + self.inference_router_api_key = inference_router_api_key + self.inference_provider = inference_provider def is_local(self) -> bool: if self.environment is None or self.environment == "": @@ -24,17 +27,18 @@ def build(cls): tower_api_key = os.getenv("TOWER_API_KEY") tower_jwt = os.getenv("TOWER_JWT") - # NOTE: These are experimental, used only for our experimental Hugging - # Face integration for LLMs. - hugging_face_provider = os.getenv("TOWER_HUGGING_FACE_PROVIDER") - hugging_face_api_key = os.getenv("TOWER_HUGGING_FACE_API_KEY") + # Replaces the deprecated hugging_face_provider and hugging_face_api_key + inference_router = os.getenv("TOWER_INFERENCE_ROUTER") + inference_router_api_key = os.getenv("TOWER_INFERENCE_ROUTER_API_KEY") + inference_provider = os.getenv("TOWER_INFERENCE_PROVIDER") return cls( tower_url = tower_url, environment = tower_environment, api_key = tower_api_key, - hugging_face_provider = hugging_face_provider, - hugging_face_api_key = hugging_face_api_key, + inference_router = inference_router, + inference_router_api_key = inference_router_api_key, + inference_provider = inference_provider, jwt = tower_jwt, ) diff --git a/src/tower/_llms.py b/src/tower/_llms.py index 34e68a9c..f331324d 100644 --- a/src/tower/_llms.py +++ b/src/tower/_llms.py @@ -3,78 +3,474 @@ from ollama import chat, pull from ollama import ChatResponse from ollama import ResponseError +from ollama import list as ollama_list_models from huggingface_hub import InferenceClient, ChatCompletionOutput +from huggingface_hub import HfApi +from huggingface_hub.utils import RepositoryNotFoundError from ._context import TowerContext -""" -OLLAMA_MODELS and HUGGING_FACE_MODELS are dictionaries that map published model -names to the internal names used by Tower when routing LLM requests to the -underlying provider. -""" -OLLAMA_MODELS = { - "deepseek-r1": "deepseek-r1:14b", -} +# TODO: add vllm back in when we have a way to use it +LOCAL_INFERENCE_ROUTERS = [ + "ollama", +] -HUGGING_FACE_MODELS = { - "deepseek-r1": "deepseek-ai/DeepSeek-R1", -} +INFERENCE_ROUTERS = LOCAL_INFERENCE_ROUTERS + [ + "hugging_face_hub" +] -def extract_model_name(ctx: TowerContext, supported_model: str) -> str: +RAW_MODEL_FAMILIES = [ + "all-minilm", + "aya", + "aya-expanse", + "athene-v2", + "bakllava", + "bge-large", + "bge-m3", + "cogito", + "codegemma", + "codegeex4", + "codeqwen", + "codestral", + "codeup", + "codellama", + "command-a", + "command-r", + "command-r-plus", + "command-r7b", + "deepcoder", + "deepseek-coder", + "deepseek-coder-v2", + "deepseek-llm", + "deepseek-r1", + "deepseek-v2", + "deepseek-v2.5", + "deepseek-v3", + "deepscaler", + "devstral", + "dbrx", + "dolphin-mistral", + "dolphin-mixtral", + "dolphin-phi", + "dolphin3", + "dolphincoder", + "exaone-deep", + "exaone3.5", + "everythinglm", + "falcon", + "falcon3", + "gemma", + "gemma2", + "gemma3", + "gemma3n", + "glm4", + "goliath", + "granite-code", + "granite-embedding", + "granite3-dense", + "granite3-guardian", + "granite3-moe", + "granite3.1-dense", + "granite3.1-moe", + "granite3.2", + "granite3.2-vision", + "granite3.3", + "hermes3", + "internlm2", + "lafrican", + "llama-pro", + "llama-guard3", + "llama2", + "llama2-chinese", + "llama2-uncensored", + "llama3", + "llama3-chatqa", + "llama3-groq-tool-use", + "llama3-gradient", + "llama3.1", + "llama3.2", + "llama3.2-vision", + "llama3.3", + "llama4", + "llava", + "llava-llama3", + "llava-phi3", + "magicoder", + "magistral", + "marco-o1", + "mathstral", + "meditron", + "medllama2", + "megadolphin", + "minicpm-v", + "mistral", + "mistral-large", + "mistral-nemo", + "mistral-openorca", + "mistral-small", + "mistral-small3.1", + "mistral-small3.2", + "mistrallite", + "moondream", + "mxbai-embed-large", + "nemotron", + "nemotron-mini", + "neural-chat", + "nexusraven", + "notus", + "nous-hermes", + "nous-hermes2", + "nous-hermes2-mixtral", + "nomic-embed-text", + "notux", + "olmo2", + "opencoder", + "openchat", + "openthinker", + "openhermes", + "orca-mini", + "orca2", + "paraphrase-multilingual", + "phi", + "phi3", + "phi3.5", + "phi4", + "phi4-mini", + "phi4-mini-reasoning", + "phi4-reasoning", + "phind-codellama", + "qwen", + "qwen2", + "qwen2-math", + "qwen2.5", + "qwen2.5-coder", + "qwen2.5vl", + "qwen3", + "qwq", + "r1-1776", + "reader-lm", + "reflection", + "sailor2", + "samatha-mistral", + "shieldgemma", + "smallthinker", + "smollm", + "smollm2", + "snowflake-arctic-embed", + "snowflake-arctic-embed2", + "solar", + "solar-pro", + "sqlcoder", + "stable-beluga", + "stable-code", + "stablelm-zephyr", + "stablelm2", + "starcoder", + "starcoder2", + "starling-lm", + "sunbeam", + "tulu3", + "tinydolphin", + "tinyllama", + "vicuna", + "wizard-math", + "wizard-vicuna", + "wizard-vicuna-uncensored", + "wizardcoder", + "wizardlm", + "wizardlm-uncensored", + "wizardlm2", + "xwinlm", + "yarn-llama2", + "yarn-mistral", + "yi", + "yi-coder", + "zephyr" +] + +def normalize_model_family(name: str) -> str: + """ + Normalize a model family name by removing '-' and '.' characters. + Args: + name (str): The model family name to normalize. + Returns: + str: The normalized model family name. + """ + return name.replace('-', '').replace('.', '').lower() + + +MODEL_FAMILIES = {normalize_model_family(name) : name for name in RAW_MODEL_FAMILIES} + +# the %-ge of memory that we can use for inference +# TODO: add this back in when implementing memory checking for LLMs +# MEMORY_THRESHOLD = 0.8 + + + +def parse_parameter_size(size_str: str) -> float: + """ + Convert parameter size string (e.g., '8.0B', '7.2B') to number of parameters. + """ + if not size_str: + return 0 + multiplier = {'B': 1e9, 'M': 1e6, 'K': 1e3} + size_str = size_str.upper() + for suffix, mult in multiplier.items(): + if suffix in size_str: + return float(size_str.replace(suffix, '')) * mult + return float(size_str) + + +def resolve_model_name(ctx: TowerContext, requested_model: str) -> str: + """ + Resolve the model name based on the inference router and requested model. + + Args: + ctx (TowerContext): The context containing the inference router and other settings. + requested_model (str): The name of the model requested by the user. + + Returns: + str: The resolved model name. + + Raises: + ValueError: If the inference router specified in the context is not supported. + """ + if ctx.inference_router not in INFERENCE_ROUTERS: + raise ValueError(f"Inference router {ctx.inference_router} not supported.") + + if ctx.inference_router == "ollama": + return resolve_ollama_model_name(ctx,requested_model) + elif ctx.inference_router == "hugging_face_hub": + return resolve_hugging_face_hub_model_name(ctx,requested_model) + +def get_local_ollama_models() -> List[dict]: + """ + Get a list of locally installed Ollama models with their details. + Returns a list of dictionaries containing: + - name: model name with tag + - model_family: model family without tag + - size: model size in bytes + - parameter_size: number of parameters + - quantization_level: quantization level if specified + """ + try: + models = ollama_list_models() + model_list = [] + for model in models['models']: + model_name = model.get('model', '') + model_family = model_name.split(':')[0] + size = model.get('size', 0) + details = model.get('details', {}) + parameter_size=details.get('parameter_size', '') + quantization_level=details.get('quantization_level', '') + + model_list.append({ + 'model': model_name, + 'model_family': model_family, + 'size': size, + 'parameter_size': parameter_size, + 'quantization_level': quantization_level + }) + return model_list + except Exception as e: + raise RuntimeError(f"Failed to list Ollama models: {str(e)}") + + +def resolve_ollama_model_name(ctx: TowerContext, requested_model: str) -> str: """ - extract_model_name maps the relevant supported model into a model for the - underlying LLM provider that we want to use. + Resolve the Ollama model name to use. """ - if ctx.is_local(): - if supported_model not in OLLAMA_MODELS: - raise ValueError(f"Model {supported_model} not supported for Ollama.") - return OLLAMA_MODELS[supported_model] + local_models = get_local_ollama_models() + local_model_names = [model['model'] for model in local_models] + + # TODO: add this back in when implementing memory checking for LLMs + #memory = get_available_memory() + #memory_threshold = memory['available'] * MEMORY_THRESHOLD + + if normalize_model_family(requested_model) in MODEL_FAMILIES: + # Filter models by family + matching_models = [model for model in local_models if model['model_family'] == requested_model] + + # TODO: add this back in when implementing memory checking for LLMs + # Filter models by memory + # if check_for_memory: + # matching_models = [model for model in matching_models if model['size'] < memory_threshold] + + # Return the model with the largest parameter size + if matching_models: + best_model = max(matching_models, key=lambda x: parse_parameter_size(x['parameter_size']))['model'] + return best_model + else: + # TODO: add this back in when implementing memory checking for LLMs + # raise ValueError(f"No models in family {requested_model} fit in available memory ({memory['available'] / (1024**3):.2f} GB) with max memory threshold {MEMORY_THRESHOLD} or are not available locally. Please pull a model first using 'ollama pull {requested_model}'") + raise ValueError(f"No models found with name {requested_model}. Please pull a model first using 'ollama pull {requested_model}'") + elif requested_model in local_model_names: + return requested_model else: - if supported_model not in HUGGING_FACE_MODELS: - raise ValueError(f"Model {supported_model} not supported for Hugging Face Hub.") - return HUGGING_FACE_MODELS[supported_model] + raise ValueError(f"No models found with name {requested_model}. Please pull a model first using 'ollama pull {requested_model}'") + +def resolve_hugging_face_hub_model_name(ctx: TowerContext, requested_model: str) -> str: + """ + Resolve the Hugging Face Hub model name to use. + Returns a list of models with their inference provider mappings. + """ + api = HfApi(token=ctx.inference_router_api_key) + + models = None + + try: + model_info = api.model_info(requested_model, expand="inferenceProviderMapping") + models = [model_info] + except RepositoryNotFoundError as e: + # If model_info fails, it means the model does not exist under this exact name + # Therefore, fall back to "search" and look for models that partially match the name + # In Hugging Face Hub terminology Repository = Model / Dataset / Space. + pass + except Exception as e: + # for the rest of the errors, we will raise an error + raise RuntimeError(f"Error while getting model_info for {requested_model}: {str(e)}") + + + # If inference_provider is specified, search by inference provider + # We will use "search" instead of "filter" because only search allows searching inside the model name + # TODO: Add more filtering options e.g. by number of parameters, so that we do not have to retrieve so many models + # TODO: We need to retrieve >1 model because "search" returns a full text match in both model IDs and Descriptions + + if models is None: + if ctx.inference_provider is not None: + models = api.list_models( + search=f"{requested_model}", + #filter=f"inference_provider:{ctx.inference_provider}", + # this is supposed to work in recent HF versions, but it doesn't work for me + # we will do the filtering manually below + expand="inferenceProviderMapping", + limit=20) + else: + models = api.list_models( + search=f"{requested_model}", + expand="inferenceProviderMapping", + limit=20) + + # Create a list of models with their inference provider mappings + model_list = [] + try: + for model in models: + # Handle the case where inference_provider_mapping might be None or empty + inference_provider_mapping = getattr(model, 'inference_provider_mapping', []) or [] + + model_info = { + 'model_name': model.id, + 'inference_provider_mapping': inference_provider_mapping + } + + # If inference_provider is specified, only add models that support it + if ctx.inference_provider is not None: + if ctx.inference_provider not in [mapping.provider for mapping in inference_provider_mapping]: + continue + + # Check that requested_model is partially contained in model.id + if normalize_model_family(requested_model) not in normalize_model_family(model.id): + continue + + model_list.append(model_info) + except Exception as e: + raise RuntimeError(f"Error while iterating: {str(e)}") + + if not model_list: + raise ValueError(f"No models found with name {requested_model} on Hugging Face Hub") + + return model_list[0]['model_name'] + class Llm: def __init__(self, context: TowerContext, model_name: str, max_tokens: int = 1000): """ Wraps up interfacing with a language model in the Tower system. """ - self.model_name = model_name + self.requested_model_name = model_name self.max_tokens = max_tokens self.context = context - def inference(self, messages: List) -> str: + self.inference_router = context.inference_router + self.inference_router_api_key = context.inference_router_api_key + self.inference_provider = context.inference_provider + + if self.inference_router is None and self.context.is_local(): + self.inference_router = "ollama" + + # for local routers, the service is also the router + if self.inference_router in LOCAL_INFERENCE_ROUTERS: + self.inference_provider = self.inference_router + + # Check that we know this router. This will also check that router was set when not in local mode. + if context.inference_router not in INFERENCE_ROUTERS: + raise ValueError(f"Inference router {context.inference_router} not supported.") + + self.model_name = resolve_model_name( + self.context, self.requested_model_name) + + + def complete_chat(self, messages: List) -> str: """ - Simulate the inference process of a language model. - In a real-world scenario, this would involve calling an API or using a library to get the model's response. + Mimics the OpenAI Chat Completions API by sending a list of messages to the language model + and returning the generated response. + + This function provides a unified interface for chat-based interactions with different + language model providers (Ollama, Hugging Face Hub, etc.) while maintaining compatibility + with the OpenAI Chat Completions API format. + + Args: + messages: A list of message dictionaries, each containing 'role' and 'content' keys. + Follows the OpenAI Chat Completions API message format. + + Returns: + str: The generated response from the language model. + + Example: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"} + ] + response = llm.complete_chat(messages) """ - model_name = extract_model_name(self.context, self.model_name) - if self.context.is_local(): - # Use Ollama for local inference using Apple GPUs - response = infer_with_ollama( + if self.inference_router == "ollama": + # Use Ollama for local inference + response = complete_chat_with_ollama( ctx = self.context, - model = model_name, + model = self.model_name, messages = messages ) - else: - max_tokens = self.max_tokens - response = infer_with_hugging_face_hub( + elif self.inference_router == "hugging_face_hub": + response = complete_chat_with_hugging_face_hub( ctx = self.context, - model = model_name, + model = self.model_name, messages = messages, - max_tokens=max_tokens + max_tokens=self.max_tokens ) return response def prompt(self, prompt: str) -> str: """ - Prompt a language model with a string. This basically will format the - relevant messages internally to send to the model. + Mimics the old-style OpenAI Completions API (not Chat Completions!) by sending a single prompt string + to the language model and returning the generated response. + + This function provides a simple interface for single-prompt interactions, similar to the + legacy OpenAI /v1/completions endpoint. It internally converts the prompt to a chat message + format and uses the complete_chat method. + + Args: + prompt: A single string containing the prompt to send to the language model. + + Returns: + str: The generated response from the language model. + + Example: + response = llm.prompt("What is the capital of France?") """ - return self.inference([{ + return self.complete_chat([{ "role": "user", "content": prompt, }]) @@ -92,7 +488,10 @@ def extract_ollama_message(resp: ChatResponse) -> str: def extract_hugging_face_hub_message(resp: ChatCompletionOutput) -> str: return resp.choices[0].message.content -def infer_with_ollama(ctx: TowerContext, model: str, messages: list, is_retry: bool = False) -> str: +def complete_chat_with_ollama(ctx: TowerContext, model: str, messages: list, is_retry: bool = False) -> str: + + # TODO: remove the try/except and don't pull the model if it doesn't exist. sso 7/20/25 + # the except code is not reachable right now because we always call this function with a model that exists try: response: ChatResponse = chat(model=model, messages=messages) return extract_ollama_message(response) @@ -102,21 +501,21 @@ def infer_with_ollama(ctx: TowerContext, model: str, messages: list, is_retry: b # (or if it exists) will start it for us. pull(model=model) - # Retry the inference after the model hasbeen pulled. - return infer_with_ollama(ctx, model, messages, is_retry=True) + # Retry the inference after the model has been pulled. + return complete_chat_with_ollama(ctx, model, messages, is_retry=True) # Couldn't figure out what the error was, so we'll just raise it accordingly. raise e -def infer_with_hugging_face_hub(ctx: TowerContext, model: str, messages: List, **kwargs) -> str: +def complete_chat_with_hugging_face_hub(ctx: TowerContext, model: str, messages: List, **kwargs) -> str: """ Uses the Hugging Face Hub API to perform inference. Will use configuration supplied by the environment to determine which client to connect to and all that. """ client = InferenceClient( - provider=ctx.hugging_face_provider, - api_key=ctx.hugging_face_api_key + provider=ctx.inference_provider, + api_key=ctx.inference_router_api_key ) completion = client.chat_completion(messages, @@ -125,3 +524,27 @@ def infer_with_hugging_face_hub(ctx: TowerContext, model: str, messages: List, * ) return extract_hugging_face_hub_message(completion) + + +# TODO: add this back in when implementing memory checking for LLMs +# def get_available_memory() -> dict: +# """ +# Get available system memory information. +# Returns a dictionary containing: +# - total: total physical memory in bytes +# - available: available memory in bytes +# - used: used memory in bytes +# - percent: memory usage percentage +# """ +# try: +# memory = psutil.virtual_memory() +# return { +# 'total': memory.total, +# 'available': memory.available, +# 'used': memory.used, +# 'percent': memory.percent +# } +# except Exception as e: +# raise RuntimeError(f"Failed to get memory information: {str(e)}") + + diff --git a/tests/tower/test_llms.py b/tests/tower/test_llms.py new file mode 100644 index 00000000..575a6c87 --- /dev/null +++ b/tests/tower/test_llms.py @@ -0,0 +1,231 @@ +import os +import pytest +from unittest.mock import patch, MagicMock + +from tower._llms import llms, Llm +from tower._context import TowerContext + +@pytest.fixture +def mock_ollama_context(): + """Create a mock TowerContext for testing.""" + context = MagicMock(spec=TowerContext) + context.is_local.return_value = True + context.inference_router = "ollama" + context.inference_provider = "ollama" + context.inference_router_api_key = None + return context + +@pytest.fixture +def mock_hf_together_context(): + """Create a mock TowerContext for Hugging Face Hub testing.""" + context = MagicMock(spec=TowerContext) + context.is_local.return_value = False + context.inference_router = "hugging_face_hub" + context.inference_router_api_key = os.getenv("TOWER_INFERENCE_ROUTER_API_KEY") + context.inference_provider = "together" + return context + +@pytest.fixture +def mock_hf_context(): + """Create a mock TowerContext for Hugging Face Hub testing.""" + context = MagicMock(spec=TowerContext) + context.is_local.return_value = False + context.inference_router = "hugging_face_hub" + context.inference_router_api_key = os.getenv("TOWER_INFERENCE_ROUTER_API_KEY") + context.inference_provider = None + return context + + +@pytest.fixture +def mock_ollama_response(): + """Create a mock Ollama response.""" + response = MagicMock() + response.message.content = "This is a test response" + return response + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_model_family_locally_1(mock_ollama_context): + """ + Test resolving a model family name to a particular model. + Run this test with ollama locally installed. + deepseek-r1 is a name that is used by both ollama and HF + """ + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_ollama_context): + + # Create LLM instance based on model family name + llm = llms("deepseek-r1") + + # Verify it's an Llm instance + assert isinstance(llm, Llm) + + # Verify the resolved model was found locally + assert llm.model_name.startswith("deepseek-r1:") + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_model_family_on_hugging_face_hub_1(mock_hf_together_context): + """ + Test resolving a model family name to a particular model. + Run this test against models available on Hugging Face Hub. + deepseek-r1 is a name that is used by both ollama and HF + """ + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_hf_together_context): + + assert mock_hf_together_context.inference_router_api_key is not None + + # Create LLM instance + llm = llms("deepseek-r1") + + # Verify it's an Llm instance + assert isinstance(llm, Llm) + + # Verify the resolved model was found on the Hub + assert llm.model_name.startswith("deepseek-ai") + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_model_family_locally_2(mock_ollama_context): + """ + Test resolving a model family name to a particular model. + Run this test with ollama locally installed. + llama3.2 is a name used by ollama. + Llama-3.2 is a name used on HF. + """ + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_ollama_context): + + # Create LLM instance based on model family name + llm = llms("llama3.2") + + # Verify it's an Llm instance + assert isinstance(llm, Llm) + + # Verify the resolved model was found locally + assert llm.model_name.startswith("llama3.2:") + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_model_family_on_hugging_face_hub_2(mock_hf_together_context): + """ + Test resolving a model family name to a particular model. + Run this test against models available on Hugging Face Hub. + llama3.2 is a name used by ollama. + Llama-3.2 is a name used on HF. + """ + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_hf_together_context): + + assert mock_hf_together_context.inference_router_api_key is not None + + # Create LLM instance + llm = llms("llama3.2") + + # Verify it's an Llm instance + assert isinstance(llm, Llm) + + # Verify the resolved model was found on the Hub + assert "llama" in llm.model_name + + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_nonexistent_model_locally(mock_ollama_context): + """Test llms function with a model that doesn't exist locally.""" + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_ollama_context): + # Mock get_local_ollama_models to return empty list + with patch('tower._llms.get_local_ollama_models', return_value=[]): + # Test with a non-existent model + with pytest.raises(ValueError) as exc_info: + llm = llms("nonexistent-model") + + # Verify the error message + assert "No models found" in str(exc_info.value) + + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_nonexistent_model_on_hugging_face_hub(mock_hf_together_context): + """Test llms function with a model that doesn't exist on huggingface hub.""" + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_hf_together_context): + + with pytest.raises(ValueError) as exc_info: + llm = llms("nonexistent-model") + + # Verify the error message + assert "No models found" in str(exc_info.value) + + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_exact_model_name_on_hugging_face_hub(mock_hf_together_context): + """Test specifying the exact name of a model on Hugging Face Hub.""" + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_hf_together_context): + + assert mock_hf_together_context.inference_router_api_key is not None + + # Create LLM instance + llm = llms("deepseek-ai/DeepSeek-R1") + + # Verify it's an Llm instance + assert isinstance(llm, Llm) + + # Verify the context was set + assert llm.context == mock_hf_together_context + + # Verify the resolved model was found on the Hub + assert llm.model_name.startswith("deepseek-ai/DeepSeek-R1") + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_nameres_with_partial_model_name_on_hugging_face_hub(mock_hf_context): + """Test specifying a partial model name on Hugging Face Hub.""" + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_hf_context): + + assert mock_hf_context.inference_router_api_key is not None + + # Create LLM instance + llm = llms("google/gemma-3") + + # Verify it's an Llm instance + assert isinstance(llm, Llm) + + # Verify the context was set + assert llm.context == mock_hf_context + + # Verify the resolved model was found on the Hub + assert llm.model_name.startswith("google/gemma-3") + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_inference_with_hugging_face_hub_1(mock_hf_together_context): + """Test actual inference on a model served by together via Hugging Face Hub.""" + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_hf_together_context): + + assert mock_hf_together_context.inference_router_api_key is not None + + # Create LLM instance + llm = llms("deepseek-ai/DeepSeek-R1") + + # Test a simple prompt + response = llm.prompt("What is your model name?") + assert "DeepSeek-R1" in response + +@pytest.mark.skip(reason="Not runnable right now in GH Actions") +def test_llms_inference_locally_1(mock_ollama_context, mock_ollama_response): + """Test local inference, but against a stubbed response.""" + # Mock the TowerContext.build() to return our mock context + with patch('tower._llms.TowerContext.build', return_value=mock_ollama_context): + # Mock the chat function to return our mock response + with patch('tower._llms.chat', return_value=mock_ollama_response): + + # Create LLM instance based on model family name + llm = llms("deepseek-r1") + + # Test a simple prompt + response = llm.prompt("Hello, how are you?") + assert response == "This is a test response" + + + + + + diff --git a/tests/tower/test_tables.py b/tests/tower/test_tables.py index 17213c5c..94c5b3c1 100644 --- a/tests/tower/test_tables.py +++ b/tests/tower/test_tables.py @@ -9,7 +9,7 @@ # We import all the things we need from Tower. import tower.polars as pl import tower.pyarrow as pa -from tower.pyiceberg.catalog.memory import InMemoryCatalog +from pyiceberg.catalog.memory import InMemoryCatalog # Imports the library under test import tower diff --git a/uv.lock b/uv.lock index bec8211e..3296a983 100644 --- a/uv.lock +++ b/uv.lock @@ -270,6 +270,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "hf-xet" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ed/d4/7685999e85945ed0d7f0762b686ae7015035390de1161dcea9d5276c134c/hf_xet-1.1.5.tar.gz", hash = "sha256:69ebbcfd9ec44fdc2af73441619eeb06b94ee34511bbcf57cd423820090f5694", size = 495969, upload-time = "2025-06-20T21:48:38.007Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/89/a1119eebe2836cb25758e7661d6410d3eae982e2b5e974bcc4d250be9012/hf_xet-1.1.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f52c2fa3635b8c37c7764d8796dfa72706cc4eded19d638331161e82b0792e23", size = 2687929, upload-time = "2025-06-20T21:48:32.284Z" }, + { url = "https://files.pythonhosted.org/packages/de/5f/2c78e28f309396e71ec8e4e9304a6483dcbc36172b5cea8f291994163425/hf_xet-1.1.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:9fa6e3ee5d61912c4a113e0708eaaef987047616465ac7aa30f7121a48fc1af8", size = 2556338, upload-time = "2025-06-20T21:48:30.079Z" }, + { url = "https://files.pythonhosted.org/packages/6d/2f/6cad7b5fe86b7652579346cb7f85156c11761df26435651cbba89376cd2c/hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc874b5c843e642f45fd85cda1ce599e123308ad2901ead23d3510a47ff506d1", size = 3102894, upload-time = "2025-06-20T21:48:28.114Z" }, + { url = "https://files.pythonhosted.org/packages/d0/54/0fcf2b619720a26fbb6cc941e89f2472a522cd963a776c089b189559447f/hf_xet-1.1.5-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dbba1660e5d810bd0ea77c511a99e9242d920790d0e63c0e4673ed36c4022d18", size = 3002134, upload-time = "2025-06-20T21:48:25.906Z" }, + { url = "https://files.pythonhosted.org/packages/f3/92/1d351ac6cef7c4ba8c85744d37ffbfac2d53d0a6c04d2cabeba614640a78/hf_xet-1.1.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ab34c4c3104133c495785d5d8bba3b1efc99de52c02e759cf711a91fd39d3a14", size = 3171009, upload-time = "2025-06-20T21:48:33.987Z" }, + { url = "https://files.pythonhosted.org/packages/c9/65/4b2ddb0e3e983f2508528eb4501288ae2f84963586fbdfae596836d5e57a/hf_xet-1.1.5-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:83088ecea236d5113de478acb2339f92c95b4fb0462acaa30621fac02f5a534a", size = 3279245, upload-time = "2025-06-20T21:48:36.051Z" }, + { url = "https://files.pythonhosted.org/packages/f0/55/ef77a85ee443ae05a9e9cba1c9f0dd9241eb42da2aeba1dc50f51154c81a/hf_xet-1.1.5-cp37-abi3-win_amd64.whl", hash = "sha256:73e167d9807d166596b4b2f0b585c6d5bd84a26dea32843665a8b58f6edba245", size = 2738931, upload-time = "2025-06-20T21:48:39.482Z" }, +] + [[package]] name = "httpcore" version = "1.0.9" @@ -300,20 +315,21 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.30.2" +version = "0.34.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "packaging" }, { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/df/22/8eb91736b1dcb83d879bd49050a09df29a57cc5cd9f38e48a4b1c45ee890/huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466", size = 400868, upload-time = "2025-04-08T08:32:45.26Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/b4/e6b465eca5386b52cf23cb6df8644ad318a6b0e12b4b96a7e0be09cbfbcc/huggingface_hub-0.34.3.tar.gz", hash = "sha256:d58130fd5aa7408480681475491c0abd7e835442082fbc3ef4d45b6c39f83853", size = 456800, upload-time = "2025-07-29T08:38:53.885Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/93/27/1fb384a841e9661faad1c31cbfa62864f59632e876df5d795234da51c395/huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28", size = 481433, upload-time = "2025-04-08T08:32:43.305Z" }, + { url = "https://files.pythonhosted.org/packages/59/a8/4677014e771ed1591a87b63a2392ce6923baf807193deef302dcfde17542/huggingface_hub-0.34.3-py3-none-any.whl", hash = "sha256:5444550099e2d86e68b2898b09e85878fbd788fc2957b506c6a79ce060e39492", size = 558847, upload-time = "2025-07-29T08:38:51.904Z" }, ] [[package]] @@ -853,6 +869,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634, upload-time = "2025-03-02T12:54:52.069Z" }, ] +[[package]] +name = "pytest-env" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1f/31/27f28431a16b83cab7a636dce59cf397517807d247caa38ee67d65e71ef8/pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf", size = 8911, upload-time = "2024-09-17T22:39:18.566Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/b8/87cfb16045c9d4092cfcf526135d73b88101aac83bc1adcf82dfb5fd3833/pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30", size = 6141, upload-time = "2024-09-17T22:39:16.942Z" }, +] + [[package]] name = "pytest-httpx" version = "0.35.0" @@ -1206,6 +1235,9 @@ source = { editable = "." } dependencies = [ { name = "attrs" }, { name = "httpx" }, + { name = "huggingface-hub" }, + { name = "ollama" }, + { name = "pyiceberg" }, { name = "python-dateutil" }, ] @@ -1232,6 +1264,7 @@ dev = [ { name = "openapi-python-client" }, { name = "pyiceberg", extra = ["sql-sqlite"] }, { name = "pytest" }, + { name = "pytest-env" }, { name = "pytest-httpx" }, ] @@ -1239,10 +1272,13 @@ dev = [ requires-dist = [ { name = "attrs", specifier = "==24.2.0" }, { name = "httpx", specifier = "==0.28.1" }, - { name = "huggingface-hub", marker = "extra == 'ai'", specifier = "==0.30.2" }, + { name = "huggingface-hub", specifier = ">=0.34.3" }, + { name = "huggingface-hub", marker = "extra == 'ai'", specifier = "==0.34.3" }, + { name = "ollama", specifier = ">=0.4.7" }, { name = "ollama", marker = "extra == 'ai'", specifier = "==0.4.7" }, { name = "polars", marker = "extra == 'iceberg'", specifier = "==1.27.1" }, { name = "pyarrow", marker = "extra == 'iceberg'", specifier = "==19.0.1" }, + { name = "pyiceberg", specifier = "==0.9.0" }, { name = "pyiceberg", marker = "extra == 'iceberg'", specifier = "==0.9.0" }, { name = "python-dateutil", specifier = "==2.9.0.post0" }, { name = "tower", extras = ["ai", "iceberg"], marker = "extra == 'all'", editable = "." }, @@ -1254,6 +1290,7 @@ dev = [ { name = "openapi-python-client", specifier = "==0.24.3" }, { name = "pyiceberg", extras = ["sql-sqlite"], specifier = "==0.9.0" }, { name = "pytest", specifier = "==8.3.5" }, + { name = "pytest-env", specifier = ">=1.1.3" }, { name = "pytest-httpx", specifier = "==0.35.0" }, ]