VLM Training in SkyRL
This document outlines the changes required for SkyRL/skyrl to support vision inputs, similar to the official tinker implementation. The proposed changes are only applicable to the pytorch / skyrl-train backend. The jax backend is out of scope.
Test Case: The supervised vision language training example (here) should run and converge. The VLM SFT example should work out of the box. The task is to classify CalTech101, a very standard image object recognition benchmark (expected accuracy >90%, especially post-SFT).
An example implementation is on this branch.
Control Flow Overview
The largest change is passing the list of ModelInputChunks to the SkyRL-train backend. This is to reconcile the fact that pytorch / huggingface expects sequences with image token padding while VLLM inserts image padding within the engine. The proposed workaround is to pass the ImageChunks to SkyRL-train which decides, based on the method call, whether to add image pad tokens or not.
Tokenization Responsibilities
The client handles tokenization of text and encoding images into bytes. The server then has two code paths:
- Sampling through vllm and
- Training with huggingface / fsdp2.
VLLM: skyrl-train will be responsible for converting the bytes64 into a pillow image. This will then be passed through a VLLM renderer.
HF / FSDP2 Training: Because the datum is a list of encoded chunks which need to get concatenated, the skyrl-train backend should:
- Patchify the image
- Compute the expected number of tokens (validate if provided by client)
- Concatenate input chunks to create the model input
Tasks Breakdown
There’s a lot of moving pieces to get to a full VLM SFT implementation. Instead, here are a set of incremental changes which should preserve backwards compatibility and interoperability with the Jax backend.
1. Inference Engine Piping
Changes localized around the sample workflow. We will only support the new inference stack (inference_engine_client_http_endpoint.py).
Compatibility: Existing sample calls should not be affected. Multi-modal data piping defaults to None.
Tests: a) Integration test with a VL model. Pass a red square and ask it to answer what color it is. b) CI test verifying our pre-processing yields the same result as VLLM, along with a log probabilities comparison test.
2. Training Integration (FSDP2 only)
Compatibility: New prompt ID construction is only used when image chunks are present. Otherwise, multi-modal data defaults to None.
Tests: Unit tests for image_chunk_to_image_tensor and chunks_to_input_ids. Update Experience and TrainingInput tests.
3. SkyRL-train chunk processing
Compatibility: Existing codepaths expecting a list[int] for prompt-ids should still work.
Tests: Multiple encoded text chunks should yield the same result as a single encoded text chunk. Similarly multiple encoded text chunks should be the same as not passing ModelChunks to the backend.
4. Core tinker types expansion
Compatibility: No new control paths for text. Vision language training.
Tests: Test initialization and chunk validation logic. Correctness test via CalTech101 tinker cookbook.
VLM Training in SkyRL
This document outlines the changes required for SkyRL/skyrl to support vision inputs, similar to the official tinker implementation. The proposed changes are only applicable to the pytorch / skyrl-train backend. The jax backend is out of scope.
Test Case: The supervised vision language training example (here) should run and converge. The VLM SFT example should work out of the box. The task is to classify CalTech101, a very standard image object recognition benchmark (expected accuracy >90%, especially post-SFT).
An example implementation is on this branch.
Control Flow Overview
The largest change is passing the list of
ModelInputChunksto the SkyRL-train backend. This is to reconcile the fact that pytorch / huggingface expects sequences with image token padding while VLLM inserts image padding within the engine. The proposed workaround is to pass theImageChunks to SkyRL-train which decides, based on the method call, whether to add image pad tokens or not.Tokenization Responsibilities
The client handles tokenization of text and encoding images into bytes. The server then has two code paths:
VLLM: skyrl-train will be responsible for converting the bytes64 into a pillow image. This will then be passed through a VLLM renderer.
HF / FSDP2 Training: Because the datum is a list of encoded chunks which need to get concatenated, the skyrl-train backend should:
Tasks Breakdown
There’s a lot of moving pieces to get to a full VLM SFT implementation. Instead, here are a set of incremental changes which should preserve backwards compatibility and interoperability with the Jax backend.
1. Inference Engine Piping
Changes localized around the
sampleworkflow. We will only support the new inference stack (inference_engine_client_http_endpoint.py)./v1/chat/completion/render+ inclusion of multi-modal data./inference/v1/generate.inference/v1/generate. Here - renderer refactor is completed, this is primarily a piping issue.GenerateRequest, which is expected by thev1/generateapi.Compatibility: Existing sample calls should not be affected. Multi-modal data piping defaults to None.
Tests: a) Integration test with a VL model. Pass a red square and ask it to answer what color it is. b) CI test verifying our pre-processing yields the same result as VLLM, along with a log probabilities comparison test.
2. Training Integration (FSDP2 only)
_forward_backward_microto pull outmulti_modal_inputsand pass tomodel_wrapper.py.model_wrapper.pyforward call signature to accept multi-modal inputs. Pass to FSDP2 wrapped huggingface model.Compatibility: New prompt ID construction is only used when image chunks are present. Otherwise, multi-modal data defaults to None.
Tests: Unit tests for
image_chunk_to_image_tensorandchunks_to_input_ids. UpdateExperienceandTrainingInputtests.3. SkyRL-train chunk processing
ModelInputChunkaggregation and pass through forprepare_sample_batchModelInputChunkaggregation and pass through forprepare_model_pass_batchCompatibility: Existing codepaths expecting a list[int] for prompt-ids should still work.
Tests: Multiple encoded text chunks should yield the same result as a single encoded text chunk. Similarly multiple encoded text chunks should be the same as not passing
ModelChunksto the backend.4. Core tinker types expansion
ModelInputChunkbeyond text (generalizedInputChunk). Tinker SDKImageChunkdefinition from SDK.renderer.py] Create a renderer which uses VLLM to process vision inputs and create an appropriatetoken_idsandpixel_valuesfor the training backend.Compatibility: No new control paths for text. Vision language training.
Tests: Test initialization and chunk validation logic. Correctness test via CalTech101 tinker cookbook.