Skip to content

Port output distribution classes to DynamicEmb#297

Open
z52527 wants to merge 3 commits intoNVIDIA:mainfrom
z52527:fea-output-dist
Open

Port output distribution classes to DynamicEmb#297
z52527 wants to merge 3 commits intoNVIDIA:mainfrom
z52527:fea-output-dist

Conversation

@z52527
Copy link
Copy Markdown
Collaborator

@z52527 z52527 commented Feb 6, 2026

Description

#296
Port TorchRec's output distribution classes to DynamicEmb library, enabling future performance optimizations.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@z52527
Copy link
Copy Markdown
Collaborator Author

z52527 commented Feb 6, 2026

Track CI here.

@z52527 z52527 requested a review from shijieliu February 6, 2026 10:23
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 6, 2026

Greptile Summary

This PR ports TorchRec's RW output distribution classes (RwSequenceEmbeddingDist, RwPooledEmbeddingDist) into the DynamicEmb library and wires them into the existing RwSequenceDynamicEmbeddingSharding and RwPooledDynamicEmbeddingSharding classes via new create_output_dist() overrides. The new file is missing its SPDX/Apache license header, and two issues from the prior review round — the Optional[ProcessGroup] crash risk and the lazy-init cast type-mismatch — remain unresolved (pyre-fixme comments left in place).

Confidence Score: 4/5

Safe to merge after the prior-round P1 concerns (Optional pg crash, cast type-mismatch) are addressed; remaining new findings are P2 style/cleanup.

Two P1-level issues surfaced in the previous review round (Optional ProcessGroup crash risk, lazy-init cast mismatch) remain unresolved — the pyre-fixme suppression comments are still present in the diff. All new findings in this pass are P2 (dead attribute, missing license header, stale TODO), but the outstanding prior P1s keep this from a clean 5.

corelib/dynamicemb/dynamicemb/output_dist.py (cast type-mismatch on lazy init, missing license header); corelib/dynamicemb/dynamicemb/planner/rw_sharding.py (Optional pg passed unchecked to both create_output_dist overrides)

Important Files Changed

Filename Overview
corelib/dynamicemb/dynamicemb/output_dist.py New file porting TorchRec's RW output distribution classes; has a dead _qcomm_codecs_registry attribute, missing SPDX license header, and the lazy-init cast pattern has a known type-mismatch risk (flagged in previous review) that remains unresolved.
corelib/dynamicemb/dynamicemb/planner/rw_sharding.py Adds create_output_dist() overrides for both sharding classes wiring in the new dist modules; self._pg Optional typing issue (pyre-fixme left in place) and an unresolved TODO remain from prior review.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant RwSequenceDynamicEmbeddingSharding
    participant RwSequenceEmbeddingDist
    participant SequenceEmbeddingsAllToAll

    Caller->>RwSequenceDynamicEmbeddingSharding: create_output_dist(device)
    RwSequenceDynamicEmbeddingSharding->>RwSequenceEmbeddingDist: __init__(pg, num_features, device, qcomm_codecs_registry)
    RwSequenceEmbeddingDist->>SequenceEmbeddingsAllToAll: __init__(pg, splits, device, codecs)
    RwSequenceEmbeddingDist-->>RwSequenceDynamicEmbeddingSharding: dist instance
    RwSequenceDynamicEmbeddingSharding-->>Caller: BaseEmbeddingDist

    Caller->>RwSequenceEmbeddingDist: forward(local_embs, sharding_ctx)
    Note over RwSequenceEmbeddingDist: assert sharding_ctx is not None
    RwSequenceEmbeddingDist->>SequenceEmbeddingsAllToAll: forward(local_embs, lengths, splits, ...)
    SequenceEmbeddingsAllToAll-->>Caller: Awaitable[Tensor]

    participant RwPooledDynamicEmbeddingSharding
    participant RwPooledEmbeddingDist
    participant PooledEmbeddingsReduceScatter
    participant VariableBatchPooledEmbeddingsReduceScatter

    Caller->>RwPooledDynamicEmbeddingSharding: create_output_dist(device)
    RwPooledDynamicEmbeddingSharding->>RwPooledEmbeddingDist: __init__(pg, embedding_dims, qcomm_codecs_registry)
    Note over RwPooledEmbeddingDist: _dist = None (lazy init)
    RwPooledEmbeddingDist-->>Caller: BaseEmbeddingDist

    Caller->>RwPooledEmbeddingDist: forward(local_embs, sharding_ctx)
    alt _dist is None
        RwPooledEmbeddingDist->>RwPooledEmbeddingDist: _create_output_dist_module(sharding_ctx)
        alt variable_batch_per_feature
            RwPooledEmbeddingDist->>VariableBatchPooledEmbeddingsReduceScatter: __init__(pg, codecs)
        else
            RwPooledEmbeddingDist->>PooledEmbeddingsReduceScatter: __init__(pg, codecs)
        end
    end
    RwPooledEmbeddingDist-->>Caller: Awaitable[Tensor]
Loading

Reviews (2): Last reviewed commit: "Remove some comments." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +155 to +166
def create_output_dist(
self,
device: Optional[torch.device] = None,
) -> BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]:
return RwSequenceEmbeddingDist(
# pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
# `Optional[ProcessGroup]`.
self._pg,
self._get_num_features(),
device if device is not None else self._device,
qcomm_codecs_registry=self.qcomm_codecs_registry,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passes Optional ProcessGroup

create_output_dist() passes self._pg through to RwSequenceEmbeddingDist, but self._pg is typed as Optional[ProcessGroup] (and you’ve added a pyre-fixme to silence it). If self._pg is actually None at runtime (e.g., in non-distributed / single-rank setups), RwSequenceEmbeddingDist.__init__ will call pg.size() and crash. This needs a real guard or to ensure _pg is always non-None before constructing the dist module (same pattern in pooled sharding too).

Also appears in: corelib/dynamicemb/dynamicemb/planner/rw_sharding.py:264-274.

Comment on lines +130 to +146
"""
if self._dist is None:
self._create_output_dist_module(sharding_ctx)

if sharding_ctx is None:
return cast(PooledEmbeddingsReduceScatter, self._dist)(local_embs)
elif sharding_ctx.variable_batch_per_feature:
return cast(VariableBatchPooledEmbeddingsReduceScatter, self._dist)(
local_embs,
batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature,
embedding_dims=self._embedding_dims,
)
else:
return cast(PooledEmbeddingsReduceScatter, self._dist)(
local_embs,
input_splits=sharding_ctx.batch_size_per_rank,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing context for init

When sharding_ctx is None, forward() returns PooledEmbeddingsReduceScatter(local_embs) (line 134-135) but _dist may have been created with a variable-batch module based on the first call’s sharding_ctx (line 131-133 / 151-155). If the first invocation had variable_batch_per_feature=True and a later call passes sharding_ctx=None, this will call a VariableBatchPooledEmbeddingsReduceScatter without its required args, causing a runtime error. Either require sharding_ctx always be provided, or make _dist selection independent of the first call.

@shijieliu
Copy link
Copy Markdown
Collaborator

/build

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants