Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions skyrl-tx/tests/tinker/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def start_api_server(overrides: dict[str, str] | None = None):
"host": "0.0.0.0",
"port": str(TEST_SERVER_PORT),
"base-model": BASE_MODEL,
"backend-config": '{"max_lora_adapters": 4}',
"backend-config": '{"max_lora_adapters": 4, "deterministic": true}',
"database-url": f"sqlite:///{db_path}",
}
if overrides:
Expand Down Expand Up @@ -186,16 +186,34 @@ def test_training_workflow(service_client):
fwdbwd_result3 = training_client.forward_backward(processed_examples, "cross_entropy").result()
assert fwdbwd_result3.loss_fn_outputs == fwdbwd_result.loss_fn_outputs

# Run one train step from resume state, then save a sampler checkpoint.
_ = training_client.forward_backward(processed_examples, "cross_entropy").result()
_ = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()
sampling_path = training_client.save_weights_for_sampler(name="final").result().path
parsed = urlparse(sampling_path)
training_run_id = parsed.netloc
checkpoint_id = parsed.path.lstrip("/")

# Re-run the same train step from the same resume point and verify checkpoint bytes match.
training_client.load_state(resume_path)
_ = training_client.forward_backward(processed_examples, "cross_entropy").result()
_ = training_client.optim_step(types.AdamParams(learning_rate=1e-4)).result()
sampling_path_2 = training_client.save_weights_for_sampler(name="final_replayed").result().path
parsed_2 = urlparse(sampling_path_2)
training_run_id_2 = parsed_2.netloc
checkpoint_id_2 = parsed_2.path.lstrip("/")

rest_client = service_client.create_rest_client()
# Download the checkpoint
checkpoint_response = rest_client.get_checkpoint_archive_url(training_run_id, checkpoint_id).result()
with tempfile.NamedTemporaryFile() as tmp_archive:
urllib.request.urlretrieve(checkpoint_response.url, tmp_archive.name)
assert os.path.getsize(tmp_archive.name) > 0
with urllib.request.urlopen(checkpoint_response.url) as resp:
checkpoint_bytes = resp.read()
assert len(checkpoint_bytes) > 0

checkpoint_response_2 = rest_client.get_checkpoint_archive_url(training_run_id_2, checkpoint_id_2).result()
with urllib.request.urlopen(checkpoint_response_2.url) as resp:
checkpoint_bytes_2 = resp.read()
assert checkpoint_bytes == checkpoint_bytes_2
Comment on lines +190 to +216
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's significant code duplication here for running a training step and fetching the resulting checkpoint. The logic in lines 190-211 is repeated with minor differences in lines 198-215.

To improve readability and maintainability, consider extracting this logic into a helper function. This function could take the training_client, rest_client, and a name for the checkpoint, and return the checkpoint bytes. This would make the test's intent clearer and reduce redundancy.


# List all checkpoints for the original training run
checkpoints_response = rest_client.list_checkpoints(original_training_run_id).result()
Expand Down
31 changes: 23 additions & 8 deletions skyrl-tx/tx/tinker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ def force_exit():
app = FastAPI(title="Tinker API Mock", version="0.0.1", lifespan=lifespan)


def _is_jax_deterministic_mode(request: Request) -> bool:
"""Return true when deterministic mode is enabled for the JAX backend."""
cfg = request.app.state.engine_config
return cfg.backend == "jax" and bool(cfg.backend_config.get("deterministic", False))


async def get_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
"""Dependency to get a database session."""
async with AsyncSession(request.app.state.db_engine) as session:
Expand Down Expand Up @@ -365,14 +371,16 @@ class SamplingParams(BaseModel):
top_k: int = -1
top_p: float = 1

def to_types(self) -> types.SamplingParams:
def to_types(self, deterministic: bool = False) -> types.SamplingParams:
if self.max_tokens is None:
raise HTTPException(status_code=400, detail="max_tokens is currently required")
if self.max_tokens <= 0:
raise HTTPException(status_code=400, detail="max_tokens must be a positive number")

