diff --git a/packages/uipath_langchain_client/CHANGELOG.md b/packages/uipath_langchain_client/CHANGELOG.md index 3674583..88ae09b 100644 --- a/packages/uipath_langchain_client/CHANGELOG.md +++ b/packages/uipath_langchain_client/CHANGELOG.md @@ -2,6 +2,11 @@ All notable changes to `uipath_langchain_client` will be documented in this file. +## [1.7.1] - 2026-04-04 + +### Added +- `custom_class` parameter in `get_chat_model()` and `get_embedding_model()` factory functions to allow instantiating a user-provided class instead of the auto-detected one + ## [1.7.0] - 2026-04-03 ### Added diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py index 66e749d..3b6d111 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/__version__.py @@ -1,3 +1,3 @@ __title__ = "UiPath LangChain Client" __description__ = "A Python client for interacting with UiPath's LLM services via LangChain." -__version__ = "1.7.0" +__version__ = "1.7.1" diff --git a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py index 4e98d63..163df11 100644 --- a/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py +++ b/packages/uipath_langchain_client/src/uipath_langchain_client/factory.py @@ -87,6 +87,7 @@ def get_chat_model( routing_mode: RoutingMode | str = RoutingMode.PASSTHROUGH, vendor_type: VendorType | str | None = None, api_flavor: ApiFlavor | str | None = None, + custom_class: type[UiPathBaseChatModel] | None = None, **model_kwargs: Any, ) -> UiPathBaseChatModel: """Factory function to create the appropriate LangChain chat model for a given model name. @@ -106,6 +107,9 @@ def get_chat_model( - Bedrock Claude: Default uses UiPathChatAnthropicBedrock. ApiFlavor.CONVERSE uses UiPathChatBedrockConverse, ApiFlavor.INVOKE uses UiPathChatBedrock. + custom_class: A custom class to use for instantiating the chat model instead of the + auto-detected one. Must be a subclass of UiPathBaseChatModel. When provided, + the factory skips vendor detection and uses this class directly. **model_kwargs: Additional keyword arguments to pass to the model constructor. Returns: @@ -128,6 +132,14 @@ def get_chat_model( if not is_uipath_owned: client_settings.validate_byo_model(model_info) + if custom_class is not None: + return custom_class( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ) + if routing_mode == RoutingMode.NORMALIZED: from uipath_langchain_client.clients.normalized.chat_models import ( UiPathChat, @@ -248,6 +260,7 @@ def get_embedding_model( client_settings: UiPathBaseSettings | None = None, routing_mode: RoutingMode | str = RoutingMode.PASSTHROUGH, vendor_type: VendorType | str | None = None, + custom_class: type[UiPathBaseEmbeddings] | None = None, **model_kwargs: Any, ) -> UiPathBaseEmbeddings: """Factory function to create the appropriate LangChain embeddings model. @@ -262,6 +275,9 @@ def get_embedding_model( RoutingMode.PASSTHROUGH for vendor-specific APIs. vendor_type: Filter models by vendor type (e.g., VendorType.OPENAI). If not provided, auto-detected from the model discovery endpoint. + custom_class: A custom class to use for instantiating the embedding model instead of + the auto-detected one. Must be a subclass of UiPathBaseEmbeddings. When provided, + the factory skips vendor detection and uses this class directly. **model_kwargs: Additional arguments passed to the embeddings constructor. Returns: @@ -286,6 +302,14 @@ def get_embedding_model( if not is_uipath_owned: client_settings.validate_byo_model(model_info) + if custom_class is not None: + return custom_class( + model=model_name, + settings=client_settings, + byo_connection_id=byo_connection_id, + **model_kwargs, + ) + if routing_mode == RoutingMode.NORMALIZED: from uipath_langchain_client.clients.normalized.embeddings import ( UiPathEmbeddings, diff --git a/tests/cassettes.db b/tests/cassettes.db index 7caa33b..9ffafbc 100644 Binary files a/tests/cassettes.db and b/tests/cassettes.db differ diff --git a/tests/langchain/test_factory_function.py b/tests/langchain/test_factory_function.py index 50b256c..6ccdb6e 100644 --- a/tests/langchain/test_factory_function.py +++ b/tests/langchain/test_factory_function.py @@ -1,4 +1,6 @@ import pytest +from uipath_langchain_client.clients.normalized.chat_models import UiPathChat +from uipath_langchain_client.clients.normalized.embeddings import UiPathEmbeddings from uipath_langchain_client.factory import get_chat_model, get_embedding_model from tests.langchain.conftest import COMPLETION_MODEL_NAMES, EMBEDDING_MODEL_NAMES @@ -18,3 +20,27 @@ def test_get_embedding_model(self, model_name: str, client_settings: UiPathBaseS model_name=model_name, client_settings=client_settings ) assert embedding_model is not None + + @pytest.mark.parametrize("model_name", COMPLETION_MODEL_NAMES) + def test_get_chat_model_custom_class( + self, model_name: str, client_settings: UiPathBaseSettings + ): + chat_model = get_chat_model( + model_name=model_name, + client_settings=client_settings, + custom_class=UiPathChat, + ) + assert chat_model is not None + assert isinstance(chat_model, UiPathChat) + + @pytest.mark.parametrize("model_name", EMBEDDING_MODEL_NAMES) + def test_get_embedding_model_custom_class( + self, model_name: str, client_settings: UiPathBaseSettings + ): + embedding_model = get_embedding_model( + model_name=model_name, + client_settings=client_settings, + custom_class=UiPathEmbeddings, + ) + assert embedding_model is not None + assert isinstance(embedding_model, UiPathEmbeddings)