diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 8c49aef..301b33f 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -23,7 +23,7 @@ from model2vec.train.base import FinetunableStaticModel, TextDataset logger = logging.getLogger(__name__) -_RANDOM_SEED = 42 +_DEFAULT_RANDOM_SEED = 42 LabelType = TypeVar("LabelType", list[str], list[list[str]]) @@ -158,6 +158,7 @@ def fit( # noqa: C901 # Complexity is bad. X_val: list[str] | None = None, y_val: LabelType | None = None, class_weight: torch.Tensor | None = None, + seed: int = _DEFAULT_RANDOM_SEED, ) -> StaticModelForClassification: """ Fit a model. @@ -187,14 +188,14 @@ def fit( # noqa: C901 # Complexity is bad. :param y_val: The labels to be used for validation. :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must have the same length as the number of classes. + :param seed: The random seed to use. Defaults to 42. :return: The fitted model. :raises ValueError: If either X_val or y_val are provided, but not both. """ - pl.seed_everything(_RANDOM_SEED) + pl.seed_everything(seed) logger.info("Re-initializing model.") # Determine whether the task is multilabel based on the type of y. - self._initialize(y) if (X_val is not None) != (y_val is not None): @@ -380,14 +381,13 @@ def to_pipeline(self) -> StaticModelPipeline: """Convert the model to an sklearn pipeline.""" static_model = self.to_static_model() - random_state = np.random.RandomState(_RANDOM_SEED) + random_state = np.random.RandomState(_DEFAULT_RANDOM_SEED) n_items = len(self.classes) X = random_state.randn(n_items, static_model.dim) y = self.classes - converted = make_pipeline(MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers)) - converted.fit(X, y) - mlp_head: MLPClassifier = converted[-1] + mlp_head = MLPClassifier(hidden_layer_sizes=(self.hidden_dim,) * self.n_layers) + mlp_head.fit(X, y) for index, layer in enumerate([module for module in self.head if isinstance(module, nn.Linear)]): mlp_head.coefs_[index] = layer.weight.detach().cpu().numpy().T @@ -401,7 +401,8 @@ def to_pipeline(self) -> StaticModelPipeline: # Set to softmax or sigmoid mlp_head.out_activation_ = "logistic" if self.multilabel else "softmax" - return StaticModelPipeline(static_model, converted) + pipeline = make_pipeline(mlp_head) + return StaticModelPipeline(static_model, pipeline) class _ClassifierLightningModule(pl.LightningModule):