# Generate a random seed if not provided
seed = self.seed if self.seed is not None else random.randint(0, 2**31 - 1)
# In deterministic mode, default to a fixed seed instead of random fallback.
seed = self.seed if self.seed is not None else (
0 if deterministic else random.randint(0, 2**31 - 1)
)

# Determine if stop values are token IDs (int) or strings
stop_tokens = None
Expand Down Expand Up @@ -600,7 +608,7 @@ async def create_sampling_session(request: CreateSamplingSessionRequest, session


@app.post("/api/v1/create_model", response_model=CreateModelResponse)
async def create_model(request: CreateModelRequest, session: AsyncSession = Depends(get_session)):
async def create_model(request: CreateModelRequest, req: Request, session: AsyncSession = Depends(get_session)):
"""Create a new model, optionally with a LoRA adapter."""
# Validate session exists
session_db = await session.get(SessionDB, request.session_id)
Expand All @@ -610,8 +618,13 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe
model_id = f"model_{uuid4().hex[:8]}"

# alpha = 32 seems to be the tinker default (see https://thinkingmachines.ai/blog/lora/)
# Generate a random seed if not provided
seed = request.lora_config.seed if request.lora_config.seed is not None else random.randint(0, 2**31 - 1)
# In deterministic mode, default to a fixed seed instead of random fallback.
deterministic = _is_jax_deterministic_mode(req)
seed = (
request.lora_config.seed
if request.lora_config.seed is not None
else (0 if deterministic else random.randint(0, 2**31 - 1))
)
Comment on lines +623 to +627
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for determining the seed based on the deterministic flag is duplicated from the SamplingParams.to_types method (lines 381-383).

To avoid code duplication and improve maintainability, consider extracting this logic into a shared helper function, for example:

def _get_seed(provided_seed: int | None, deterministic: bool) -> int:
    """Return a seed, falling back to a fixed or random seed based on deterministic mode."""
    if provided_seed is not None:
        return provided_seed
    return 0 if deterministic else random.randint(0, 2**31 - 1)

You could then call this helper function here and in SamplingParams.to_types.

    seed = _get_seed(request.lora_config.seed, deterministic)

lora_config = types.LoraConfig(rank=request.lora_config.rank, alpha=32.0, seed=seed)
request_id = await create_future(
session=session,
Expand All @@ -635,7 +648,7 @@ async def create_model(request: CreateModelRequest, session: AsyncSession = Depe
return CreateModelResponse(
model_id=model_id,
base_model=request.base_model,
lora_config=request.lora_config,
lora_config=LoRAConfig(rank=request.lora_config.rank, seed=seed),
status="created",
request_id=str(request_id),
)
Expand Down Expand Up @@ -890,6 +903,8 @@ async def asample(request: SampleRequest, req: Request, session: AsyncSession =
# Validate that the checkpoint exists and is ready
await validate_checkpoint(req, model_id, checkpoint_id, types.CheckpointType.SAMPLER, session)

deterministic = _is_jax_deterministic_mode(req)

request_id = await create_future(
session=session,
request_type=(
Expand All @@ -899,7 +914,7 @@ async def asample(request: SampleRequest, req: Request, session: AsyncSession =
request_data=types.SampleInput(
base_model=base_model,
prompt=request.prompt.to_types(),
sampling_params=request.sampling_params.to_types(),
sampling_params=request.sampling_params.to_types(deterministic=deterministic),
num_samples=request.num_samples,
checkpoint_id=checkpoint_id,
prompt_logprobs=request.prompt_logprobs if request.prompt_logprobs is not None else False,
Expand Down
8 changes: 8 additions & 0 deletions skyrl-tx/tx/tinker/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ class JaxBackendConfig(BaseModel, extra="forbid"):
default=1024,
description="Chunk size for cross-entropy loss computation. Reduces memory by avoiding full [B*T, V] logits materialization. Set to 0 to disable chunking.",
)
deterministic: bool = Field(
default=False,
description=(
"Enable deterministic behavior for JAX backend operations and checkpoint serialization. "
"This mode assumes fixed hardware/software/runtime."
),
)
# Multi-node configuration
coordinator_address: str | None = Field(
default=None,
Expand Down Expand Up @@ -861,6 +868,7 @@ def save_sampler_checkpoint(self, output_path: AnyPath, model_id: str, persist:
lora_model.lora_config,
lora_model.adapter_index,
output_path,
deterministic=self.config.deterministic,
)
logger.info(f"Saved LoRA sampler checkpoint to {output_path}")

Expand Down
6 changes: 4 additions & 2 deletions skyrl-tx/tx/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def save_safetensors(

tensors = {k: multihost_utils.process_allgather(v, tiled=True) for k, v in tensors.items()}

safetensors.numpy.save_file({k: np.asarray(v) for k, v in tensors.items()}, filename)
safetensors.numpy.save_file({k: np.asarray(tensors[k]) for k in sorted(tensors)}, filename)


def filter_lora(adapter_config: LoraConfig, path: tuple[str, ...]) -> bool:
Expand Down Expand Up @@ -272,6 +272,7 @@ def save_lora_checkpoint(
adapter_config: LoraConfig,
adapter_index: int,
output_path: Path | CloudPath,
deterministic: bool = False,
):
"""Save a LoRA checkpoint as a tar.gz archive.

Expand All @@ -280,12 +281,13 @@ def save_lora_checkpoint(
adapter_config: LoRA adapter configuration
adapter_index: Index of the adapter to save
output_path: Path to save the checkpoint tar.gz file
deterministic: Normalize archive metadata and ordering for reproducible bytes
"""
peft_config = peft.LoraConfig(
base_model_name_or_path=base_model_name, r=adapter_config.rank, lora_alpha=adapter_config.alpha
)

with pack_and_upload(output_path, rank=jax.process_index()) as temp_dir:
with pack_and_upload(output_path, rank=jax.process_index(), deterministic=deterministic) as temp_dir:

save_safetensors(
model.config,
Expand Down
36 changes: 33 additions & 3 deletions skyrl-tx/tx/utils/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@


@contextmanager
def pack_and_upload(dest: AnyPath, rank: Optional[int] = None) -> Generator[Path, None, None]:
def pack_and_upload(
dest: AnyPath,
rank: Optional[int] = None,
deterministic: bool = False,
) -> Generator[Path, None, None]:
"""Give the caller a temp directory that gets uploaded as a tar.gz archive on exit.

Args:
dest: Destination path for the tar.gz file
rank: Process rank for multi-rank deduplication. If provided and a probe
file exists at {dest}.probe, only rank 0 writes.
deterministic: If true, normalize tar/gzip metadata and entry order so
resulting archives are byte-for-byte reproducible.
"""
with TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
Expand All @@ -34,9 +40,33 @@ def pack_and_upload(dest: AnyPath, rank: Optional[int] = None) -> Generator[Path

with dest.open("wb") as f:
# Use compresslevel=0 to prioritize speed, as checkpoint files don't compress well.
with gzip.GzipFile(fileobj=f, mode="wb", compresslevel=0) as gz_stream:
with gzip.GzipFile(
fileobj=f,
filename="",
mode="wb",
compresslevel=0,
mtime=0 if deterministic else None,
) as gz_stream:
with tarfile.open(fileobj=gz_stream, mode="w:") as tar:
tar.add(tmp_path, arcname="")
if not deterministic:
tar.add(tmp_path, arcname="")
else:
# Deterministic pack: stable traversal + normalized metadata.
for path in sorted(tmp_path.rglob("*"), key=lambda p: p.relative_to(tmp_path).as_posix()):
rel = path.relative_to(tmp_path)
arcname = rel.as_posix()
tarinfo = tar.gettarinfo(str(path), arcname=arcname)
tarinfo.mtime = 0
tarinfo.uid = 0
tarinfo.gid = 0
tarinfo.uname = ""
tarinfo.gname = ""

if tarinfo.isdir():
tar.addfile(tarinfo)
else:
with path.open("rb") as src:
tar.addfile(tarinfo, src)


@contextmanager
Expand Down
Loading