Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Scikit-LLM: Scikit-Learn Meets Large Language Models

Seamlessly integrate powerful language models like ChatGPT into scikit-learn for enhanced text analysis tasks.
Seamlessly integrate powerful language models like ChatGPT, Claude, and MiniMax into scikit-learn for enhanced text analysis tasks.

## Installation 💾

Expand Down Expand Up @@ -60,6 +60,23 @@ clf.fit(X,y)
clf.predict(X)
```

### Using MiniMax

```python
from skllm.config import SKLLMConfig
from skllm.models.minimax.classification.zero_shot import ZeroShotMiniMaxClassifier

# Configure the credentials
SKLLMConfig.set_minimax_key("<YOUR_MINIMAX_API_KEY>")

# Initialize the model and make predictions
clf = ZeroShotMiniMaxClassifier(model="MiniMax-M2.7")
clf.fit(X, y)
clf.predict(X)
```

Available MiniMax models: `MiniMax-M2.7`, `MiniMax-M2.5`, `MiniMax-M2.5-highspeed`.

For more information please refer to the **[documentation](https://skllm.beastbyte.ai)**.

## Citation
Expand Down
23 changes: 23 additions & 0 deletions skllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH"
_GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS"
_GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE"
_MINIMAX_KEY_VAR = "SKLLM_CONFIG_MINIMAX_KEY"


class SKLLMConfig:
Expand Down Expand Up @@ -192,6 +193,28 @@ def get_anthropic_key() -> Optional[str]:
"""
return os.environ.get(_ANTHROPIC_KEY_VAR, None)

@staticmethod
def set_minimax_key(key: str) -> None:
"""Sets the MiniMax API key.

Parameters
----------
key : str
MiniMax API key.
"""
os.environ[_MINIMAX_KEY_VAR] = key

@staticmethod
def get_minimax_key() -> Optional[str]:
"""Gets the MiniMax API key.

Returns
-------
Optional[str]
MiniMax API key.
"""
return os.environ.get(_MINIMAX_KEY_VAR, None)

@staticmethod
def reset_gpt_url():
"""Resets the GPT URL."""
Expand Down
102 changes: 102 additions & 0 deletions skllm/llm/minimax/completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import re
from typing import Dict, List, Optional
from skllm.llm.minimax.credentials import set_credentials
from skllm.utils import retry
from skllm.model_constants import MINIMAX_MODEL


def _strip_think_tags(text: str) -> str:
"""Strip <think>...</think> tags from MiniMax model responses.

Also handles unclosed <think> tags (e.g. when the model runs out of tokens
while still in the thinking phase).
"""
# First strip properly closed think tags
text = re.sub(r"<think>[\s\S]*?</think>\s*", "", text)
# Then strip unclosed think tags (content truncated during thinking)
text = re.sub(r"<think>[\s\S]*$", "", text)
return text.strip()


@retry(max_retries=3)
def get_chat_completion(
messages: List[Dict],
key: str,
model: str = MINIMAX_MODEL,
max_tokens: int = 1000,
temperature: float = 0.0,
system: Optional[str] = None,
json_response: bool = False,
) -> dict:
"""Gets a chat completion from the MiniMax API via OpenAI-compatible endpoint.

Parameters
----------
messages : list
Input messages to use.
key : str
The MiniMax API key to use.
model : str, optional
The MiniMax model to use.
max_tokens : int, optional
Maximum tokens to generate.
temperature : float, optional
Sampling temperature (0.0 to 1.0).
system : str, optional
System message to set the assistant's behavior.
json_response : bool, optional
Whether to request a JSON-formatted response.

Returns
-------
response : dict
The completion response from the API.
"""
if not messages:
raise ValueError("Messages list cannot be empty")
if not isinstance(messages, list):
raise TypeError("Messages must be a list")

# Clamp temperature to MiniMax's supported range [0.0, 1.0]
temperature = max(0.0, min(1.0, temperature))

client = set_credentials(key)

formatted_messages = []
if system:
if json_response:
system = f"{system.rstrip('.')}. Respond in JSON format."
formatted_messages.append({"role": "system", "content": system})
elif json_response:
formatted_messages.append(
{"role": "system", "content": "Respond in JSON format."}
)

for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
formatted_messages.append({"role": role, "content": content})

model_dict = {
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": formatted_messages,
}

if json_response:
model_dict["response_format"] = {"type": "json_object"}

response = client.chat.completions.create(**model_dict)

# Strip <think>...</think> tags from the response content
if (
response.choices
and response.choices[0].message.content
and isinstance(response.choices[0].message.content, str)
):
response.choices[0].message.content = _strip_think_tags(
response.choices[0].message.content
)

return response
20 changes: 20 additions & 0 deletions skllm/llm/minimax/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from openai import OpenAI

MINIMAX_BASE_URL = "https://api.minimax.io/v1"


def set_credentials(key: str) -> OpenAI:
"""Set MiniMax credentials and return an OpenAI-compatible client.

Parameters
----------
key : str
The MiniMax API key to use.

Returns
-------
client : OpenAI
An OpenAI client configured for MiniMax.
"""
client = OpenAI(api_key=key, base_url=MINIMAX_BASE_URL)
return client
102 changes: 102 additions & 0 deletions skllm/llm/minimax/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Optional, Union, Any, List, Dict, Mapping
from skllm.config import SKLLMConfig as _Config
from skllm.llm.minimax.completion import get_chat_completion
from skllm.utils import extract_json_key
from skllm.llm.base import BaseTextCompletionMixin, BaseClassifierMixin
import json


class MiniMaxMixin:
"""A mixin class that provides MiniMax API key to other classes."""

_prefer_json_output = False

def _set_keys(self, key: Optional[str] = None) -> None:
"""Set the MiniMax API key."""
self.key = key

def _get_minimax_key(self) -> str:
"""Get the MiniMax key from the class or config."""
key = self.key
if key is None:
key = _Config.get_minimax_key()
if key is None:
raise RuntimeError("MiniMax API key was not found")
return key


class MiniMaxTextCompletionMixin(MiniMaxMixin, BaseTextCompletionMixin):
"""A mixin class that provides text completion capabilities using the MiniMax API."""

def _get_chat_completion(
self,
model: str,
messages: Union[str, List[Dict[str, str]]],
system_message: Optional[str] = None,
**kwargs: Any,
):
"""Gets a chat completion from the MiniMax API.

Parameters
----------
model : str
The model to use.
messages : Union[str, List[Dict[str, str]]]
Input messages to use.
system_message : Optional[str]
A system message to use.

Returns
-------
completion : dict
"""
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
elif isinstance(messages, list):
messages = [
{"role": msg.get("role", "user"), "content": msg.get("content", "")}
for msg in messages
]

completion = get_chat_completion(
messages=messages,
key=self._get_minimax_key(),
model=model,
system=system_message,
json_response=self._prefer_json_output,
**kwargs,
)
return completion

def _convert_completion_to_str(self, completion: Mapping[str, Any]):
"""Converts MiniMax API completion to string."""
try:
if hasattr(completion, "choices"):
return str(completion.choices[0].message.content)
return str(completion["choices"][0]["message"]["content"])
except Exception as e:
print(f"Error converting completion to string: {str(e)}")
return ""


class MiniMaxClassifierMixin(MiniMaxTextCompletionMixin, BaseClassifierMixin):
"""A mixin class that provides classification capabilities using MiniMax API."""

_prefer_json_output = True

def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> str:
"""Extracts the label from a MiniMax API completion."""
try:
content = self._convert_completion_to_str(completion)
if not self._prefer_json_output:
return content.strip()
try:
label = extract_json_key(content, "label")
if label is not None:
return label
except Exception:
pass
return ""
except Exception as e:
print(f"Error extracting label: {str(e)}")
return ""
3 changes: 3 additions & 0 deletions skllm/model_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@
# Anthropic (Claude) models
ANTHROPIC_CLAUDE_MODEL = "claude-3-haiku-20240307"

# MiniMax models
MINIMAX_MODEL = "MiniMax-M2.7"

# Vertex AI models
VERTEX_DEFAULT_MODEL = "text-bison@002"
Loading
Loading