diff --git a/src/diffusers/modular_pipelines/ernie_image/encoders.py b/src/diffusers/modular_pipelines/ernie_image/encoders.py index 24e9622c9422..74d02ffb4dba 100644 --- a/src/diffusers/modular_pipelines/ernie_image/encoders.py +++ b/src/diffusers/modular_pipelines/ernie_image/encoders.py @@ -15,7 +15,7 @@ import json import torch -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance @@ -38,7 +38,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("pe", AutoModelForCausalLM), + ComponentSpec("pe", Ministral3ForCausalLM), ComponentSpec("pe_tokenizer", AutoTokenizer), ] @@ -83,7 +83,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _enhance_prompt( - pe: AutoModelForCausalLM, + pe: Ministral3ForCausalLM, pe_tokenizer: AutoTokenizer, prompt: str, device: torch.device, @@ -160,7 +160,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("text_encoder", AutoModel), + ComponentSpec("text_encoder", Mistral3Model), ComponentSpec("tokenizer", AutoTokenizer), ComponentSpec( "guider", @@ -200,7 +200,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _encode( - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, prompt: list[str], device: torch.device, diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index e0231c4620c5..b8f95241779e 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -20,7 +20,7 @@ from typing import Callable, List, Optional, Union import torch -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model from ...image_processor import VaeImageProcessor from ...loaders import ErnieImageLoraLoaderMixin @@ -52,10 +52,10 @@ def __init__( self, transformer: ErnieImageTransformer2DModel, vae: AutoencoderKLFlux2, - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, scheduler: FlowMatchEulerDiscreteScheduler, - pe: Optional[AutoModelForCausalLM] = None, + pe: Optional[Ministral3ForCausalLM] = None, pe_tokenizer: Optional[AutoTokenizer] = None, ): super().__init__()