diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index db56376e3..194956412 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -398,6 +398,7 @@ def _augment_schema( cats = cats or [] conts = conts or [] labels = labels or [] + lists = lists or [] schema = schema.select_by_name(conts + cats + labels + lists)