From e9559e903ec13d1c6ced135f95ed323737f1e665 Mon Sep 17 00:00:00 2001 From: tqtensor Date: Thu, 8 May 2025 21:25:35 +0000 Subject: [PATCH 1/2] feat: unify provider configs --- README.md | 48 +++++++------- src/vision_parse/llm.py | 138 ++++++++++++++++------------------------ tests/test_llm.py | 40 +++--------- tests/test_parser.py | 17 ++--- 4 files changed, 94 insertions(+), 149 deletions(-) diff --git a/README.md b/README.md index 3defd19..3d1184c 100644 --- a/README.md +++ b/README.md @@ -93,18 +93,18 @@ parser = VisionParser( # Initialize parser with Azure OpenAI model parser = VisionParser( model_name="gpt-4o", + api_key="your-azure-openai-api-key", # replace with your Azure OpenAI API key image_mode="url", detailed_extraction=False, # set to True for more detailed extraction enable_concurrency=True, - openai_config={ - "AZURE_ENDPOINT_URL": "https://****.openai.azure.com/", # replace with your Azure endpoint URL - "AZURE_DEPLOYMENT_NAME": "*******", # replace with Azure deployment name, if needed - "AZURE_OPENAI_API_KEY": "***********", # replace with your Azure OpenAI API key - "AZURE_OPENAI_API_VERSION": "2024-08-01-preview", # replace with latest Azure OpenAI API version + provider_config={ + "base_url": "https://****.openai.azure.com/", # replace with your Azure endpoint URL + "api_version": "2024-08-01-preview", # replace with latest Azure OpenAI API version + "azure": True, # specify that this is Azure OpenAI + "azure_deployment": "*******", # replace with Azure deployment name }, ) - # Initialize parser with Google Gemini model parser = VisionParser( model_name="gemini-1.5-flash", @@ -162,30 +162,28 @@ The following Vision LLM models have been thoroughly tested with Vision Parse, b ### Provider-Specific Configuration -#### OpenAI Configuration +The `provider_config` parameter lets you configure provider-specific settings through a unified interface: ```python -openai_config = { - # For standard OpenAI - "OPENAI_BASE_URL": "https://api.openai.com/v1", # optional - "OPENAI_MAX_RETRIES": 3, # optional - "OPENAI_TIMEOUT": 240.0, # optional - - # For Azure OpenAI - "AZURE_ENDPOINT_URL": "https://your-resource.openai.azure.com/", - "AZURE_DEPLOYMENT_NAME": "your-deployment-name", - "AZURE_OPENAI_API_KEY": "your-azure-api-key", - "AZURE_OPENAI_API_VERSION": "2024-08-01-preview", +# For OpenAI +provider_config = { + "base_url": "https://api.openai.com/v1", # optional + "max_retries": 3, # optional + "timeout": 240.0, # optional } -``` -#### Gemini Configuration (Google AI Studio) +# For Azure OpenAI +provider_config = { + "base_url": "https://your-resource.openai.azure.com/", + "api_version": "2024-08-01-preview", + "azure": True, + "azure_deployment": "your-deployment-name", +} -```python -gemini_config = { - "GOOGLE_API_KEY": "your-google-api-key", # API key from Google AI Studio (not Vertex AI) - "GEMINI_MAX_RETRIES": 3, # optional - "GEMINI_TIMEOUT": 240.0, # optional +# For Gemini (Google AI Studio) +provider_config = { + "max_retries": 3, # optional + "timeout": 240.0, # optional } ``` diff --git a/src/vision_parse/llm.py b/src/vision_parse/llm.py index 21335e1..a95c313 100644 --- a/src/vision_parse/llm.py +++ b/src/vision_parse/llm.py @@ -51,21 +51,19 @@ class LLM: def __init__( self, model_name: str, - api_key: Optional[str], - temperature: float, - top_p: float, - openai_config: Optional[Dict], - gemini_config: Optional[Dict], - image_mode: Literal["url", "base64", None], - custom_prompt: Optional[str], - detailed_extraction: bool, - enable_concurrency: bool, + api_key: Optional[str] = None, + temperature: float = 0.7, + top_p: float = 0.7, + provider_config: Optional[Dict] = None, + image_mode: Literal["url", "base64", None] = None, + custom_prompt: Optional[str] = None, + detailed_extraction: bool = False, + enable_concurrency: bool = False, **kwargs: Any, ): self.model_name = model_name self.api_key = api_key - self.openai_config = openai_config or {} - self.gemini_config = gemini_config or {} + self.provider_config = provider_config or {} self.temperature = temperature self.top_p = top_p self.image_mode = image_mode @@ -129,66 +127,17 @@ def _get_model_params(self, structured: bool = False) -> Dict[str, Any]: Returns: Dict[str, Any]: Dictionary containing model parameters for API calls. """ - # Base parameters that are common across providers + + # Base parameters common across providers params = { "model": self.model_name, "temperature": 0.0 if structured else self.temperature, "top_p": 0.4 if structured else self.top_p, + **self.kwargs, } - # Filter kwargs based on provider - if self.provider in ["openai", "azure"]: - # Only include OpenAI-compatible parameters - openai_params = { - k: v - for k, v in self.kwargs.items() - if k not in ["device", "num_workers", "ollama_config"] - } - params.update(openai_params) - - if self.openai_config.get("AZURE_OPENAI_API_KEY"): - params.update( - { - "api_key": self.openai_config["AZURE_OPENAI_API_KEY"], - "api_base": self.openai_config["AZURE_ENDPOINT_URL"], - "api_version": self.openai_config.get( - "AZURE_OPENAI_API_VERSION", "2024-08-01-preview" - ), - "deployment_id": self.openai_config.get( - "AZURE_DEPLOYMENT_NAME" - ), - } - ) - else: - params.update( - { - "api_key": self.api_key, - "base_url": self.openai_config.get("OPENAI_BASE_URL"), - "max_retries": self.openai_config.get("OPENAI_MAX_RETRIES", 3), - "timeout": self.openai_config.get("OPENAI_TIMEOUT", 240.0), - } - ) - elif self.provider == "gemini": - # Only include Gemini-compatible parameters - gemini_params = { - k: v - for k, v in self.kwargs.items() - if k not in ["device", "num_workers", "ollama_config"] - } - params.update(gemini_params) - params.update( - { - "api_key": self.api_key, - **self.gemini_config, - } - ) - elif self.provider == "deepseek": - # Handle DeepSeek parameters - params.update( - { - "api_key": self.api_key, - } - ) + # Add API key and provider-specific config + params.update({"api_key": self.api_key, **self.provider_config}) return params @@ -214,6 +163,7 @@ async def _get_response( Raises: LLMError: If LLM processing fails. """ + try: messages = [ { @@ -232,26 +182,46 @@ async def _get_response( params = self._get_model_params(structured) - if structured: - response = await self.client.chat.completions.create( - messages=messages, - response_model=ImageDescription, - **params, - ) - return response.model_dump_json() + if self.enable_concurrency: + # Async path + if structured: + response = await self.client.chat.completions.create( + messages=messages, + response_model=ImageDescription, + **params, + ) + return response.model_dump_json() + else: + # For non-structured responses, use str as the response model + response = await self.client.chat.completions.create( + messages=messages, + response_model=str, + **params, + ) else: - # For non-structured responses, use str as the response model - response = await self.client.chat.completions.create( - messages=messages, - response_model=str, - **params, - ) - return re.sub( - r"```(?:markdown)?\n(.*?)\n```", - r"\1", - response, - flags=re.DOTALL, - ) + # Sync path + if structured: + response = self.client.chat.completions.create( + messages=messages, + response_model=ImageDescription, + **params, + ) + return response.model_dump_json() + else: + # For non-structured responses, use str as the response model + response = self.client.chat.completions.create( + messages=messages, + response_model=str, + **params, + ) + + # Process the response for non-structured output + return re.sub( + r"```(?:markdown)?\n(.*?)\n```", + r"\1", + response, + flags=re.DOTALL, + ) except Exception as e: raise LLMError(f"LLM processing failed: {str(e)}") diff --git a/tests/test_llm.py b/tests/test_llm.py index 625df03..bc0fc63 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -53,15 +53,6 @@ def test_unsupported_model(): model_name="unsupported-model", temperature=0.7, top_p=0.7, - api_key=None, - openai_config=None, - gemini_config=None, - image_mode=None, - custom_prompt=None, - detailed_extraction=False, - enable_concurrency=False, - device=None, - num_workers=1, ) assert "not from a supported provider" in str(exc_info.value) @@ -81,14 +72,11 @@ async def test_openai_generate_markdown( api_key="test-key", temperature=0.7, top_p=0.7, - openai_config={"OPENAI_API_KEY": "test-key"}, - gemini_config=None, + provider_config={"base_url": "https://api.openai.com/v1"}, image_mode=None, custom_prompt=None, detailed_extraction=True, enable_concurrency=True, - device=None, - num_workers=1, ) result = await llm.generate_markdown(sample_base64_image, mock_pixmap, 0) @@ -109,22 +97,18 @@ async def test_azure_openai_generate_markdown( llm = LLM( model_name="gpt-4o", - api_key=None, + provider_config={ + "api_key": "test-key", + "api_base": "https://test.openai.azure.com/", + "deployment_id": "gpt-4o", + "api_version": "2024-08-01-preview", + }, temperature=0.7, top_p=0.7, - openai_config={ - "AZURE_ENDPOINT_URL": "https://test.openai.azure.com/", - "AZURE_DEPLOYMENT_NAME": "gpt-4o", - "AZURE_OPENAI_API_KEY": "test-key", - "AZURE_OPENAI_API_VERSION": "2024-08-01-preview", - }, - gemini_config=None, image_mode=None, custom_prompt=None, detailed_extraction=True, enable_concurrency=True, - device=None, - num_workers=1, ) result = await llm.generate_markdown(sample_base64_image, mock_pixmap, 0) @@ -145,17 +129,13 @@ async def test_gemini_generate_markdown( llm = LLM( model_name="gemini-2.5-pro", - api_key=None, + api_key="test-key", temperature=0.7, top_p=0.7, - openai_config=None, - gemini_config={"GOOGLE_API_KEY": "test-key"}, image_mode=None, custom_prompt=None, detailed_extraction=True, enable_concurrency=True, - device=None, - num_workers=1, ) result = await llm.generate_markdown(sample_base64_image, mock_pixmap, 0) @@ -179,14 +159,10 @@ async def test_deepseek_generate_markdown( api_key="test-key", temperature=0.7, top_p=0.7, - openai_config=None, - gemini_config=None, image_mode=None, custom_prompt=None, detailed_extraction=True, enable_concurrency=True, - device=None, - num_workers=1, ) result = await llm.generate_markdown(sample_base64_image, mock_pixmap, 0) diff --git a/tests/test_parser.py b/tests/test_parser.py index 1215884..f80be6e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -198,19 +198,20 @@ def test_parser_with_custom_page_config(): assert parser.page_config.preserve_transparency -def test_parser_with_openai_config(): - """Test parser initialization with OpenAI configuration.""" +def test_parser_with_provider_config(): + """Test parser initialization with provider-specific configuration.""" - openai_config = { - "OPENAI_BASE_URL": "https://api.openai.com/v1", - "OPENAI_MAX_RETRIES": 3, - "OPENAI_TIMEOUT": 240.0, + provider_config = { + "base_url": "https://api.openai.com/v1", + "max_retries": 3, + "timeout": 240.0, } parser = VisionParser( model_name="gpt-4o", api_key="test-key", temperature=0.7, top_p=0.7, - openai_config=openai_config, + provider_config=provider_config, ) - assert parser.llm.openai_config == openai_config + # Test that provider_config is correctly passed through to the LLM + assert parser.llm.provider_config == provider_config From 7640b1ab35bef70f0ec066ce11f9db7dceb936a0 Mon Sep 17 00:00:00 2001 From: tqtensor Date: Thu, 8 May 2025 21:42:56 +0000 Subject: [PATCH 2/2] feat: add LiteLLM proxy models --- README.md | 14 +++++++++++++ src/vision_parse/constants.py | 2 +- src/vision_parse/llm.py | 6 ++++++ src/vision_parse/parser.py | 18 +++++----------- src/vision_parse/utils.py | 39 +++++++++++++---------------------- 5 files changed, 40 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 3d1184c..e8f9a54 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,20 @@ parser = VisionParser( detailed_extraction=False, # set to True for more detailed extraction enable_concurrency=True, ) + +# Initialize parser with model on LiteLLM proxy +parser = VisionParser( + model_name="litellm/provider/model", + api_key="your-litellm-proxy-api-key", + temperature=0.7, + top_p=0.4, + image_mode="url", + detailed_extraction=False, # set to True for more detailed extraction + enable_concurrency=True, + provider_config={ + "base_url": "https://litellm.proxy.domain", + }, +) ``` ## ✅ Tested Models diff --git a/src/vision_parse/constants.py b/src/vision_parse/constants.py index 13f55c0..fe0deea 100644 --- a/src/vision_parse/constants.py +++ b/src/vision_parse/constants.py @@ -9,7 +9,7 @@ # Common model prefixes for provider detection PROVIDER_PREFIXES: Dict[str, List[str]] = { - "openai": ["gpt"], + "openai": ["gpt", "litellm"], "azure": ["gpt"], "gemini": ["gemini"], "deepseek": ["deepseek"], diff --git a/src/vision_parse/llm.py b/src/vision_parse/llm.py index a95c313..72c161b 100644 --- a/src/vision_parse/llm.py +++ b/src/vision_parse/llm.py @@ -101,12 +101,18 @@ def _get_provider_name(self, model_name: str) -> str: Returns: str: The provider name (e.g., 'openai', 'gemini'). + Note: + Models with 'litellm' prefix are treated as OpenAI models and + converted accordingly by replacing 'litellm' with 'openai'. + Raises: UnsupportedProviderError: If the model name doesn't match any known provider. """ for provider, prefixes in PROVIDER_PREFIXES.items(): if any(model_name.startswith(prefix) for prefix in prefixes): + if model_name.startswith("litellm"): + self.model_name = model_name.replace("litellm", "openai") return provider supported_providers = ", ".join( diff --git a/src/vision_parse/parser.py b/src/vision_parse/parser.py index a3f162e..288dfcc 100644 --- a/src/vision_parse/parser.py +++ b/src/vision_parse/parser.py @@ -11,7 +11,7 @@ from .exceptions import UnsupportedFileError, VisionParserError from .llm import LLM -from .utils import get_device_config +from .utils import get_num_workers logger = logging.getLogger(__name__) nest_asyncio.apply() @@ -41,9 +41,7 @@ def __init__( api_key: Optional[str] = None, temperature: float = 0.7, top_p: float = 0.7, - ollama_config: Optional[Dict] = None, - openai_config: Optional[Dict] = None, - gemini_config: Optional[Dict] = None, + provider_config: Optional[Dict] = None, image_mode: Literal["url", "base64", None] = None, custom_prompt: Optional[str] = None, detailed_extraction: bool = False, @@ -58,9 +56,7 @@ def __init__( api_key (Optional[str]): API key for the LLM provider. temperature (float): Controls randomness in LLM output. Defaults to 0.7. top_p (float): Controls diversity in LLM output. Defaults to 0.7. - ollama_config (Optional[Dict]): Configuration for Ollama provider. - openai_config (Optional[Dict]): Configuration for OpenAI provider. - gemini_config (Optional[Dict]): Configuration for Google AI Studio provider. + provider_config (Optional[Dict]): Configuration for the LLM provider. image_mode (Literal["url", "base64", None]): Mode for handling embedded images. custom_prompt (Optional[str]): Custom prompt for LLM processing. detailed_extraction (bool): Enables detailed text extraction. Defaults to False. @@ -69,7 +65,7 @@ def __init__( """ self.page_config = page_config or PDFPageConfig() - self.device, self.num_workers = get_device_config() + self.num_workers = get_num_workers() self.enable_concurrency = enable_concurrency self.llm = LLM( @@ -77,15 +73,11 @@ def __init__( api_key=api_key, temperature=temperature, top_p=top_p, - ollama_config=ollama_config, - openai_config=openai_config, - gemini_config=gemini_config, + provider_config=provider_config, image_mode=image_mode, detailed_extraction=detailed_extraction, custom_prompt=custom_prompt, enable_concurrency=enable_concurrency, - device=self.device, - num_workers=self.num_workers, **kwargs, ) diff --git a/src/vision_parse/utils.py b/src/vision_parse/utils.py index dc3f233..f8d3a83 100644 --- a/src/vision_parse/utils.py +++ b/src/vision_parse/utils.py @@ -1,11 +1,9 @@ import base64 import logging import os -import platform -import subprocess from dataclasses import dataclass from threading import Lock -from typing import ClassVar, List, Literal, Tuple +from typing import ClassVar, List, Literal import cv2 import fitz @@ -51,6 +49,7 @@ def _prepare_image_for_detection(image: np.ndarray) -> np.ndarray: Raises: ImageExtractionError: If image processing fails. """ + try: grayscale = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) smooth = cv2.GaussianBlur(grayscale, (5, 5), 0) @@ -83,6 +82,7 @@ def _check_region_validity( Raises: ImageExtractionError: If image validation fails. """ + try: width, height = region_dims region_area = cv2.contourArea(contour) / (width * height) @@ -127,6 +127,7 @@ def extract_images( Raises: ImageExtractionError: If image extraction or processing fails. """ + with cls._lock: try: min_width, min_height = min_dimensions @@ -198,29 +199,17 @@ def extract_images( raise ImageExtractionError(f"Image processing failed: {str(e)}") -def get_device_config() -> Tuple[Literal["cuda", "mps", "cpu"], int]: - """Determines optimal device configuration for processing. +def get_num_workers() -> int: + """Determines optimal number of worker processes for concurrent operations. - This function checks available hardware (GPU, MPS, CPU) and returns - the appropriate device type and number of worker processes. + Uses system CPU information to choose a reasonable number of worker processes + for PDF page processing. Generally uses half the available CPU cores to avoid + overwhelming the system. Returns: - Tuple[Literal["cuda", "mps", "cpu"], int]: Device type and optimal number - of worker processes. + int: Optimal number of worker processes (minimum 2, maximum 8). """ - try: - nvidia_smi = subprocess.run( - ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - ) - if nvidia_smi.returncode == 0: - return "cuda", min(len(nvidia_smi.stdout.strip().split("\n")) * 2, 8) - except Exception: - pass - - if platform.system() == "Darwin" and platform.processor() == "arm": - return "mps", 4 - - return "cpu", max(2, (os.cpu_count() // 2)) + + cpu_count = os.cpu_count() or 4 # default to 4 if detection fails + # Use half of available CPUs, but keep between 2-8 workers for reasonable performance + return max(2, min(cpu_count // 2, 8))