Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,18 @@ parser = VisionParser(
# Initialize parser with Azure OpenAI model
parser = VisionParser(
model_name="gpt-4o",
api_key="your-azure-openai-api-key", # replace with your Azure OpenAI API key
image_mode="url",
detailed_extraction=False, # set to True for more detailed extraction
enable_concurrency=True,
openai_config={
"AZURE_ENDPOINT_URL": "https://****.openai.azure.com/", # replace with your Azure endpoint URL
"AZURE_DEPLOYMENT_NAME": "*******", # replace with Azure deployment name, if needed
"AZURE_OPENAI_API_KEY": "***********", # replace with your Azure OpenAI API key
"AZURE_OPENAI_API_VERSION": "2024-08-01-preview", # replace with latest Azure OpenAI API version
provider_config={
"base_url": "https://****.openai.azure.com/", # replace with your Azure endpoint URL
"api_version": "2024-08-01-preview", # replace with latest Azure OpenAI API version
"azure": True, # specify that this is Azure OpenAI
"azure_deployment": "*******", # replace with Azure deployment name
},
)


# Initialize parser with Google Gemini model
parser = VisionParser(
model_name="gemini-1.5-flash",
Expand All @@ -126,6 +126,20 @@ parser = VisionParser(
detailed_extraction=False, # set to True for more detailed extraction
enable_concurrency=True,
)

# Initialize parser with model on LiteLLM proxy
parser = VisionParser(
model_name="litellm/provider/model",
api_key="your-litellm-proxy-api-key",
temperature=0.7,
top_p=0.4,
image_mode="url",
detailed_extraction=False, # set to True for more detailed extraction
enable_concurrency=True,
provider_config={
"base_url": "https://litellm.proxy.domain",
},
)
```

## ✅ Tested Models
Expand Down Expand Up @@ -162,30 +176,28 @@ The following Vision LLM models have been thoroughly tested with Vision Parse, b

### Provider-Specific Configuration

#### OpenAI Configuration
The `provider_config` parameter lets you configure provider-specific settings through a unified interface:

```python
openai_config = {
# For standard OpenAI
"OPENAI_BASE_URL": "https://api.openai.com/v1", # optional
"OPENAI_MAX_RETRIES": 3, # optional
"OPENAI_TIMEOUT": 240.0, # optional

# For Azure OpenAI
"AZURE_ENDPOINT_URL": "https://your-resource.openai.azure.com/",
"AZURE_DEPLOYMENT_NAME": "your-deployment-name",
"AZURE_OPENAI_API_KEY": "your-azure-api-key",
"AZURE_OPENAI_API_VERSION": "2024-08-01-preview",
# For OpenAI
provider_config = {
"base_url": "https://api.openai.com/v1", # optional
"max_retries": 3, # optional
"timeout": 240.0, # optional
}
```

#### Gemini Configuration (Google AI Studio)
# For Azure OpenAI
provider_config = {
"base_url": "https://your-resource.openai.azure.com/",
"api_version": "2024-08-01-preview",
"azure": True,
"azure_deployment": "your-deployment-name",
}

```python
gemini_config = {
"GOOGLE_API_KEY": "your-google-api-key", # API key from Google AI Studio (not Vertex AI)
"GEMINI_MAX_RETRIES": 3, # optional
"GEMINI_TIMEOUT": 240.0, # optional
# For Gemini (Google AI Studio)
provider_config = {
"max_retries": 3, # optional
"timeout": 240.0, # optional
}
```

Expand Down
2 changes: 1 addition & 1 deletion src/vision_parse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Common model prefixes for provider detection
PROVIDER_PREFIXES: Dict[str, List[str]] = {
"openai": ["gpt"],
"openai": ["gpt", "litellm"],
"azure": ["gpt"],
"gemini": ["gemini"],
"deepseek": ["deepseek"],
Expand Down
144 changes: 60 additions & 84 deletions src/vision_parse/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,19 @@ class LLM:
def __init__(
self,
model_name: str,
api_key: Optional[str],
temperature: float,
top_p: float,
openai_config: Optional[Dict],
gemini_config: Optional[Dict],
image_mode: Literal["url", "base64", None],
custom_prompt: Optional[str],
detailed_extraction: bool,
enable_concurrency: bool,
api_key: Optional[str] = None,
temperature: float = 0.7,
top_p: float = 0.7,
provider_config: Optional[Dict] = None,
image_mode: Literal["url", "base64", None] = None,
custom_prompt: Optional[str] = None,
detailed_extraction: bool = False,
enable_concurrency: bool = False,
**kwargs: Any,
):
self.model_name = model_name
self.api_key = api_key
self.openai_config = openai_config or {}
self.gemini_config = gemini_config or {}
self.provider_config = provider_config or {}
self.temperature = temperature
self.top_p = top_p
self.image_mode = image_mode
Expand Down Expand Up @@ -103,12 +101,18 @@ def _get_provider_name(self, model_name: str) -> str:
Returns:
str: The provider name (e.g., 'openai', 'gemini').

Note:
Models with 'litellm' prefix are treated as OpenAI models and
converted accordingly by replacing 'litellm' with 'openai'.

Raises:
UnsupportedProviderError: If the model name doesn't match any known provider.
"""

for provider, prefixes in PROVIDER_PREFIXES.items():
if any(model_name.startswith(prefix) for prefix in prefixes):
if model_name.startswith("litellm"):
self.model_name = model_name.replace("litellm", "openai")
return provider

supported_providers = ", ".join(
Expand All @@ -129,66 +133,17 @@ def _get_model_params(self, structured: bool = False) -> Dict[str, Any]:
Returns:
Dict[str, Any]: Dictionary containing model parameters for API calls.
"""
# Base parameters that are common across providers

# Base parameters common across providers
params = {
"model": self.model_name,
"temperature": 0.0 if structured else self.temperature,
"top_p": 0.4 if structured else self.top_p,
**self.kwargs,
}

# Filter kwargs based on provider
if self.provider in ["openai", "azure"]:
# Only include OpenAI-compatible parameters
openai_params = {
k: v
for k, v in self.kwargs.items()
if k not in ["device", "num_workers", "ollama_config"]
}
params.update(openai_params)

if self.openai_config.get("AZURE_OPENAI_API_KEY"):
params.update(
{
"api_key": self.openai_config["AZURE_OPENAI_API_KEY"],
"api_base": self.openai_config["AZURE_ENDPOINT_URL"],
"api_version": self.openai_config.get(
"AZURE_OPENAI_API_VERSION", "2024-08-01-preview"
),
"deployment_id": self.openai_config.get(
"AZURE_DEPLOYMENT_NAME"
),
}
)
else:
params.update(
{
"api_key": self.api_key,
"base_url": self.openai_config.get("OPENAI_BASE_URL"),
"max_retries": self.openai_config.get("OPENAI_MAX_RETRIES", 3),
"timeout": self.openai_config.get("OPENAI_TIMEOUT", 240.0),
}
)
elif self.provider == "gemini":
# Only include Gemini-compatible parameters
gemini_params = {
k: v
for k, v in self.kwargs.items()
if k not in ["device", "num_workers", "ollama_config"]
}
params.update(gemini_params)
params.update(
{
"api_key": self.api_key,
**self.gemini_config,
}
)
elif self.provider == "deepseek":
# Handle DeepSeek parameters
params.update(
{
"api_key": self.api_key,
}
)
# Add API key and provider-specific config
params.update({"api_key": self.api_key, **self.provider_config})

return params

Expand All @@ -214,6 +169,7 @@ async def _get_response(
Raises:
LLMError: If LLM processing fails.
"""

try:
messages = [
{
Expand All @@ -232,26 +188,46 @@ async def _get_response(

params = self._get_model_params(structured)

if structured:
response = await self.client.chat.completions.create(
messages=messages,
response_model=ImageDescription,
**params,
)
return response.model_dump_json()
if self.enable_concurrency:
# Async path
if structured:
response = await self.client.chat.completions.create(
messages=messages,
response_model=ImageDescription,
**params,
)
return response.model_dump_json()
else:
# For non-structured responses, use str as the response model
response = await self.client.chat.completions.create(
messages=messages,
response_model=str,
**params,
)
else:
# For non-structured responses, use str as the response model
response = await self.client.chat.completions.create(
messages=messages,
response_model=str,
**params,
)
return re.sub(
r"```(?:markdown)?\n(.*?)\n```",
r"\1",
response,
flags=re.DOTALL,
)
# Sync path
if structured:
response = self.client.chat.completions.create(
messages=messages,
response_model=ImageDescription,
**params,
)
return response.model_dump_json()
else:
# For non-structured responses, use str as the response model
response = self.client.chat.completions.create(
messages=messages,
response_model=str,
**params,
)

# Process the response for non-structured output
return re.sub(
r"```(?:markdown)?\n(.*?)\n```",
r"\1",
response,
flags=re.DOTALL,
)
except Exception as e:
raise LLMError(f"LLM processing failed: {str(e)}")

Expand Down
18 changes: 5 additions & 13 deletions src/vision_parse/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .exceptions import UnsupportedFileError, VisionParserError
from .llm import LLM
from .utils import get_device_config
from .utils import get_num_workers

logger = logging.getLogger(__name__)
nest_asyncio.apply()
Expand Down Expand Up @@ -41,9 +41,7 @@ def __init__(
api_key: Optional[str] = None,
temperature: float = 0.7,
top_p: float = 0.7,
ollama_config: Optional[Dict] = None,
openai_config: Optional[Dict] = None,
gemini_config: Optional[Dict] = None,
provider_config: Optional[Dict] = None,
image_mode: Literal["url", "base64", None] = None,
custom_prompt: Optional[str] = None,
detailed_extraction: bool = False,
Expand All @@ -58,9 +56,7 @@ def __init__(
api_key (Optional[str]): API key for the LLM provider.
temperature (float): Controls randomness in LLM output. Defaults to 0.7.
top_p (float): Controls diversity in LLM output. Defaults to 0.7.
ollama_config (Optional[Dict]): Configuration for Ollama provider.
openai_config (Optional[Dict]): Configuration for OpenAI provider.
gemini_config (Optional[Dict]): Configuration for Google AI Studio provider.
provider_config (Optional[Dict]): Configuration for the LLM provider.
image_mode (Literal["url", "base64", None]): Mode for handling embedded images.
custom_prompt (Optional[str]): Custom prompt for LLM processing.
detailed_extraction (bool): Enables detailed text extraction. Defaults to False.
Expand All @@ -69,23 +65,19 @@ def __init__(
"""

self.page_config = page_config or PDFPageConfig()
self.device, self.num_workers = get_device_config()
self.num_workers = get_num_workers()
self.enable_concurrency = enable_concurrency

self.llm = LLM(
model_name=model_name,
api_key=api_key,
temperature=temperature,
top_p=top_p,
ollama_config=ollama_config,
openai_config=openai_config,
gemini_config=gemini_config,
provider_config=provider_config,
image_mode=image_mode,
detailed_extraction=detailed_extraction,
custom_prompt=custom_prompt,
enable_concurrency=enable_concurrency,
device=self.device,
num_workers=self.num_workers,
**kwargs,
)

Expand Down
Loading
Loading