You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Implement a collate_fn that takes a list of per-sample sparse-lag items and returns padded, mask-aware tensors suitable for the model. Each sample contains selected lag features and categorical IDs; sequences have variable $K$ (number of selected lags). The collate must:
Feature Details
Implement a$K$ (number of selected lags). The collate must:
collate_fnthat takes a list of per-sample sparse-lag items and returns padded, mask-aware tensors suitable for the model. Each sample contains selected lag features and categorical IDs; sequences have variablefloat, IDsint64, masksbool),Input (per sample)
{ "vals": np.ndarray | torch.Tensor, # shape: (K_i, F_val) "lag_ids": np.ndarray | torch.Tensor, # shape: (K_i,), int "ticker_id": int, # scalar # optional "sector_id": int, "meta": {...} # optional passthrough }Output (batched)
Affected Modules
As stated in the parent issue.
Implementation Checklist
CollatedBatchtyped container (NamedTuple/dataclass) for outputs.collate_variable_lags(samples, *, pad_value=0.0, pad_idx=0, sort_by_len=False, pin_memory=False)torch.Tensor(avoid extra copies).• Heterogeneous
• Dtype checks:
vals==float32/64,IDs==int64,mask==bool.• Optional fields present/absent.
• Sorting on/off preserves content; if sorted, return restore_idx.
• Edge cases:
• Perf sanity (
Limitations
As stated in the parent issue.