Skip to content

[tx] replace calls to jax.process_index() to resolve rank ordering issue with multi-host TPUs#1252

Open
andrewsykim wants to merge 11 commits intoNovaSky-AI:mainfrom
andrewsykim:remove-jax-process-id
Open

[tx] replace calls to jax.process_index() to resolve rank ordering issue with multi-host TPUs#1252
andrewsykim wants to merge 11 commits intoNovaSky-AI:mainfrom
andrewsykim:remove-jax-process-id

Conversation

@andrewsykim
Copy link

@andrewsykim andrewsykim commented Mar 2, 2026

jax.process_index() is inconsistent between GPU and TPU backends. For GPU / CPU, it returns the process_id passed into jax.distributed.initialize. However, for TPU, the index assignment happens from the TPU slice, regardless of what is assigned from jax.distributed.initialize or the TPU environment (e.g. TPU_WORKER_ID from GKE). This leads tojax.process_index() returning a different rank from --process-id set from the user.

Assuming jax.process_index() won't be updated in the near future to be consistent across TPU / GPU, this PR updates the SkyRL JAX backend to only use the index provided from the --process-id argument.

I've tested this PR on my 4x4 TPU v6e cluster with both Qwen3-0.6B and Qwen3-8B.

Resolves #1024


Open with Devin

… and GPU

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
@andrewsykim andrewsykim changed the title Replace calls to jax.process_index() to resolve rank ordering issue with multi-host TPUs [tx] replace calls to jax.process_index() to resolve rank ordering issue with multi-host TPUs Mar 2, 2026
gemini-code-assist[bot]

This comment was marked as resolved.

devin-ai-integration[bot]

This comment was marked as resolved.

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 3, 2026

We will also need to update

with pack_and_upload(output_path, rank=jax.process_index()) as temp_dir:
so the checkpointing works correctly in the multi-node setting :)

@pcmoritz pcmoritz added the tx label Mar 3, 2026
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
devin-ai-integration[bot]

This comment was marked as resolved.

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
@andrewsykim
Copy link
Author

so the checkpointing works correctly

Updated

adapter_config: LoraConfig,
adapter_index: int,
output_path: Path | CloudPath,
rank: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would probably be better to remove the default from this, otherwise it is too easy to make a mistake and not specify this in call sites (before, it was kind of doing implicitly the right thing by using jax.process_index(), but now it is important to pass in the right value here).

I can make that change before merging the PR.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 3, 2026

/gemini review

gemini-code-assist[bot]

This comment was marked as resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support Google Cloud TPUs

2 participants