From 6fa5820e1b351ebbc1178851c8f0397d4dd1c753 Mon Sep 17 00:00:00 2001 From: JunghwanNA <70629228+shaun0927@users.noreply.github.com> Date: Fri, 17 Apr 2026 13:12:12 +0900 Subject: [PATCH] Restore default factory on Head._task_weights PR #802 replaced defaultdict(lambda: 1.0) with defaultdict(), which has the same runtime semantics as a plain dict - missing keys raise KeyError. The documented behavior ('1.0 when unset') was preserved at only one call site via .get(name, 1.0); any other direct indexing regresses to a crash. Restoring the lambda factory is a one-line change that preserves the original API contract and keeps downstream code that reads head._task_weights[task_name] working. Fixes #813 --- transformers4rec/torch/model/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformers4rec/torch/model/base.py b/transformers4rec/torch/model/base.py index 3aeb4eb3f..6d697ebce 100644 --- a/transformers4rec/torch/model/base.py +++ b/transformers4rec/torch/model/base.py @@ -269,7 +269,7 @@ def __init__( for i, task in enumerate(prediction_tasks): self.prediction_task_dict[task.task_name] = task - self._task_weights = defaultdict() + self._task_weights = defaultdict(lambda: 1.0) if task_weights: for task, val in zip(cast(List[PredictionTask], prediction_tasks), task_weights): self._task_weights[task.task_name] = val