diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index 65ba520b..1141c57d 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -21,11 +21,23 @@ from peft import LoraConfig from transformers import DataCollatorForSeq2Seq, TrainingArguments from transformers.trainer_utils import RemoveColumnsCollator -from trl import ( # pylint: disable=import-error, no-name-in-module - DataCollatorForCompletionOnlyLM, -) import torch +# Handle trl version compatibility +# In trl < 0.19: DataCollatorForCompletionOnlyLM +# In trl >= 0.19: May be renamed or moved +try: + # pylint: disable=import-error, no-name-in-module + from trl import DataCollatorForCompletionOnlyLM +except ImportError: + # Fallback for newer trl versions where it might be renamed + try: + from trl.trainer.utils import DataCollatorForCompletionOnlyLM + except ImportError: + # If still not available, create a placeholder that will never match + # This allows the plugin to load even if this specific collator isn't used + DataCollatorForCompletionOnlyLM = type('DataCollatorForCompletionOnlyLM', (), {}) + class PaddingFreeAccelerationPlugin(AccelerationPlugin):