Port output distribution classes to DynamicEmb#297
Port output distribution classes to DynamicEmb#297z52527 wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
|
Track CI here. |
Greptile SummaryThis PR ports TorchRec's RW output distribution classes ( Confidence Score: 4/5Safe to merge after the prior-round P1 concerns (Optional pg crash, cast type-mismatch) are addressed; remaining new findings are P2 style/cleanup. Two P1-level issues surfaced in the previous review round (Optional ProcessGroup crash risk, lazy-init cast mismatch) remain unresolved — the pyre-fixme suppression comments are still present in the diff. All new findings in this pass are P2 (dead attribute, missing license header, stale TODO), but the outstanding prior P1s keep this from a clean 5. corelib/dynamicemb/dynamicemb/output_dist.py (cast type-mismatch on lazy init, missing license header); corelib/dynamicemb/dynamicemb/planner/rw_sharding.py (Optional pg passed unchecked to both create_output_dist overrides) Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant RwSequenceDynamicEmbeddingSharding
participant RwSequenceEmbeddingDist
participant SequenceEmbeddingsAllToAll
Caller->>RwSequenceDynamicEmbeddingSharding: create_output_dist(device)
RwSequenceDynamicEmbeddingSharding->>RwSequenceEmbeddingDist: __init__(pg, num_features, device, qcomm_codecs_registry)
RwSequenceEmbeddingDist->>SequenceEmbeddingsAllToAll: __init__(pg, splits, device, codecs)
RwSequenceEmbeddingDist-->>RwSequenceDynamicEmbeddingSharding: dist instance
RwSequenceDynamicEmbeddingSharding-->>Caller: BaseEmbeddingDist
Caller->>RwSequenceEmbeddingDist: forward(local_embs, sharding_ctx)
Note over RwSequenceEmbeddingDist: assert sharding_ctx is not None
RwSequenceEmbeddingDist->>SequenceEmbeddingsAllToAll: forward(local_embs, lengths, splits, ...)
SequenceEmbeddingsAllToAll-->>Caller: Awaitable[Tensor]
participant RwPooledDynamicEmbeddingSharding
participant RwPooledEmbeddingDist
participant PooledEmbeddingsReduceScatter
participant VariableBatchPooledEmbeddingsReduceScatter
Caller->>RwPooledDynamicEmbeddingSharding: create_output_dist(device)
RwPooledDynamicEmbeddingSharding->>RwPooledEmbeddingDist: __init__(pg, embedding_dims, qcomm_codecs_registry)
Note over RwPooledEmbeddingDist: _dist = None (lazy init)
RwPooledEmbeddingDist-->>Caller: BaseEmbeddingDist
Caller->>RwPooledEmbeddingDist: forward(local_embs, sharding_ctx)
alt _dist is None
RwPooledEmbeddingDist->>RwPooledEmbeddingDist: _create_output_dist_module(sharding_ctx)
alt variable_batch_per_feature
RwPooledEmbeddingDist->>VariableBatchPooledEmbeddingsReduceScatter: __init__(pg, codecs)
else
RwPooledEmbeddingDist->>PooledEmbeddingsReduceScatter: __init__(pg, codecs)
end
end
RwPooledEmbeddingDist-->>Caller: Awaitable[Tensor]
Reviews (2): Last reviewed commit: "Remove some comments." | Re-trigger Greptile |
| def create_output_dist( | ||
| self, | ||
| device: Optional[torch.device] = None, | ||
| ) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]: | ||
| return RwSequenceEmbeddingDist( | ||
| # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got | ||
| # `Optional[ProcessGroup]`. | ||
| self._pg, | ||
| self._get_num_features(), | ||
| device if device is not None else self._device, | ||
| qcomm_codecs_registry=self.qcomm_codecs_registry, | ||
| ) |
There was a problem hiding this comment.
Passes Optional ProcessGroup
create_output_dist() passes self._pg through to RwSequenceEmbeddingDist, but self._pg is typed as Optional[ProcessGroup] (and you’ve added a pyre-fixme to silence it). If self._pg is actually None at runtime (e.g., in non-distributed / single-rank setups), RwSequenceEmbeddingDist.__init__ will call pg.size() and crash. This needs a real guard or to ensure _pg is always non-None before constructing the dist module (same pattern in pooled sharding too).
Also appears in: corelib/dynamicemb/dynamicemb/planner/rw_sharding.py:264-274.
| """ | ||
| if self._dist is None: | ||
| self._create_output_dist_module(sharding_ctx) | ||
|
|
||
| if sharding_ctx is None: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs) | ||
| elif sharding_ctx.variable_batch_per_feature: | ||
| return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature, | ||
| embedding_dims=self._embedding_dims, | ||
| ) | ||
| else: | ||
| return cast(PooledEmbeddingsReduceScatter, self._dist)( | ||
| local_embs, | ||
| input_splits=sharding_ctx.batch_size_per_rank, | ||
| ) |
There was a problem hiding this comment.
Missing context for init
When sharding_ctx is None, forward() returns PooledEmbeddingsReduceScatter(local_embs) (line 134-135) but _dist may have been created with a variable-batch module based on the first call’s sharding_ctx (line 131-133 / 151-155). If the first invocation had variable_batch_per_feature=True and a later call passes sharding_ctx=None, this will call a VariableBatchPooledEmbeddingsReduceScatter without its required args, causing a runtime error. Either require sharding_ctx always be provided, or make _dist selection independent of the first call.
d44caeb to
55c7bdc
Compare
|
/build |
Description
#296
Port TorchRec's output distribution classes to DynamicEmb library, enabling future performance optimizations.
Checklist