Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from model2vec.train.base import FinetunableStaticModel, TextDataset

logger = logging.getLogger(__name__)
_RANDOM_SEED = 42
_DEFAULT_RANDOM_SEED = 42

LabelType = TypeVar("LabelType", list[str], list[list[str]])

Expand Down Expand Up @@ -158,6 +158,7 @@ def fit( # noqa: C901 # Complexity is bad.
X_val: list[str] | None = None,
y_val: LabelType | None = None,
class_weight: torch.Tensor | None = None,
seed: int = _DEFAULT_RANDOM_SEED,
) -> StaticModelForClassification:
"""
Fit a model.
Expand Down Expand Up @@ -187,14 +188,14 @@ def fit( # noqa: C901 # Complexity is bad.
:param y_val: The labels to be used for validation.
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
have the same length as the number of classes.
:param seed: The random seed to use. Defaults to 42.
:return: The fitted model.
:raises ValueError: If either X_val or y_val are provided, but not both.
"""
pl.seed_everything(_RANDOM_SEED)
pl.seed_everything(seed)
logger.info("Re-initializing model.")

# Determine whether the task is multilabel based on the type of y.

self._initialize(y)

if (X_val is not None) != (y_val is not None):
Expand Down Expand Up @@ -380,14 +381,13 @@ def to_pipeline(self) -> StaticModelPipeline:
"""Convert the model to an sklearn pipeline."""
static_model = self.to_static_model()

random_state = np.random.RandomState(_RANDOM_SEED)
random_state = np.random.RandomState(_DEFAULT_RANDOM_SEED)
n_items = len(self.classes)
X = random_state.randn(n_items, static_model.dim)
y = self.classes

converted = make_pipeline(MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers))
converted.fit(X, y)
mlp_head: MLPClassifier = converted[-1]
mlp_head = MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers)
mlp_head.fit(X, y)

for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]):
mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T
Expand All @@ -401,7 +401,8 @@ def to_pipeline(self) -> StaticModelPipeline:
# Set to softmax or sigmoid
mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax"

return StaticModelPipeline(static_model, converted)
pipeline = make_pipeline(mlp_head)
return StaticModelPipeline(static_model, pipeline)


class _ClassifierLightningModule(pl.LightningModule):
Expand Down