Skip to content

[skyrl][RFC] VLM Training in SkyRL with the SkyRL train backend #1200

@nithinvc

Description

@nithinvc

VLM Training in SkyRL

This document outlines the changes required for SkyRL/skyrl to support vision inputs, similar to the official tinker implementation. The proposed changes are only applicable to the pytorch / skyrl-train backend. The jax backend is out of scope.

Test Case: The supervised vision language training example (here) should run and converge. The VLM SFT example should work out of the box. The task is to classify CalTech101, a very standard image object recognition benchmark (expected accuracy >90%, especially post-SFT).

An example implementation is on this branch.

Control Flow Overview

The largest change is passing the list of ModelInputChunks to the SkyRL-train backend. This is to reconcile the fact that pytorch / huggingface expects sequences with image token padding while VLLM inserts image padding within the engine. The proposed workaround is to pass the ImageChunks to SkyRL-train which decides, based on the method call, whether to add image pad tokens or not.

Image

Tokenization Responsibilities

The client handles tokenization of text and encoding images into bytes. The server then has two code paths:

  1. Sampling through vllm and
  2. Training with huggingface / fsdp2.

VLLM: skyrl-train will be responsible for converting the bytes64 into a pillow image. This will then be passed through a VLLM renderer.

HF / FSDP2 Training: Because the datum is a list of encoded chunks which need to get concatenated, the skyrl-train backend should:

  1. Patchify the image
  2. Compute the expected number of tokens (validate if provided by client)
  3. Concatenate input chunks to create the model input

Tasks Breakdown

There’s a lot of moving pieces to get to a full VLM SFT implementation. Instead, here are a set of incremental changes which should preserve backwards compatibility and interoperability with the Jax backend.

1. Inference Engine Piping
Changes localized around the sample workflow. We will only support the new inference stack (inference_engine_client_http_endpoint.py).

  • Inference engine client call for /v1/chat/completion/render + inclusion of multi-modal data.
  • Implement the sample API which performs the generation by calling /inference/v1/generate.
  • [VLLM Upstream] Implement multi-modal features in inference/v1/generate. Here - renderer refactor is completed, this is primarily a piping issue.
  • InferenceEngineClient calls the VLLM renderer endpoint to get a GenerateRequest, which is expected by the v1/generate api.
  • Verify existing RL / SFT training loops work with new inference server stack.

Compatibility: Existing sample calls should not be affected. Multi-modal data piping defaults to None.
Tests: a) Integration test with a VL model. Pass a red square and ask it to answer what color it is. b) CI test verifying our pre-processing yields the same result as VLLM, along with a log probabilities comparison test.

2. Training Integration (FSDP2 only)

  • Pass multi-modal inputs through workers (Policy worker and Ref worker). In worker.py update _forward_backward_micro to pull out multi_modal_inputs and pass to model_wrapper.py.
  • Update model_wrapper.py forward call signature to accept multi-modal inputs. Pass to FSDP2 wrapped huggingface model.

Compatibility: New prompt ID construction is only used when image chunks are present. Otherwise, multi-modal data defaults to None.
Tests: Unit tests for image_chunk_to_image_tensor and chunks_to_input_ids. Update Experience and TrainingInput tests.

3. SkyRL-train chunk processing

  • Pass chunks to SkyRL-train backend.
    • ModelInputChunk aggregation and pass through for prepare_sample_batch
    • ModelInputChunk aggregation and pass through for prepare_model_pass_batch
  • Update the jax backend to accept chunks, but no-op with them.
  • In SkyRL-train, when chunks are provided, construct text only inputs (no vision yet).
  • Remove concatenation within the engine, and move it to the backend. Comparing between git commits should yield the same result.

Compatibility: Existing codepaths expecting a list[int] for prompt-ids should still work.
Tests: Multiple encoded text chunks should yield the same result as a single encoded text chunk. Similarly multiple encoded text chunks should be the same as not passing ModelChunks to the backend.

4. Core tinker types expansion

  • Extend the definition of ModelInputChunk beyond text (generalized InputChunk). Tinker SDK
  • Add ImageChunk definition from SDK.
  • [New renderer.py] Create a renderer which uses VLLM to process vision inputs and create an appropriate token_ids and pixel_values for the training backend.

Compatibility: No new control paths for text. Vision language training.
Tests: Test initialization and chunk validation logic. Correctness test via CalTech101 tinker cookbook.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions