diff --git a/README.md b/README.md index fb41caa..7b1103f 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # Scikit-LLM: Scikit-Learn Meets Large Language Models -Seamlessly integrate powerful language models like ChatGPT into scikit-learn for enhanced text analysis tasks. +Seamlessly integrate powerful language models like ChatGPT, Claude, and MiniMax into scikit-learn for enhanced text analysis tasks. ## Installation 💾 @@ -60,6 +60,23 @@ clf.fit(X,y) clf.predict(X) ``` +### Using MiniMax + +```python +from skllm.config import SKLLMConfig +from skllm.models.minimax.classification.zero_shot import ZeroShotMiniMaxClassifier + +# Configure the credentials +SKLLMConfig.set_minimax_key("") + +# Initialize the model and make predictions +clf = ZeroShotMiniMaxClassifier(model="MiniMax-M2.7") +clf.fit(X, y) +clf.predict(X) +``` + +Available MiniMax models: `MiniMax-M2.7`, `MiniMax-M2.5`, `MiniMax-M2.5-highspeed`. + For more information please refer to the **[documentation](https://skllm.beastbyte.ai)**. ## Citation diff --git a/skllm/config.py b/skllm/config.py index ea6709b..d063c3f 100644 --- a/skllm/config.py +++ b/skllm/config.py @@ -11,6 +11,7 @@ _GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH" _GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS" _GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE" +_MINIMAX_KEY_VAR = "SKLLM_CONFIG_MINIMAX_KEY" class SKLLMConfig: @@ -192,6 +193,28 @@ def get_anthropic_key() -> Optional[str]: """ return os.environ.get(_ANTHROPIC_KEY_VAR, None) + @staticmethod + def set_minimax_key(key: str) -> None: + """Sets the MiniMax API key. + + Parameters + ---------- + key : str + MiniMax API key. + """ + os.environ[_MINIMAX_KEY_VAR] = key + + @staticmethod + def get_minimax_key() -> Optional[str]: + """Gets the MiniMax API key. + + Returns + ------- + Optional[str] + MiniMax API key. + """ + return os.environ.get(_MINIMAX_KEY_VAR, None) + @staticmethod def reset_gpt_url(): """Resets the GPT URL.""" diff --git a/skllm/llm/minimax/completion.py b/skllm/llm/minimax/completion.py new file mode 100644 index 0000000..ded3b87 --- /dev/null +++ b/skllm/llm/minimax/completion.py @@ -0,0 +1,102 @@ +import re +from typing import Dict, List, Optional +from skllm.llm.minimax.credentials import set_credentials +from skllm.utils import retry +from skllm.model_constants import MINIMAX_MODEL + + +def _strip_think_tags(text: str) -> str: + """Strip ... tags from MiniMax model responses. + + Also handles unclosed tags (e.g. when the model runs out of tokens + while still in the thinking phase). + """ + # First strip properly closed think tags + text = re.sub(r"[\s\S]*?\s*", "", text) + # Then strip unclosed think tags (content truncated during thinking) + text = re.sub(r"[\s\S]*$", "", text) + return text.strip() + + +@retry(max_retries=3) +def get_chat_completion( + messages: List[Dict], + key: str, + model: str = MINIMAX_MODEL, + max_tokens: int = 1000, + temperature: float = 0.0, + system: Optional[str] = None, + json_response: bool = False, +) -> dict: + """Gets a chat completion from the MiniMax API via OpenAI-compatible endpoint. + + Parameters + ---------- + messages : list + Input messages to use. + key : str + The MiniMax API key to use. + model : str, optional + The MiniMax model to use. + max_tokens : int, optional + Maximum tokens to generate. + temperature : float, optional + Sampling temperature (0.0 to 1.0). + system : str, optional + System message to set the assistant's behavior. + json_response : bool, optional + Whether to request a JSON-formatted response. + + Returns + ------- + response : dict + The completion response from the API. + """ + if not messages: + raise ValueError("Messages list cannot be empty") + if not isinstance(messages, list): + raise TypeError("Messages must be a list") + + # Clamp temperature to MiniMax's supported range [0.0, 1.0] + temperature = max(0.0, min(1.0, temperature)) + + client = set_credentials(key) + + formatted_messages = [] + if system: + if json_response: + system = f"{system.rstrip('.')}. Respond in JSON format." + formatted_messages.append({"role": "system", "content": system}) + elif json_response: + formatted_messages.append( + {"role": "system", "content": "Respond in JSON format."} + ) + + for message in messages: + role = message.get("role", "user") + content = message.get("content", "") + formatted_messages.append({"role": role, "content": content}) + + model_dict = { + "model": model, + "max_tokens": max_tokens, + "temperature": temperature, + "messages": formatted_messages, + } + + if json_response: + model_dict["response_format"] = {"type": "json_object"} + + response = client.chat.completions.create(**model_dict) + + # Strip ... tags from the response content + if ( + response.choices + and response.choices[0].message.content + and isinstance(response.choices[0].message.content, str) + ): + response.choices[0].message.content = _strip_think_tags( + response.choices[0].message.content + ) + + return response diff --git a/skllm/llm/minimax/credentials.py b/skllm/llm/minimax/credentials.py new file mode 100644 index 0000000..eaff356 --- /dev/null +++ b/skllm/llm/minimax/credentials.py @@ -0,0 +1,20 @@ +from openai import OpenAI + +MINIMAX_BASE_URL = "https://api.minimax.io/v1" + + +def set_credentials(key: str) -> OpenAI: + """Set MiniMax credentials and return an OpenAI-compatible client. + + Parameters + ---------- + key : str + The MiniMax API key to use. + + Returns + ------- + client : OpenAI + An OpenAI client configured for MiniMax. + """ + client = OpenAI(api_key=key, base_url=MINIMAX_BASE_URL) + return client diff --git a/skllm/llm/minimax/mixin.py b/skllm/llm/minimax/mixin.py new file mode 100644 index 0000000..ecd10c3 --- /dev/null +++ b/skllm/llm/minimax/mixin.py @@ -0,0 +1,102 @@ +from typing import Optional, Union, Any, List, Dict, Mapping +from skllm.config import SKLLMConfig as _Config +from skllm.llm.minimax.completion import get_chat_completion +from skllm.utils import extract_json_key +from skllm.llm.base import BaseTextCompletionMixin, BaseClassifierMixin +import json + + +class MiniMaxMixin: + """A mixin class that provides MiniMax API key to other classes.""" + + _prefer_json_output = False + + def _set_keys(self, key: Optional[str] = None) -> None: + """Set the MiniMax API key.""" + self.key = key + + def _get_minimax_key(self) -> str: + """Get the MiniMax key from the class or config.""" + key = self.key + if key is None: + key = _Config.get_minimax_key() + if key is None: + raise RuntimeError("MiniMax API key was not found") + return key + + +class MiniMaxTextCompletionMixin(MiniMaxMixin, BaseTextCompletionMixin): + """A mixin class that provides text completion capabilities using the MiniMax API.""" + + def _get_chat_completion( + self, + model: str, + messages: Union[str, List[Dict[str, str]]], + system_message: Optional[str] = None, + **kwargs: Any, + ): + """Gets a chat completion from the MiniMax API. + + Parameters + ---------- + model : str + The model to use. + messages : Union[str, List[Dict[str, str]]] + Input messages to use. + system_message : Optional[str] + A system message to use. + + Returns + ------- + completion : dict + """ + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + elif isinstance(messages, list): + messages = [ + {"role": msg.get("role", "user"), "content": msg.get("content", "")} + for msg in messages + ] + + completion = get_chat_completion( + messages=messages, + key=self._get_minimax_key(), + model=model, + system=system_message, + json_response=self._prefer_json_output, + **kwargs, + ) + return completion + + def _convert_completion_to_str(self, completion: Mapping[str, Any]): + """Converts MiniMax API completion to string.""" + try: + if hasattr(completion, "choices"): + return str(completion.choices[0].message.content) + return str(completion["choices"][0]["message"]["content"]) + except Exception as e: + print(f"Error converting completion to string: {str(e)}") + return "" + + +class MiniMaxClassifierMixin(MiniMaxTextCompletionMixin, BaseClassifierMixin): + """A mixin class that provides classification capabilities using MiniMax API.""" + + _prefer_json_output = True + + def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> str: + """Extracts the label from a MiniMax API completion.""" + try: + content = self._convert_completion_to_str(completion) + if not self._prefer_json_output: + return content.strip() + try: + label = extract_json_key(content, "label") + if label is not None: + return label + except Exception: + pass + return "" + except Exception as e: + print(f"Error extracting label: {str(e)}") + return "" diff --git a/skllm/model_constants.py b/skllm/model_constants.py index 91c408e..26e2fb5 100644 --- a/skllm/model_constants.py +++ b/skllm/model_constants.py @@ -6,5 +6,8 @@ # Anthropic (Claude) models ANTHROPIC_CLAUDE_MODEL = "claude-3-haiku-20240307" +# MiniMax models +MINIMAX_MODEL = "MiniMax-M2.7" + # Vertex AI models VERTEX_DEFAULT_MODEL = "text-bison@002" diff --git a/skllm/models/minimax/classification/few_shot.py b/skllm/models/minimax/classification/few_shot.py new file mode 100644 index 0000000..b820687 --- /dev/null +++ b/skllm/models/minimax/classification/few_shot.py @@ -0,0 +1,148 @@ +from skllm.models._base.classifier import ( + BaseFewShotClassifier, + BaseDynamicFewShotClassifier, + SingleLabelMixin, + MultiLabelMixin, +) +from skllm.llm.minimax.mixin import MiniMaxClassifierMixin +from skllm.models.gpt.vectorization import GPTVectorizer +from skllm.models._base.vectorizer import BaseVectorizer +from skllm.memory.base import IndexConstructor +from typing import Optional +from skllm.model_constants import MINIMAX_MODEL, OPENAI_EMBEDDING_MODEL + + +class FewShotMiniMaxClassifier( + BaseFewShotClassifier, MiniMaxClassifierMixin, SingleLabelMixin +): + """Few-shot text classifier using MiniMax API for single-label classification tasks.""" + + def __init__( + self, + model: str = MINIMAX_MODEL, + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Few-shot text classifier using MiniMax API. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + default_label : str, optional + Default label for failed predictions; if "Random", selects + randomly based on class frequencies, by default "Random". + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the + global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class MultiLabelFewShotMiniMaxClassifier( + BaseFewShotClassifier, MiniMaxClassifierMixin, MultiLabelMixin +): + """Few-shot text classifier using MiniMax API for multi-label classification tasks.""" + + def __init__( + self, + model: str = MINIMAX_MODEL, + default_label: str = "Random", + max_labels: Optional[int] = 5, + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Multi-label few-shot text classifier using MiniMax API. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + default_label : str, optional + Default label for failed predictions; if "Random", selects + randomly based on class frequencies, by default "Random". + max_labels : Optional[int], optional + Maximum labels per sample, by default 5. + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the + global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + max_labels=max_labels, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class DynamicFewShotMiniMaxClassifier( + BaseDynamicFewShotClassifier, MiniMaxClassifierMixin, SingleLabelMixin +): + """Dynamic few-shot text classifier using MiniMax API with dynamic example selection.""" + + def __init__( + self, + model: str = MINIMAX_MODEL, + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + n_examples: int = 3, + memory_index: Optional[IndexConstructor] = None, + vectorizer: Optional[BaseVectorizer] = None, + metric: Optional[str] = "euclidean", + **kwargs, + ): + """ + Dynamic few-shot text classifier using MiniMax API. + For each sample, N closest examples are retrieved from the memory. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + default_label : str, optional + Default label for failed predictions; if "Random", selects + randomly based on class frequencies, by default "Random". + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the + global config, by default None. + n_examples : int, optional + Number of closest examples per class, by default 3. + memory_index : Optional[IndexConstructor], optional + Custom memory index, for details check `skllm.memory` submodule. + vectorizer : Optional[BaseVectorizer], optional + Scikit-llm vectorizer; if None, `GPTVectorizer` is used. + metric : Optional[str], optional + Metric used for similarity search, by default "euclidean". + """ + if vectorizer is None: + vectorizer = GPTVectorizer(model=OPENAI_EMBEDDING_MODEL, key=key) + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + n_examples=n_examples, + memory_index=memory_index, + vectorizer=vectorizer, + metric=metric, + ) + self._set_keys(key) diff --git a/skllm/models/minimax/classification/zero_shot.py b/skllm/models/minimax/classification/zero_shot.py new file mode 100644 index 0000000..7a24f60 --- /dev/null +++ b/skllm/models/minimax/classification/zero_shot.py @@ -0,0 +1,127 @@ +from skllm.models._base.classifier import ( + SingleLabelMixin as _SingleLabelMixin, + MultiLabelMixin as _MultiLabelMixin, + BaseZeroShotClassifier as _BaseZeroShotClassifier, + BaseCoTClassifier as _BaseCoTClassifier, +) +from skllm.llm.minimax.mixin import MiniMaxClassifierMixin as _MiniMaxClassifierMixin +from typing import Optional +from skllm.model_constants import MINIMAX_MODEL + + +class ZeroShotMiniMaxClassifier( + _BaseZeroShotClassifier, _MiniMaxClassifierMixin, _SingleLabelMixin +): + """Zero-shot text classifier using MiniMax models for single-label classification.""" + + def __init__( + self, + model: str = MINIMAX_MODEL, + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Zero-shot text classifier using MiniMax models. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + default_label : str, optional + Default label for failed predictions; if "Random", selects + randomly based on class frequencies, by default "Random". + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the + global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class CoTMiniMaxClassifier( + _BaseCoTClassifier, _MiniMaxClassifierMixin, _SingleLabelMixin +): + """Chain-of-thought text classifier using MiniMax models for single-label classification.""" + + def __init__( + self, + model: str = MINIMAX_MODEL, + default_label: str = "Random", + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Chain-of-thought text classifier using MiniMax models. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + default_label : str, optional + Default label for failed predictions; if "Random", selects + randomly based on class frequencies, by default "Random". + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the + global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) + + +class MultiLabelZeroShotMiniMaxClassifier( + _BaseZeroShotClassifier, _MiniMaxClassifierMixin, _MultiLabelMixin +): + """Zero-shot text classifier using MiniMax models for multi-label classification.""" + + def __init__( + self, + model: str = MINIMAX_MODEL, + default_label: str = "Random", + max_labels: Optional[int] = 5, + prompt_template: Optional[str] = None, + key: Optional[str] = None, + **kwargs, + ): + """ + Multi-label zero-shot text classifier using MiniMax models. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + default_label : str, optional + Default label for failed predictions; if "Random", selects + randomly based on class frequencies, by default "Random". + max_labels : Optional[int], optional + Maximum number of labels per sample, by default 5. + prompt_template : Optional[str], optional + Custom prompt template to use, by default None. + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the + global config, by default None. + """ + super().__init__( + model=model, + default_label=default_label, + max_labels=max_labels, + prompt_template=prompt_template, + **kwargs, + ) + self._set_keys(key) diff --git a/skllm/models/minimax/tagging/ner.py b/skllm/models/minimax/tagging/ner.py new file mode 100644 index 0000000..c2d7a08 --- /dev/null +++ b/skllm/models/minimax/tagging/ner.py @@ -0,0 +1,46 @@ +from skllm.models._base.tagger import ExplainableNER as _ExplainableNER +from skllm.llm.minimax.mixin import ( + MiniMaxTextCompletionMixin as _MiniMaxTextCompletionMixin, +) +from typing import Optional, Dict +from skllm.model_constants import MINIMAX_MODEL + + +class MiniMaxExplainableNER(_ExplainableNER, _MiniMaxTextCompletionMixin): + """Named Entity Recognition model using MiniMax API for explainable entity extraction.""" + + def __init__( + self, + entities: Dict[str, str], + display_predictions: bool = False, + sparse_output: bool = True, + model: str = MINIMAX_MODEL, + key: Optional[str] = None, + num_workers: int = 1, + ) -> None: + """ + Named entity recognition using MiniMax API. + + Parameters + ---------- + entities : dict + Dictionary of entities to recognize, with keys as entity + names and values as descriptions. + display_predictions : bool, optional + Whether to display predictions, by default False. + sparse_output : bool, optional + Whether to generate a sparse representation, by default True. + model : str, optional + Model to use, by default "MiniMax-M2.7". + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from the + global config, by default None. + num_workers : int, optional + Number of workers (threads) to use, by default 1. + """ + self._set_keys(key) + self.model = model + self.entities = entities + self.display_predictions = display_predictions + self.sparse_output = sparse_output + self.num_workers = num_workers diff --git a/skllm/models/minimax/text2text/__init__.py b/skllm/models/minimax/text2text/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/skllm/models/minimax/text2text/__init__.py @@ -0,0 +1 @@ + diff --git a/skllm/models/minimax/text2text/summarization.py b/skllm/models/minimax/text2text/summarization.py new file mode 100644 index 0000000..f6df5dd --- /dev/null +++ b/skllm/models/minimax/text2text/summarization.py @@ -0,0 +1,37 @@ +from skllm.models._base.text2text import BaseSummarizer as _BaseSummarizer +from skllm.llm.minimax.mixin import ( + MiniMaxTextCompletionMixin as _MiniMaxTextCompletionMixin, +) +from typing import Optional +from skllm.model_constants import MINIMAX_MODEL + + +class MiniMaxSummarizer(_BaseSummarizer, _MiniMaxTextCompletionMixin): + """Text summarizer using MiniMax API.""" + + def __init__( + self, + model: str = MINIMAX_MODEL, + key: Optional[str] = None, + max_words: int = 15, + focus: Optional[str] = None, + ) -> None: + """ + Text summarizer using MiniMax API. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from + the global config, by default None. + max_words : int, optional + Soft limit of the summary length, by default 15. + focus : Optional[str], optional + Concept in the text to focus on, by default None. + """ + self._set_keys(key) + self.model = model + self.max_words = max_words + self.focus = focus diff --git a/skllm/models/minimax/text2text/translation.py b/skllm/models/minimax/text2text/translation.py new file mode 100644 index 0000000..8cefaac --- /dev/null +++ b/skllm/models/minimax/text2text/translation.py @@ -0,0 +1,35 @@ +from skllm.models._base.text2text import BaseTranslator as _BaseTranslator +from skllm.llm.minimax.mixin import ( + MiniMaxTextCompletionMixin as _MiniMaxTextCompletionMixin, +) +from typing import Optional +from skllm.model_constants import MINIMAX_MODEL + + +class MiniMaxTranslator(_BaseTranslator, _MiniMaxTextCompletionMixin): + """Text translator using MiniMax API.""" + + default_output = "Translation is unavailable." + + def __init__( + self, + model: str = MINIMAX_MODEL, + key: Optional[str] = None, + output_language: str = "English", + ) -> None: + """ + Text translator using MiniMax API. + + Parameters + ---------- + model : str, optional + Model to use, by default "MiniMax-M2.7". + key : Optional[str], optional + Estimator-specific API key; if None, retrieved from + the global config, by default None. + output_language : str, optional + Target language, by default "English". + """ + self._set_keys(key) + self.model = model + self.output_language = output_language diff --git a/tests/llm/minimax/__init__.py b/tests/llm/minimax/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/llm/minimax/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/llm/minimax/test_minimax_integration.py b/tests/llm/minimax/test_minimax_integration.py new file mode 100644 index 0000000..b5b05da --- /dev/null +++ b/tests/llm/minimax/test_minimax_integration.py @@ -0,0 +1,107 @@ +"""Integration tests for MiniMax provider. + +These tests require a valid MINIMAX_API_KEY environment variable. +They are skipped when the key is not available. +""" +import os +import unittest + +MINIMAX_KEY = os.environ.get("MINIMAX_API_KEY") +SKIP_REASON = "MINIMAX_API_KEY not set" + + +@unittest.skipUnless(MINIMAX_KEY, SKIP_REASON) +class TestMiniMaxCompletionIntegration(unittest.TestCase): + def test_basic_completion(self): + from skllm.llm.minimax.completion import get_chat_completion + + response = get_chat_completion( + messages=[{"role": "user", "content": "Say hello in one word."}], + key=MINIMAX_KEY, + model="MiniMax-M2.5-highspeed", + max_tokens=500, + ) + content = response.choices[0].message.content + self.assertIsInstance(content, str) + self.assertTrue(len(content) > 0) + + def test_completion_with_system_message(self): + from skllm.llm.minimax.completion import get_chat_completion + + response = get_chat_completion( + messages=[{"role": "user", "content": "What are you?"}], + key=MINIMAX_KEY, + model="MiniMax-M2.5-highspeed", + system="You are a friendly bot. Respond in one sentence.", + max_tokens=500, + ) + content = response.choices[0].message.content + self.assertIsInstance(content, str) + self.assertTrue(len(content) > 0) + + def test_completion_json_response(self): + import json + from skllm.llm.minimax.completion import get_chat_completion + from skllm.utils import find_json_in_string + + response = get_chat_completion( + messages=[ + { + "role": "user", + "content": 'Classify the sentiment of "I love this product" as positive, negative, or neutral. Return JSON with a "label" key.', + } + ], + key=MINIMAX_KEY, + model="MiniMax-M2.5-highspeed", + json_response=True, + max_tokens=1000, + ) + content = response.choices[0].message.content + # The model may wrap JSON in markdown code fences + json_str = find_json_in_string(content) + data = json.loads(json_str) + self.assertIn("label", data) + + +@unittest.skipUnless(MINIMAX_KEY, SKIP_REASON) +class TestMiniMaxMixinIntegration(unittest.TestCase): + def test_text_completion_mixin(self): + from skllm.llm.minimax.mixin import MiniMaxTextCompletionMixin + + mixin = MiniMaxTextCompletionMixin() + mixin._set_keys(MINIMAX_KEY) + + completion = mixin._get_chat_completion( + model="MiniMax-M2.5-highspeed", + messages="Say 'test' and nothing else.", + ) + result = mixin._convert_completion_to_str(completion) + self.assertIsInstance(result, str) + self.assertTrue(len(result) > 0) + + def test_classifier_mixin(self): + from skllm.llm.minimax.mixin import MiniMaxClassifierMixin + + mixin = MiniMaxClassifierMixin() + mixin._set_keys(MINIMAX_KEY) + + completion = mixin._get_chat_completion( + model="MiniMax-M2.5-highspeed", + messages='Classify "I love this!" as positive, negative, or neutral.', + system_message="You are a text classifier. Return a JSON object with a single key 'label'.", + ) + label = mixin._extract_out_label(completion) + self.assertIsInstance(label, str) + + +@unittest.skipUnless(MINIMAX_KEY, SKIP_REASON) +class TestMiniMaxConfigIntegration(unittest.TestCase): + def test_config_set_get_key(self): + from skllm.config import SKLLMConfig + + SKLLMConfig.set_minimax_key(MINIMAX_KEY) + self.assertEqual(SKLLMConfig.get_minimax_key(), MINIMAX_KEY) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/llm/minimax/test_minimax_mixins.py b/tests/llm/minimax/test_minimax_mixins.py new file mode 100644 index 0000000..047b7e8 --- /dev/null +++ b/tests/llm/minimax/test_minimax_mixins.py @@ -0,0 +1,319 @@ +import unittest +from unittest.mock import patch, MagicMock +import json +from skllm.llm.minimax.mixin import ( + MiniMaxMixin, + MiniMaxTextCompletionMixin, + MiniMaxClassifierMixin, +) + + +class TestMiniMaxMixin(unittest.TestCase): + def test_set_and_get_key(self): + mixin = MiniMaxMixin() + mixin._set_keys("test_minimax_key") + self.assertEqual(mixin._get_minimax_key(), "test_minimax_key") + + def test_get_key_from_config(self): + mixin = MiniMaxMixin() + mixin._set_keys(None) + with patch("skllm.llm.minimax.mixin._Config") as mock_config: + mock_config.get_minimax_key.return_value = "config_key" + self.assertEqual(mixin._get_minimax_key(), "config_key") + + def test_get_key_raises_when_not_found(self): + mixin = MiniMaxMixin() + mixin._set_keys(None) + with patch("skllm.llm.minimax.mixin._Config") as mock_config: + mock_config.get_minimax_key.return_value = None + with self.assertRaises(RuntimeError): + mixin._get_minimax_key() + + +class TestMiniMaxTextCompletionMixin(unittest.TestCase): + @patch("skllm.llm.minimax.mixin.get_chat_completion") + def test_chat_completion_with_string_message(self, mock_get_chat_completion): + mixin = MiniMaxTextCompletionMixin() + mixin._set_keys("test_key") + + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].message.content = "test response" + mock_get_chat_completion.return_value = mock_completion + + completion = mixin._get_chat_completion( + model="MiniMax-M2.7", + messages="Hello", + system_message="Test system", + ) + + self.assertEqual( + mixin._convert_completion_to_str(completion), + "test response", + ) + mock_get_chat_completion.assert_called_once() + + @patch("skllm.llm.minimax.mixin.get_chat_completion") + def test_chat_completion_with_list_messages(self, mock_get_chat_completion): + mixin = MiniMaxTextCompletionMixin() + mixin._set_keys("test_key") + + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].message.content = "response" + mock_get_chat_completion.return_value = mock_completion + + completion = mixin._get_chat_completion( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": "Hello"}], + ) + + self.assertEqual( + mixin._convert_completion_to_str(completion), + "response", + ) + + @patch("skllm.llm.minimax.mixin.get_chat_completion") + def test_convert_completion_dict_format(self, mock_get_chat_completion): + mixin = MiniMaxTextCompletionMixin() + mixin._set_keys("test_key") + + completion = { + "choices": [{"message": {"content": "dict response"}}] + } + self.assertEqual( + mixin._convert_completion_to_str(completion), + "dict response", + ) + + +class TestMiniMaxClassifierMixin(unittest.TestCase): + @patch("skllm.llm.minimax.mixin.get_chat_completion") + def test_extract_out_label_with_valid_json(self, mock_get_chat_completion): + mixin = MiniMaxClassifierMixin() + mixin._set_keys("test_key") + + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].message.content = '{"label":"positive"}' + mock_get_chat_completion.return_value = mock_completion + + completion = mixin._get_chat_completion( + model="MiniMax-M2.7", + messages="Classify this text", + system_message="You are a classifier", + ) + self.assertEqual(mixin._extract_out_label(completion), "positive") + + @patch("skllm.llm.minimax.mixin.get_chat_completion") + def test_extract_out_label_with_invalid_json(self, mock_get_chat_completion): + mixin = MiniMaxClassifierMixin() + mixin._set_keys("test_key") + + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].message.content = "not json" + mock_get_chat_completion.return_value = mock_completion + + completion = mixin._get_chat_completion( + model="MiniMax-M2.7", + messages="test", + ) + self.assertEqual(mixin._extract_out_label(completion), "") + + def test_prefer_json_output_is_true(self): + mixin = MiniMaxClassifierMixin() + self.assertTrue(mixin._prefer_json_output) + + +class TestMiniMaxCompletion(unittest.TestCase): + @patch("skllm.llm.minimax.completion.set_credentials") + def test_get_chat_completion_validates_messages(self, mock_creds): + from skllm.llm.minimax.completion import get_chat_completion + + # @retry wraps exceptions in RuntimeError after max_retries + with self.assertRaises(RuntimeError): + get_chat_completion(messages=[], key="test") + + with self.assertRaises(RuntimeError): + get_chat_completion(messages="not a list", key="test") + + @patch("skllm.llm.minimax.completion.set_credentials") + def test_get_chat_completion_clamps_temperature(self, mock_creds): + mock_client = MagicMock() + mock_creds.return_value = mock_client + mock_client.chat.completions.create.return_value = MagicMock() + + from skllm.llm.minimax.completion import get_chat_completion + + get_chat_completion( + messages=[{"role": "user", "content": "hi"}], + key="test", + temperature=2.0, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + self.assertLessEqual(call_kwargs["temperature"], 1.0) + + @patch("skllm.llm.minimax.completion.set_credentials") + def test_get_chat_completion_with_system_message(self, mock_creds): + mock_client = MagicMock() + mock_creds.return_value = mock_client + mock_client.chat.completions.create.return_value = MagicMock() + + from skllm.llm.minimax.completion import get_chat_completion + + get_chat_completion( + messages=[{"role": "user", "content": "hi"}], + key="test", + system="You are helpful.", + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + messages = call_kwargs["messages"] + self.assertEqual(messages[0]["role"], "system") + self.assertEqual(messages[0]["content"], "You are helpful.") + + @patch("skllm.llm.minimax.completion.set_credentials") + def test_get_chat_completion_json_response(self, mock_creds): + mock_client = MagicMock() + mock_creds.return_value = mock_client + mock_client.chat.completions.create.return_value = MagicMock() + + from skllm.llm.minimax.completion import get_chat_completion + + get_chat_completion( + messages=[{"role": "user", "content": "hi"}], + key="test", + json_response=True, + ) + + call_kwargs = mock_client.chat.completions.create.call_args[1] + self.assertEqual( + call_kwargs["response_format"], {"type": "json_object"} + ) + + +class TestMiniMaxStripThinkTags(unittest.TestCase): + def test_strip_think_tags(self): + from skllm.llm.minimax.completion import _strip_think_tags + + text = '\nSome reasoning here\n\n{"label": "positive"}' + result = _strip_think_tags(text) + self.assertEqual(result, '{"label": "positive"}') + + def test_strip_think_tags_no_tags(self): + from skllm.llm.minimax.completion import _strip_think_tags + + text = '{"label": "positive"}' + result = _strip_think_tags(text) + self.assertEqual(result, '{"label": "positive"}') + + def test_strip_think_tags_empty(self): + from skllm.llm.minimax.completion import _strip_think_tags + + text = "" + result = _strip_think_tags(text) + self.assertEqual(result, "") + + +class TestMiniMaxCredentials(unittest.TestCase): + @patch("skllm.llm.minimax.credentials.OpenAI") + def test_set_credentials_returns_client(self, mock_openai): + from skllm.llm.minimax.credentials import set_credentials, MINIMAX_BASE_URL + + mock_client = MagicMock() + mock_openai.return_value = mock_client + + client = set_credentials("test_key") + mock_openai.assert_called_once_with( + api_key="test_key", base_url=MINIMAX_BASE_URL + ) + self.assertEqual(client, mock_client) + + +class TestMiniMaxConfig(unittest.TestCase): + def test_set_and_get_minimax_key(self): + from skllm.config import SKLLMConfig + import os + + SKLLMConfig.set_minimax_key("test_minimax_key_123") + self.assertEqual( + SKLLMConfig.get_minimax_key(), "test_minimax_key_123" + ) + # Clean up + os.environ.pop("SKLLM_CONFIG_MINIMAX_KEY", None) + + def test_get_minimax_key_returns_none_when_not_set(self): + from skllm.config import SKLLMConfig + import os + + os.environ.pop("SKLLM_CONFIG_MINIMAX_KEY", None) + self.assertIsNone(SKLLMConfig.get_minimax_key()) + + +class TestMiniMaxModelConstants(unittest.TestCase): + def test_minimax_model_constant(self): + from skllm.model_constants import MINIMAX_MODEL + + self.assertEqual(MINIMAX_MODEL, "MiniMax-M2.7") + + +class TestMiniMaxModels(unittest.TestCase): + def test_zero_shot_classifier_init(self): + from skllm.models.minimax.classification.zero_shot import ( + ZeroShotMiniMaxClassifier, + ) + + clf = ZeroShotMiniMaxClassifier(key="test_key") + self.assertEqual(clf.model, "MiniMax-M2.7") + self.assertEqual(clf.key, "test_key") + + def test_cot_classifier_init(self): + from skllm.models.minimax.classification.zero_shot import ( + CoTMiniMaxClassifier, + ) + + clf = CoTMiniMaxClassifier(key="test_key", model="MiniMax-M2.5") + self.assertEqual(clf.model, "MiniMax-M2.5") + + def test_multi_label_classifier_init(self): + from skllm.models.minimax.classification.zero_shot import ( + MultiLabelZeroShotMiniMaxClassifier, + ) + + clf = MultiLabelZeroShotMiniMaxClassifier(key="test_key", max_labels=3) + self.assertEqual(clf.max_labels, 3) + + def test_few_shot_classifier_init(self): + from skllm.models.minimax.classification.few_shot import ( + FewShotMiniMaxClassifier, + ) + + clf = FewShotMiniMaxClassifier(key="test_key") + self.assertEqual(clf.model, "MiniMax-M2.7") + + def test_summarizer_init(self): + from skllm.models.minimax.text2text.summarization import MiniMaxSummarizer + + s = MiniMaxSummarizer(key="test_key", max_words=20) + self.assertEqual(s.model, "MiniMax-M2.7") + self.assertEqual(s.max_words, 20) + + def test_translator_init(self): + from skllm.models.minimax.text2text.translation import MiniMaxTranslator + + t = MiniMaxTranslator(key="test_key", output_language="French") + self.assertEqual(t.output_language, "French") + + def test_ner_init(self): + from skllm.models.minimax.tagging.ner import MiniMaxExplainableNER + + entities = {"PERSON": "A person's name", "ORG": "An organization"} + ner = MiniMaxExplainableNER(entities=entities, key="test_key") + self.assertEqual(ner.entities, entities) + self.assertEqual(ner.model, "MiniMax-M2.7") + + +if __name__ == "__main__": + unittest.main()