diff --git a/dev/run_yes_no_maybe_kl_advantage_tinker.py b/dev/run_yes_no_maybe_kl_advantage_tinker.py
new file mode 100644
index 00000000..468001d3
--- /dev/null
+++ b/dev/run_yes_no_maybe_kl_advantage_tinker.py
@@ -0,0 +1,104 @@
+"""Launch yes-no-maybe-kl-advantage-tinker training on SkyPilot (Kubernetes).
+
+Usage:
+ uv run dev/run_yes_no_maybe_kl_advantage_tinker.py
+ uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --fast
+ uv run dev/run_yes_no_maybe_kl_advantage_tinker.py --base-model Qwen/Qwen2.5-7B-Instruct
+"""
+
+import argparse
+import os
+import textwrap
+
+from dotenv import load_dotenv
+import sky
+from sky import ClusterStatus
+
+load_dotenv()
+
+parser = argparse.ArgumentParser(
+ description="Launch yes-no-maybe KL advantage training (Tinker) on SkyPilot."
+)
+parser.add_argument(
+ "--fast", action="store_true", help="Skip setup (for re-runs on existing cluster)."
+)
+parser.add_argument(
+ "--base-model", type=str, default="meta-llama/Llama-3.1-8B-Instruct"
+)
+parser.add_argument("--num-steps", type=int, default=20)
+parser.add_argument("--kl-penalty-coef", type=float, default=0.1)
+parser.add_argument("--accelerator", type=str, default="H200:1")
+parser.add_argument("--cluster-name", type=str, default=None)
+parser.add_argument(
+ "--kl-ref-step",
+ type=int,
+ default=None,
+ help="Checkpoint step of training model to use as KL reference",
+)
+args = parser.parse_args()
+
+cluster_name = args.cluster_name or f"ynm-tinker-kl-{args.kl_penalty_coef}"
+cluster_prefix = os.environ.get("CLUSTER_PREFIX")
+if cluster_prefix:
+ cluster_name = f"{cluster_prefix}-{cluster_name}"
+
+setup_script = textwrap.dedent("""\
+ echo 'Setting up environment...'
+ apt install -y nvtop
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ source $HOME/.local/bin/env
+""")
+
+kl_ref_env = ""
+if args.kl_ref_step is not None:
+ kl_ref_env = f"KL_REF_STEP={args.kl_ref_step} "
+
+run_script = textwrap.dedent(f"""\
+ source $HOME/.local/bin/env
+ cd ~/sky_workdir
+ {kl_ref_env}BASE_MODEL={args.base_model} NUM_STEPS={args.num_steps} KL_PENALTY_COEF={args.kl_penalty_coef} uv run --python 3.11 --extra tinker dev/yes-no-maybe-kl-advantage-tinker.py
+""")
+
+task = sky.Task(
+ name="yes-no-maybe-kl-advantage-tinker",
+ setup=setup_script,
+ run=run_script,
+ workdir=".",
+)
+task.set_resources(
+ sky.Resources(accelerators=args.accelerator, cloud=sky.clouds.Kubernetes())
+)
+task.set_file_mounts(
+ {
+ "~/sky_workdir/.env": ".env",
+ }
+)
+
+print(f"Launching on cluster: {cluster_name}")
+print(f" base_model: {args.base_model}")
+print(f" accelerator: {args.accelerator}")
+print(f" num_steps: {args.num_steps}")
+print(f" kl_penalty_coef: {args.kl_penalty_coef}")
+if args.kl_ref_step is not None:
+ print(f" kl_ref_step: {args.kl_ref_step}")
+
+# Cancel any existing jobs on this cluster
+cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name]))
+if len(cluster_status) > 0 and cluster_status[0]["status"] == ClusterStatus.UP:
+ print(f"Cluster {cluster_name} is UP. Canceling any active jobs...")
+ sky.stream_and_get(sky.cancel(cluster_name, all=True))
+
+job_id, _ = sky.stream_and_get(
+ sky.launch(
+ task,
+ cluster_name=cluster_name,
+ retry_until_up=True,
+ idle_minutes_to_autostop=60,
+ down=True,
+ fast=args.fast,
+ )
+)
+
+print(f"Job submitted (ID: {job_id}). Streaming logs...")
+exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True)
+print(f"Job {job_id} finished with exit code {exit_code}.")
diff --git a/dev/yes-no-maybe-kl-advantage-tinker.py b/dev/yes-no-maybe-kl-advantage-tinker.py
new file mode 100644
index 00000000..5983a1f2
--- /dev/null
+++ b/dev/yes-no-maybe-kl-advantage-tinker.py
@@ -0,0 +1,111 @@
+"""Yes-no-maybe training with KL-penalized advantage adjustment (Tinker backend).
+
+Demonstrates the kl_penalty_coef feature: tokens where the policy has drifted
+more from the reference model get reduced advantages, while tokens that have
+drifted less get increased advantages.
+
+Uses meta-llama/Llama-3.1-8B-Instruct as the base model (trained via Tinker).
+"""
+
+import asyncio
+from itertools import permutations
+import os
+import random
+import string
+
+from dotenv import load_dotenv
+import openai
+
+import art
+from art.tinker_native import TinkerNativeBackend
+
+
+async def rollout(
+ client: openai.AsyncOpenAI, model: art.TrainableModel, prompt: str
+) -> art.Trajectory:
+ messages: art.Messages = [
+ {
+ "role": "user",
+ "content": prompt,
+ }
+ ]
+ chat_completion = await client.chat.completions.create(
+ messages=messages, model=model.get_inference_name(), max_tokens=100, timeout=100
+ )
+ choice = chat_completion.choices[0]
+ content = choice.message.content
+ assert isinstance(content, str)
+ if content == "yes":
+ reward = 0.5
+ elif content == "no":
+ reward = 0.75
+ elif content == "maybe":
+ reward = 1.0
+ else:
+ reward = 0.0
+ return art.Trajectory(messages_and_choices=[*messages, choice], reward=reward)
+
+
+def with_quotes(w: str) -> str:
+ return f"'{w}'"
+
+
+async def main():
+ load_dotenv()
+
+ backend = TinkerNativeBackend()
+ base_model = os.environ.get("BASE_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
+ kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1"))
+ random_suffix = "".join(random.choices(string.ascii_lowercase, k=4))
+ model = art.TrainableModel(
+ name=os.environ.get("MODEL_NAME", f"tinker-{random_suffix}-{kl_penalty_coef}"),
+ project="yes-no-maybe",
+ base_model=base_model,
+ )
+ await model.register(backend)
+
+ kl_penalty_reference_step: int | None = (
+ int(os.environ["KL_REF_STEP"])
+ if os.environ.get("KL_REF_STEP") is not None
+ else None
+ )
+
+ prompts = [
+ f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
+ for prefix in ["respond", "just respond"]
+ for use_quotes in [True, False]
+ for words in (
+ list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n)
+ )
+ ]
+
+ openai_client = model.openai_client()
+ max_steps = int(os.environ.get("NUM_STEPS", "20"))
+ start_step = await model.get_step()
+ for step in range(start_step, start_step + max_steps):
+ train_groups = await art.gather_trajectory_groups(
+ (
+ art.TrajectoryGroup(
+ rollout(openai_client, model, prompt) for _ in range(32)
+ )
+ for prompt in prompts
+ )
+ )
+ result = await backend.train(
+ model,
+ train_groups,
+ learning_rate=1e-4,
+ kl_penalty_coef=kl_penalty_coef,
+ kl_penalty_reference_step=kl_penalty_reference_step,
+ )
+ await model.log(
+ train_groups,
+ metrics=result.metrics,
+ step=result.step,
+ split="train",
+ )
+ print(f"step {result.step}: {result.metrics}")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/dev/yes-no-maybe-kl-advantage.py b/dev/yes-no-maybe-kl-advantage.py
index 41ce0b11..ccd21b24 100644
--- a/dev/yes-no-maybe-kl-advantage.py
+++ b/dev/yes-no-maybe-kl-advantage.py
@@ -10,6 +10,8 @@
import asyncio
from itertools import permutations
import os
+import random
+import string
from dotenv import load_dotenv
import openai
@@ -54,8 +56,9 @@ async def main():
backend = LocalBackend()
base_model = os.environ.get("BASE_MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
kl_penalty_coef = float(os.environ.get("KL_PENALTY_COEF", "0.1"))
+ random_suffix = "".join(random.choices(string.ascii_lowercase, k=4))
model = art.TrainableModel(
- name=os.environ.get("MODEL_NAME", f"kl-{kl_penalty_coef}"),
+ name=os.environ.get("MODEL_NAME", f"local-{random_suffix}-{kl_penalty_coef}"),
project="yes-no-maybe",
base_model=base_model,
)
diff --git a/docs/features/checkpoint-forking.mdx b/docs/features/checkpoint-forking.mdx
index c3d3603d..4a214c85 100644
--- a/docs/features/checkpoint-forking.mdx
+++ b/docs/features/checkpoint-forking.mdx
@@ -31,7 +31,7 @@ import art
from art.local import LocalBackend
async def train():
- with LocalBackend() as backend:
+ async with LocalBackend() as backend:
# Create a new model that will fork from an existing checkpoint
model = art.TrainableModel(
name="my-model-v2",
@@ -115,7 +115,7 @@ low_lr_model = art.TrainableModel(
)
async def experiment():
- with LocalBackend() as backend:
+ async with LocalBackend() as backend:
# Fork the model from the base model
await backend._experimental_fork_checkpoint(
low_lr_model,
diff --git a/docs/fundamentals/art-backend.mdx b/docs/fundamentals/art-backend.mdx
index f65473a9..9e6019c0 100644
--- a/docs/fundamentals/art-backend.mdx
+++ b/docs/fundamentals/art-backend.mdx
@@ -73,6 +73,30 @@ backend = LocalBackend(
)
```
+If you're using `PipelineTrainer`, `LocalBackend` is currently supported only in dedicated mode, where training and inference run on separate GPUs.
+
+```python
+from art import TrainableModel
+from art.dev import InternalModelConfig
+from art.local import LocalBackend
+
+backend = LocalBackend(path="./.art")
+model = TrainableModel(
+ name="pipeline-localbackend",
+ project="my-project",
+ base_model="Qwen/Qwen3-0.6B",
+ _internal_config=InternalModelConfig(
+ trainer_gpu_ids=[0],
+ inference_gpu_ids=[1],
+ ),
+)
+```
+
+Shared `LocalBackend` still pauses inference during training, so ART rejects that configuration for `PipelineTrainer`.
+
+In dedicated mode, a new checkpoint becomes the default inference target only after its LoRA has been reloaded into vLLM. That checkpoint publication flow is backend-specific, so `save_checkpoint` does not have identical semantics across every ART backend.
+Requests that are already in flight keep using the adapter they started with; the reload only affects subsequent routing to the latest served step.
+
## Using a backend
Once initialized, a backend can be used in the same way regardless of whether it runs locally or remotely.
diff --git a/docs/fundamentals/training-loop.mdx b/docs/fundamentals/training-loop.mdx
index b7bc8fe9..4c8f4f75 100644
--- a/docs/fundamentals/training-loop.mdx
+++ b/docs/fundamentals/training-loop.mdx
@@ -22,6 +22,8 @@ ART's functionality is divided into a [**client**](/fundamentals/art-client) and
This training loop runs until a specified number of inference and training iterations have completed.
+This describes the default shared-resource loop. `PipelineTrainer` can also run with `LocalBackend` in dedicated mode, where training and inference stay on separate GPUs and the latest served step advances only after vLLM reloads the new LoRA.
+
Training and inference use both the ART **client** and **backend**. Learn more by following the links below!
diff --git a/src/art/dev/train.py b/src/art/dev/train.py
index b0e232c5..5da3e1ab 100644
--- a/src/art/dev/train.py
+++ b/src/art/dev/train.py
@@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False):
]
kimi_k2_tau: float | None
kl_penalty_coef: float
+ kl_penalty_source: Literal["current_learner", "sample"]
kl_ref_adapter_path: str | None
logprob_calculation_chunk_size: int
mask_prob_ratio: bool
diff --git a/src/art/local/backend.py b/src/art/local/backend.py
index 5baf200f..723383ad 100644
--- a/src/art/local/backend.py
+++ b/src/art/local/backend.py
@@ -159,6 +159,9 @@ def _allocated_gpu_count(self, model: Model) -> int:
def __enter__(self) -> Self:
return self
+ async def __aenter__(self) -> Self:
+ return self
+
def __exit__(
self,
exc_type: type[BaseException] | None,
@@ -167,14 +170,30 @@ def __exit__(
) -> None:
self._close()
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc: BaseException | None,
+ tb: TracebackType | None,
+ ) -> None:
+ await self.close()
+
async def close(self) -> None:
"""
If running vLLM in a separate process, this will kill that process and close the communication threads.
"""
- self._close()
+ for service in self._services.values():
+ aclose = getattr(service, "aclose", None)
+ if aclose is None:
+ close = getattr(service, "close", None)
+ if close is not None:
+ close()
+ else:
+ await aclose()
+ close_proxy(service)
def _close(self) -> None:
- for _, service in self._services.items():
+ for service in self._services.values():
close = getattr(service, "close", None)
if close is not None:
close()
@@ -219,25 +238,27 @@ def _model_inference_name(self, model: Model, step: int | None = None) -> str:
If None, returns name for latest checkpoint (step 0 initially).
"""
- # For LocalBackend, vLLM always serves LoRA adapters with @step suffix
- # Default to step 0 when not specified (the initial checkpoint created at registration)
- if step is not None:
- actual_step = step
- elif model.name in self._services and self._in_process:
- # In dedicated mode the service tracks which adapter vLLM has
- # actually loaded. Reading the filesystem would race: the
- # checkpoint directory appears before the HTTP reload completes.
- svc = self._services[model.name]
- loaded_step = getattr(svc, "_latest_step", None)
- actual_step = (
- loaded_step if loaded_step is not None else self.__get_step(model)
- )
- else:
- actual_step = self.__get_step(model)
- name = f"{model.name}@{actual_step}"
+ requested_step = step
+
+ if step is None and isinstance(model, TrainableModel):
+ from ..dev.validate import is_dedicated_mode
+
+ service = self._services.get(model.name)
+ if service is not None and is_dedicated_mode(
+ model._internal_config or dev.InternalModelConfig()
+ ):
+ loaded_step = getattr(service, "_latest_step", None)
+ if isinstance(loaded_step, int):
+ step = loaded_step
+
+ if step is None:
+ # The checkpoint directory is written before dedicated-mode
+ # vLLM finishes reloading the new adapter.
+ step = self.__get_step(model)
+ name = f"{model.name}@{step}"
logger.debug(
- f"[BACKEND] _model_inference_name: step_arg={step} "
- f"actual_step={actual_step} -> {name}"
+ f"[BACKEND] _model_inference_name: step_arg={requested_step} "
+ f"actual_step={step} -> {name}"
)
return name
@@ -502,12 +523,15 @@ async def train( # type: ignore[override]
*,
# Core training parameters
learning_rate: float = 5e-6,
+ loss_fn: Literal["cispo", "ppo"] = "cispo",
+ loss_fn_config: dict | None = None,
+ normalize_advantages: bool = True,
+ adam_params: object | None = None,
# KL-penalized advantage adjustment
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_ref_adapter_path: str | None = None,
- # RL algorithm settings
- ppo: bool = False,
+ kl_penalty_source: Literal["current_learner", "sample"] = "current_learner",
epsilon: float | None = None,
epsilon_high: float | None = None,
# Advantage computation
@@ -544,6 +568,14 @@ async def train( # type: ignore[override]
model: The trainable model to train.
trajectory_groups: Batches of trajectories to train on.
learning_rate: Learning rate for training. Defaults to 5e-6.
+ loss_fn: RL loss function. LocalBackend currently supports
+ "cispo" and "ppo".
+ loss_fn_config: Additional loss-function config. Not supported by
+ LocalBackend.
+ normalize_advantages: Whether to normalize advantages. LocalBackend
+ currently requires True.
+ adam_params: Custom optimizer params. Not supported by
+ LocalBackend.
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
Tokens diverging more from the reference get reduced advantages.
Defaults to 0.0 (disabled).
@@ -553,8 +585,12 @@ async def train( # type: ignore[override]
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
checkpoint to use as the KL reference. Alternative to
kl_penalty_reference_step.
- ppo: Whether to use PPO clipping. Defaults to False.
- epsilon: Clip epsilon for importance sampling. Defaults based on ppo.
+ kl_penalty_source: Which policy's logprobs to compare against the
+ reference when building the centered KL penalty. Use
+ "current_learner" to match the original ART implementation, or
+ "sample" to shape from the rollout policy logprobs, which is
+ usually better for async/off-policy workloads.
+ epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
advantage_balance: Balance between negative and positive advantages
in range [-1.0, 1.0]. Defaults to 0.0 (balanced).
@@ -597,19 +633,31 @@ async def train( # type: ignore[override]
# await model.log(metrics=result.metrics, step=result.step)
"""
groups_list = list(trajectory_groups)
+ if loss_fn not in {"cispo", "ppo"}:
+ raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.")
+ if loss_fn_config is not None:
+ raise ValueError("LocalBackend requires loss_fn_config=None.")
+ if not normalize_advantages:
+ raise ValueError("LocalBackend requires normalize_advantages=True.")
+ if adam_params is not None:
+ raise ValueError("LocalBackend requires adam_params=None.")
+ assert kl_penalty_source in {"current_learner", "sample"}
# Build config objects from explicit kwargs
config = TrainConfig(
- learning_rate=learning_rate, kl_penalty_coef=kl_penalty_coef
+ learning_rate=learning_rate,
+ kl_penalty_coef=kl_penalty_coef,
+ kl_penalty_source=kl_penalty_source,
)
dev_config: dev.TrainConfig = {
"advantage_balance": advantage_balance,
"allow_training_without_logprobs": allow_training_without_logprobs,
"importance_sampling_level": importance_sampling_level,
"kl_penalty_coef": kl_penalty_coef,
+ "kl_penalty_source": kl_penalty_source,
"mask_prob_ratio": mask_prob_ratio,
"plot_tensors": plot_tensors,
- "ppo": ppo,
+ "ppo": loss_fn == "ppo",
"precalculate_logprobs": precalculate_logprobs,
"scale_learning_rate_by_reward_std_dev": scale_learning_rate_by_reward_std_dev,
"scale_rewards": scale_rewards,
diff --git a/src/art/loss.py b/src/art/loss.py
index 5a73d7b7..59cfa46a 100644
--- a/src/art/loss.py
+++ b/src/art/loss.py
@@ -95,7 +95,14 @@ def loss_fn(
kl_policy_ref: torch.Tensor | None = None
kl_penalty_coef = experimental_config.get("kl_penalty_coef", 0.0)
if kl_penalty_coef > 0 and ref_logprobs is not None:
- kl_per_token = (new_logprobs - ref_logprobs).detach() * assistant_mask
+ match experimental_config.get("kl_penalty_source", "current_learner"):
+ case "sample":
+ kl_source_logprobs = old_logprobs.detach()
+ case "current_learner":
+ kl_source_logprobs = new_logprobs.detach()
+ case other:
+ raise AssertionError(other)
+ kl_per_token = (kl_source_logprobs - ref_logprobs).detach() * assistant_mask
avg_kl = kl_per_token.sum() / (assistant_mask.sum() + 1e-6)
kl_penalty = kl_penalty_coef * (avg_kl - kl_per_token) * assistant_mask
advantages = advantages + kl_penalty
diff --git a/src/art/pipeline_trainer/trainer.py b/src/art/pipeline_trainer/trainer.py
index 1458d153..2522cc58 100644
--- a/src/art/pipeline_trainer/trainer.py
+++ b/src/art/pipeline_trainer/trainer.py
@@ -78,6 +78,8 @@ def __init__(
loss_fn_config: dict | None = None,
normalize_advantages: bool = True,
adam_params: object | None = None,
+ kl_penalty_coef: float = 0.0,
+ kl_penalty_step_lag: int | None = None,
max_steps: int | None = None,
# Discard handling
discard_queue_multiplier: int = 100,
@@ -112,6 +114,8 @@ def __init__(
raise ValueError("log_interval_seconds must be > 0")
if discard_queue_multiplier <= 0:
raise ValueError("discard_queue_multiplier must be > 0")
+ if kl_penalty_step_lag is not None and kl_penalty_step_lag < 1:
+ raise ValueError("kl_penalty_step_lag must be >= 1")
self.model = model
self.backend = backend
self.rollout_fn = rollout_fn
@@ -129,6 +133,8 @@ def __init__(
self.loss_fn_config = loss_fn_config
self.normalize_advantages = normalize_advantages
self.adam_params = adam_params
+ self.kl_penalty_coef = kl_penalty_coef
+ self.kl_penalty_step_lag = kl_penalty_step_lag
self.max_steps = max_steps
self._status_log_interval_seconds = log_interval_seconds
self.eval_every_n_steps = eval_every_n_steps
@@ -154,6 +160,7 @@ def __init__(
total_scenarios=total_scenarios,
num_workers=num_rollout_workers,
)
+ self._validate_backend_support()
async def train(self, *, handle_signals: bool = True) -> None:
"""Run the training pipeline over the configured scenario iterator."""
@@ -277,6 +284,42 @@ async def _notify_policy() -> None:
except asyncio.QueueFull:
loop.create_task(self._output_queue.put(None))
+ def _validate_backend_support(self) -> None:
+ from art.dev.validate import is_dedicated_mode
+ from art.local.backend import LocalBackend
+
+ if not isinstance(self.backend, LocalBackend):
+ return
+
+ model_config = self.model._internal_config or art.dev.InternalModelConfig()
+ if not is_dedicated_mode(model_config):
+ raise ValueError(
+ "PipelineTrainer only supports LocalBackend in dedicated mode. "
+ "Shared LocalBackend pauses inference during training and is not "
+ "a supported async PipelineTrainer path. Set both "
+ "trainer_gpu_ids and inference_gpu_ids on the TrainableModel "
+ "_internal_config to use LocalBackend with PipelineTrainer."
+ )
+ if self.loss_fn not in {"cispo", "ppo"}:
+ raise ValueError(
+ "PipelineTrainer + LocalBackend(dedicated) only supports "
+ "loss_fn='cispo' or loss_fn='ppo'."
+ )
+ if self.loss_fn_config is not None:
+ raise ValueError(
+ "PipelineTrainer + LocalBackend(dedicated) requires "
+ "loss_fn_config=None."
+ )
+ if not self.normalize_advantages:
+ raise ValueError(
+ "PipelineTrainer + LocalBackend(dedicated) requires "
+ "normalize_advantages=True."
+ )
+ if self.adam_params is not None:
+ raise ValueError(
+ "PipelineTrainer + LocalBackend(dedicated) requires adam_params=None."
+ )
+
async def _skip_scenarios(
self, scenarios: AsyncIterator[ScenarioT], count: int
) -> int:
@@ -415,6 +458,16 @@ async def _training_stage(self) -> None:
if os.getenv("ART_TRAIN_STEP_LOG"):
print(f"[train] step {expected_step} starting (batch={len(batch)})")
try:
+ kl_train_kwargs: dict[str, object] = {}
+ if self.kl_penalty_coef > 0.0:
+ kl_train_kwargs["kl_penalty_coef"] = self.kl_penalty_coef
+ kl_train_kwargs["kl_penalty_source"] = "sample"
+ if self.kl_penalty_step_lag is None:
+ kl_train_kwargs["kl_penalty_reference_step"] = 0
+ else:
+ kl_train_kwargs["kl_penalty_reference_step"] = max(
+ 0, current_step - self.kl_penalty_step_lag
+ )
result = await self.backend.train(
self.model,
batch,
@@ -424,6 +477,7 @@ async def _training_stage(self) -> None:
normalize_advantages=self.normalize_advantages,
save_checkpoint=should_checkpoint,
adam_params=self.adam_params,
+ **kl_train_kwargs,
)
except Exception:
self._status.note_training_end()
@@ -515,7 +569,7 @@ async def _collect_batch(
continue
batch.append(item)
- while not saw_sentinel:
+ while not saw_sentinel and len(batch) < self.max_batch_size:
try:
item = self._output_queue.get_nowait()
except asyncio.QueueEmpty:
diff --git a/src/art/test/test_kl_advantage.py b/src/art/test/test_kl_advantage.py
index d944efc6..82c0f2a2 100644
--- a/src/art/test/test_kl_advantage.py
+++ b/src/art/test/test_kl_advantage.py
@@ -2,7 +2,7 @@
import torch
-from art.loss import Loss, loss_fn
+from art.loss import loss_fn, shift_tensor
def _make_inputs(
@@ -114,3 +114,50 @@ def test_kl_advantage_does_not_affect_when_no_ref():
loss = loss_fn(inputs, new_logprobs, None, None, {"kl_penalty_coef": 0.5})
assert loss.kl_policy_ref is None
+
+
+def test_kl_advantage_can_use_sample_logprobs() -> None:
+ """Sample-source KL should use stored rollout logprobs rather than learner logprobs."""
+ inputs = _make_inputs(seq_len=8)
+ inputs["logprobs"] = torch.tensor(
+ [[0.0, -0.2, -0.4, -0.6, -0.8, -1.0, -1.2, -1.4]], dtype=torch.float32
+ )
+ new_logprobs = torch.tensor(
+ [[0.0, -1.2, -1.1, -1.0, -0.9, -0.8, -0.7, -0.6]], dtype=torch.float32
+ )
+ ref_logprobs = torch.full((1, 8), -0.5)
+ assistant_mask = shift_tensor(inputs["assistant_mask"], False).to(
+ new_logprobs.dtype
+ )
+ sampled_logprobs = torch.where(
+ torch.isnan(shift_tensor(inputs["logprobs"], float("nan"))),
+ new_logprobs.detach(),
+ shift_tensor(inputs["logprobs"], float("nan")),
+ )
+ expected_sample_kl = ((sampled_logprobs - ref_logprobs) * assistant_mask).sum() / (
+ assistant_mask.sum() + 1e-6
+ )
+ expected_current_kl = ((new_logprobs - ref_logprobs) * assistant_mask).sum() / (
+ assistant_mask.sum() + 1e-6
+ )
+
+ sample_loss = loss_fn(
+ inputs,
+ new_logprobs,
+ ref_logprobs,
+ None,
+ {"kl_penalty_coef": 0.5, "kl_penalty_source": "sample"},
+ )
+ learner_loss = loss_fn(
+ inputs,
+ new_logprobs,
+ ref_logprobs,
+ None,
+ {"kl_penalty_coef": 0.5, "kl_penalty_source": "current_learner"},
+ )
+
+ assert sample_loss.kl_policy_ref is not None
+ assert learner_loss.kl_policy_ref is not None
+ assert torch.isclose(sample_loss.kl_policy_ref, expected_sample_kl)
+ assert torch.isclose(learner_loss.kl_policy_ref, expected_current_kl)
+ assert not torch.isclose(sample_loss.kl_policy_ref, learner_loss.kl_policy_ref)
diff --git a/src/art/test/test_step_skipping.py b/src/art/test/test_step_skipping.py
index f4c85a1b..0a048b5a 100755
--- a/src/art/test/test_step_skipping.py
+++ b/src/art/test/test_step_skipping.py
@@ -44,7 +44,7 @@ async def test_step_skipping():
# Set up backend with custom art path
art_path = os.path.join(tmpdir, ".art")
- with LocalBackend(path=art_path) as backend:
+ async with LocalBackend(path=art_path) as backend:
# Create a test model
model = TrainableModel(
name=f"test-step-skip-{uuid.uuid4()}",
diff --git a/src/art/tinker_native/backend.py b/src/art/tinker_native/backend.py
index c1687bf7..65f59ca6 100644
--- a/src/art/tinker_native/backend.py
+++ b/src/art/tinker_native/backend.py
@@ -24,6 +24,7 @@
from openai.types.chat.completion_create_params import CompletionCreateParams
from openai.types.completion_usage import CompletionUsage
import tinker
+import torch
import uvicorn
from art.tinker.cookbook_v import renderers, tokenizer_utils
@@ -82,6 +83,76 @@ def _canonicalize_upstream_metric_key(metric: str) -> str:
return _UPSTREAM_TRAIN_METRIC_KEYS.get(metric, metric)
+async def _apply_kl_penalty(
+ datums: list[tinker.Datum],
+ reference_sampling_client: tinker.SamplingClient,
+ kl_penalty_coef: float,
+) -> dict[str, float]:
+ assert datums
+ assert kl_penalty_coef > 0.0
+
+ full_sequences: list[tinker.ModelInput] = []
+ sampled_logprobs_by_datum: list[torch.Tensor] = []
+ masks_by_datum: list[torch.Tensor] = []
+ advantages_by_datum: list[torch.Tensor] = []
+ for datum in datums:
+ target_tokens = datum.loss_fn_inputs["target_tokens"].to_torch()
+ assert target_tokens.numel() > 0
+ full_sequences.append(
+ datum.model_input.append_int(int(target_tokens[-1].item()))
+ )
+ sampled_logprobs_by_datum.append(datum.loss_fn_inputs["logprobs"].to_torch())
+ masks_by_datum.append(datum.loss_fn_inputs["mask"].to_torch().float())
+ advantages_by_datum.append(datum.loss_fn_inputs["advantages"].to_torch())
+
+ reference_logprobs_by_datum = await asyncio.gather(
+ *[
+ reference_sampling_client.compute_logprobs_async(full_sequence)
+ for full_sequence in full_sequences
+ ]
+ )
+
+ logprob_diffs_by_datum: list[torch.Tensor] = []
+ for reference_logprobs, sampled_logprobs, mask in zip(
+ reference_logprobs_by_datum,
+ sampled_logprobs_by_datum,
+ masks_by_datum,
+ strict=True,
+ ):
+ reference_values = reference_logprobs[1:]
+ assert len(reference_values) == sampled_logprobs.numel()
+ assert all(value is not None for value in reference_values)
+ reference_logprobs_tensor = torch.tensor(
+ reference_values,
+ dtype=sampled_logprobs.dtype,
+ )
+ logprob_diffs_by_datum.append(
+ (sampled_logprobs - reference_logprobs_tensor) * mask
+ )
+
+ total_tokens = torch.stack([mask.sum() for mask in masks_by_datum]).sum()
+ assert total_tokens.item() > 0
+ avg_logprob_diff = (
+ torch.stack(
+ [logprob_diff.sum() for logprob_diff in logprob_diffs_by_datum]
+ ).sum()
+ / total_tokens
+ )
+
+ for datum, advantages, mask, logprob_diff in zip(
+ datums,
+ advantages_by_datum,
+ masks_by_datum,
+ logprob_diffs_by_datum,
+ strict=True,
+ ):
+ datum.loss_fn_inputs["advantages"] = tinker.TensorData.from_torch(
+ advantages + kl_penalty_coef * (avg_logprob_diff - logprob_diff) * mask
+ )
+
+ return {"loss/kl_policy_ref": float(avg_logprob_diff)}
+
+
@dataclass
class ModelState:
service_client: tinker.ServiceClient
@@ -239,7 +310,14 @@ async def train( # type: ignore[override]
save_checkpoint: bool = False,
loss_fn_config: dict | None = None,
adam_params: tinker.AdamParams | None = None,
+ kl_penalty_coef: float = 0.0,
+ kl_penalty_reference_step: int | None = None,
+ kl_penalty_source: Literal["sample"] = "sample",
) -> TrainResult:
+ assert kl_penalty_source == "sample", (
+ "TinkerNativeBackend only supports kl_penalty_source='sample'."
+ )
+
state = self._model_state[model.name]
groups_list = list(trajectory_groups)
summary = summarize_trajectory_groups(groups_list)
@@ -272,6 +350,23 @@ async def train( # type: ignore[override]
train_tokens, pricing
)
trainer_started = time.monotonic()
+ sampled_kl_policy_ref: float | None = None
+
+ if kl_penalty_coef > 0:
+ kl_metrics = await self._tinker_sample_call(
+ "apply_kl_penalty",
+ _apply_kl_penalty(
+ datums,
+ await self._get_kl_reference_sampling_client(
+ state,
+ model.base_model,
+ kl_penalty_reference_step,
+ ),
+ kl_penalty_coef,
+ ),
+ )
+ sampled_kl_policy_ref = kl_metrics["loss/kl_policy_ref"]
+ metrics.update(kl_metrics)
if adam_params is None:
adam_params = tinker.AdamParams(
@@ -310,6 +405,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum:
if value is None:
continue
canonical_key = _canonicalize_upstream_metric_key(key)
+ if (
+ sampled_kl_policy_ref is not None
+ and canonical_key == "loss/kl_policy_ref"
+ ):
+ continue
if canonical_key:
metrics[canonical_key] = float(value)
if optim_output.metrics:
@@ -317,6 +417,11 @@ def remove_mask(datum: tinker.Datum) -> tinker.Datum:
if value is None:
continue
canonical_key = _canonicalize_upstream_metric_key(key)
+ if (
+ sampled_kl_policy_ref is not None
+ and canonical_key == "loss/kl_policy_ref"
+ ):
+ continue
if canonical_key:
metrics[canonical_key] = float(value)
@@ -697,6 +802,19 @@ async def _get_sampler_client(
state.sampler_clients[actual_step] = sampler_client
return sampler_client
+ async def _get_kl_reference_sampling_client(
+ self,
+ state: ModelState,
+ base_model: str,
+ step: int | None,
+ ) -> tinker.SamplingClient:
+ if step is not None:
+ return await self._get_sampler_client(state, step)
+ return await self._tinker_sample_call(
+ "create_sampling_client_async",
+ state.service_client.create_sampling_client_async(base_model=base_model),
+ )
+
def _normalize_messages(self, messages: Iterable[Any]) -> list[dict[str, Any]]:
normalized: list[dict[str, Any]] = []
for message in messages:
diff --git a/src/art/types.py b/src/art/types.py
index 088041ad..317fc156 100644
--- a/src/art/types.py
+++ b/src/art/types.py
@@ -17,6 +17,7 @@
class TrainConfig(pydantic.BaseModel):
learning_rate: float = 5e-6
kl_penalty_coef: float = 0.0
+ kl_penalty_source: Literal["current_learner", "sample"] = "current_learner"
class TrainSFTConfig(pydantic.BaseModel):
diff --git a/src/art/unsloth/service.py b/src/art/unsloth/service.py
index f24be80f..5b6a563c 100644
--- a/src/art/unsloth/service.py
+++ b/src/art/unsloth/service.py
@@ -35,7 +35,7 @@
from ..utils.get_model_step import get_step_from_dir
from ..utils.output_dirs import get_step_checkpoint_dir
from ..vllm import get_llm, get_worker, openai_server_task, run_on_workers
-from .train import gc_and_empty_cuda_cache, train
+from .train import StopTrainingLoop, gc_and_empty_cuda_cache, train
logger = logging.getLogger(__name__)
@@ -55,6 +55,15 @@ class SupportsLoadLora(Protocol):
def load_lora(self, lora_path: str, load_tensors: bool = True) -> LoRARequest: ...
+class _StopTrainInputs:
+ """Dedicated sentinel for stopping the background trainer loop."""
+
+
+_STOP_TRAIN_INPUT = _StopTrainInputs()
+_TRAIN_TASK_SHUTDOWN_TIMEOUT_S = 5.0
+_TrainLoopInput = TrainInputs | _StopTrainInputs
+
+
def precalculate_new_logprobs(
trainer: "GRPOTrainer",
peft_model: "PeftModelForCausalLM",
@@ -91,7 +100,7 @@ async def process_train_batch(
packed_tensors: PackedTensors,
config: types.TrainConfig,
_config: dev.TrainConfig,
- inputs_queue: asyncio.Queue[TrainInputs],
+ inputs_queue: asyncio.Queue[_TrainLoopInput],
results_queue: asyncio.Queue[dict[str, float]],
train_task: asyncio.Task[None],
trainer: "GRPOTrainer",
@@ -215,7 +224,7 @@ class UnslothState:
tokenizer: PreTrainedTokenizerBase
peft_model: peft.peft_model.PeftModelForCausalLM
trainer: GRPOTrainer
- inputs_queue: asyncio.Queue[TrainInputs]
+ inputs_queue: asyncio.Queue[_TrainLoopInput]
results_queue: asyncio.Queue[dict[str, float]]
_is_offloaded: bool = False
_pinned_buffers: dict[str, torch.Tensor] | None = None
@@ -316,6 +325,7 @@ class UnslothService:
_vllm_log_file: Any = field(default=None, repr=False)
_vllm_host: str = "127.0.0.1"
_vllm_port: int = 0
+ _train_task: asyncio.Task[None] | None = field(default=None, init=False, repr=False)
@property
def is_dedicated(self) -> bool:
@@ -326,6 +336,24 @@ def _next_lora_id(self) -> int:
self._lora_id_counter += 1
return self._lora_id_counter
+ async def aclose(self) -> None:
+ train_task = self._train_task
+ self._train_task = None
+ if train_task is None or train_task.done():
+ self.close()
+ return
+
+ # `_state` is a cached_property. Read from __dict__ directly so
+ # closing does not instantiate trainer state only to stop a task.
+ state = self.__dict__.get("_state")
+ assert isinstance(state, UnslothState)
+ state.inputs_queue.put_nowait(_STOP_TRAIN_INPUT)
+ try:
+ await asyncio.wait_for(train_task, timeout=_TRAIN_TASK_SHUTDOWN_TIMEOUT_S)
+ except asyncio.TimeoutError:
+ train_task.cancel()
+ self.close()
+
# =========================================================================
# Dedicated mode: vLLM subprocess lifecycle
# =========================================================================
@@ -595,7 +623,7 @@ async def _train_dedicated(
await self._state.results_queue.join()
- if not hasattr(self, "_train_task") or self._train_task is None:
+ if self._train_task is None:
self._train_task = asyncio.create_task(
train(
trainer=self._state.trainer,
@@ -685,7 +713,7 @@ async def _train_shared(
await self._state.results_queue.join()
# If we haven't already, start the training task
- if not hasattr(self, "_train_task") or self._train_task is None:
+ if self._train_task is None:
self._train_task = asyncio.create_task(
train(
trainer=self._state.trainer,
@@ -981,17 +1009,19 @@ def _state(self) -> UnslothState:
trainer.create_optimizer()
# Initialize queues
- inputs_queue: asyncio.Queue[TrainInputs] = asyncio.Queue()
+ inputs_queue: asyncio.Queue[_TrainLoopInput] = asyncio.Queue()
results_queue: asyncio.Queue[dict[str, float]] = asyncio.Queue()
# Patch trainer _prepare_inputs() to pull from queue
def _async_prepare_inputs(*_: Any, **__: Any) -> dict[str, torch.Tensor]:
- async def get_inputs() -> TrainInputs:
+ async def get_inputs() -> _TrainLoopInput:
return await inputs_queue.get()
# Force otherwise synchronous _prepare_inputs() to yield
# with nested asyncio.run() call
inputs = asyncio.run(get_inputs())
+ if isinstance(inputs, _StopTrainInputs):
+ raise StopTrainingLoop()
return cast(dict[str, torch.Tensor], inputs)
diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py
index f6c42a2c..43387904 100644
--- a/src/art/unsloth/train.py
+++ b/src/art/unsloth/train.py
@@ -39,6 +39,10 @@
}
+class StopTrainingLoop(Exception):
+ """Signal that the background trainer loop should exit cleanly."""
+
+
def _canonicalize_upstream_metric_key(metric: str) -> str:
if "/" in metric:
return metric
@@ -75,6 +79,8 @@ async def train(
trainer._metrics = {"train": defaultdict(list)}
try:
trainer.train()
+ except StopTrainingLoop:
+ return
finally:
trainer.compute_loss = _compute_loss
trainer.log = _log # ty:ignore[invalid-assignment]
diff --git a/src/art/utils/benchmarking/pull_model_trajectories.py b/src/art/utils/benchmarking/pull_model_trajectories.py
index b8a06bfe..2708d5cf 100644
--- a/src/art/utils/benchmarking/pull_model_trajectories.py
+++ b/src/art/utils/benchmarking/pull_model_trajectories.py
@@ -31,8 +31,9 @@ async def pull_model_trajectories(model: ArtModel) -> None:
"Environment variable BACKUP_BUCKET is required but was not found."
)
- # Use the LocalBackend context manager to work with the on-disk artefacts.
- with LocalBackend() as backend:
+ # Use the LocalBackend async context manager so backend cleanup can await
+ # any background service shutdown before returning.
+ async with LocalBackend() as backend:
print(
f"Pulling trajectories for model '{model.name}' from S3 bucket '{bucket}'…",
flush=True,
diff --git a/tests/integration/test_pipeline_localbackend_dedicated.py b/tests/integration/test_pipeline_localbackend_dedicated.py
new file mode 100644
index 00000000..f1154a33
--- /dev/null
+++ b/tests/integration/test_pipeline_localbackend_dedicated.py
@@ -0,0 +1,206 @@
+"""Dedicated LocalBackend smoke test for PipelineTrainer."""
+
+import asyncio
+import json
+import math
+import os
+from pathlib import Path
+import tempfile
+import uuid
+
+import openai
+import pytest
+
+torch = pytest.importorskip("torch")
+pytest.importorskip("vllm")
+
+import art
+from art.local import LocalBackend
+from art.pipeline_trainer import PipelineTrainer
+
+DEFAULT_BASE_MODEL = "Qwen/Qwen3-0.6B"
+DEFAULT_GPU_MEMORY_UTILIZATION = 0.2
+DEFAULT_MAX_MODEL_LEN = 2048
+DEFAULT_MAX_SEQ_LENGTH = 2048
+
+
+def get_base_model() -> str:
+ return os.environ.get("BASE_MODEL", DEFAULT_BASE_MODEL)
+
+
+def get_safe_gpu_memory_utilization() -> float:
+ requested = float(
+ os.environ.get(
+ "ART_TEST_GPU_MEMORY_UTILIZATION",
+ str(DEFAULT_GPU_MEMORY_UTILIZATION),
+ )
+ )
+ min_free_gib = float(os.environ.get("ART_TEST_MIN_FREE_GPU_GIB", "8"))
+ free_ratios: list[float] = []
+ for device in (0, 1):
+ free_bytes, total_bytes = torch.cuda.mem_get_info(device)
+ free_gib = free_bytes / (1024**3)
+ if free_gib < min_free_gib:
+ pytest.skip(
+ "Insufficient free GPU memory for dedicated LocalBackend smoke test: "
+ f"GPU {device} has {free_gib:.1f} GiB free < {min_free_gib:.1f} GiB required."
+ )
+ free_ratios.append(free_bytes / total_bytes)
+ return max(0.02, min(requested, min(free_ratios) * 0.8))
+
+
+def get_dedicated_vllm_test_config() -> art.dev.InternalModelConfig:
+ return {
+ "trainer_gpu_ids": [0],
+ "inference_gpu_ids": [1],
+ "engine_args": {
+ "gpu_memory_utilization": get_safe_gpu_memory_utilization(),
+ "max_model_len": int(
+ os.environ.get("ART_TEST_MAX_MODEL_LEN", str(DEFAULT_MAX_MODEL_LEN))
+ ),
+ "max_num_seqs": 8,
+ "enforce_eager": True,
+ },
+ "init_args": {
+ "max_seq_length": int(
+ os.environ.get("ART_TEST_MAX_SEQ_LENGTH", str(DEFAULT_MAX_SEQ_LENGTH))
+ ),
+ },
+ }
+
+
+def reward_for_answer(text: str) -> float:
+ content = text.lower()
+ if "yes" in content:
+ return 1.0
+ if "no" in content:
+ return 0.5
+ if "maybe" in content:
+ return 0.25
+ return 0.0
+
+
+async def assert_chat_logprobs(
+ client: openai.AsyncOpenAI,
+ model_name: str,
+) -> None:
+ completion = await client.chat.completions.create(
+ messages=[{"role": "user", "content": "Say hello."}],
+ model=model_name,
+ max_tokens=8,
+ timeout=60,
+ logprobs=True,
+ top_logprobs=0,
+ )
+ assert completion.choices[0].logprobs is not None
+
+
+@pytest.mark.skipif(
+ not torch.cuda.is_available() or torch.cuda.device_count() < 2,
+ reason="Need at least 2 CUDA GPUs for dedicated LocalBackend PipelineTrainer test",
+)
+async def test_pipeline_trainer_local_backend_dedicated_smoke() -> None:
+ model_name = f"test-pipeline-local-dedicated-{uuid.uuid4().hex[:8]}"
+ prompts = [
+ "Say yes",
+ "Say no",
+ "Say maybe",
+ "Say hello",
+ "Say yes again",
+ "Say no again",
+ ]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ async with LocalBackend(path=tmpdir) as backend:
+ model = art.TrainableModel(
+ name=model_name,
+ project="integration-tests",
+ base_model=get_base_model(),
+ _internal_config=get_dedicated_vllm_test_config(),
+ )
+
+ async def scenario_iter():
+ for prompt in prompts:
+ yield {"prompt": prompt}
+
+ await model.register(backend)
+ client = model.openai_client()
+ try:
+
+ async def rollout_fn(
+ rollout_model: art.TrainableModel,
+ scenario: dict[str, str],
+ _config: None,
+ ) -> art.TrajectoryGroup:
+ await asyncio.sleep(0.2)
+ messages: art.Messages = [
+ {"role": "user", "content": scenario["prompt"]}
+ ]
+ completion = await client.chat.completions.create(
+ messages=messages,
+ model=rollout_model.get_inference_name(),
+ max_tokens=10,
+ timeout=60,
+ temperature=1,
+ n=2,
+ logprobs=True,
+ top_logprobs=0,
+ )
+ return art.TrajectoryGroup(
+ [
+ art.Trajectory(
+ messages_and_choices=[*messages, choice],
+ reward=reward_for_answer(choice.message.content or ""),
+ )
+ for choice in completion.choices
+ ]
+ )
+
+ trainer = PipelineTrainer(
+ model=model,
+ backend=backend,
+ rollout_fn=rollout_fn,
+ scenarios=scenario_iter(),
+ config=None,
+ num_rollout_workers=2,
+ min_batch_size=1,
+ max_batch_size=1,
+ max_steps=2,
+ kl_penalty_coef=0.25,
+ loss_fn="cispo",
+ eval_fn=None,
+ )
+
+ await trainer.train()
+
+ latest_step = await model.get_step()
+ assert latest_step >= 2
+
+ await assert_chat_logprobs(client, model.get_inference_name(step=0))
+ await assert_chat_logprobs(
+ client, model.get_inference_name(step=latest_step)
+ )
+
+ model_ids = [m.id async for m in client.models.list()]
+ assert f"{model.name}@0" in model_ids
+ assert f"{model.name}@{latest_step}" in model_ids
+
+ history_path = (
+ Path(tmpdir)
+ / model.project
+ / "models"
+ / model.name
+ / "history.jsonl"
+ )
+ history_rows = [
+ json.loads(line) for line in history_path.read_text().splitlines()
+ ]
+ kl_values = [
+ row["loss/kl_policy_ref"]
+ for row in history_rows
+ if "loss/kl_policy_ref" in row
+ ]
+ assert kl_values
+ assert all(math.isfinite(value) for value in kl_values)
+ finally:
+ await client.close()
diff --git a/tests/integration/test_tinker_native_backend.py b/tests/integration/test_tinker_native_backend.py
index 09ff33c4..4e5c61a5 100644
--- a/tests/integration/test_tinker_native_backend.py
+++ b/tests/integration/test_tinker_native_backend.py
@@ -9,6 +9,8 @@
import art
from art.tinker_native import TinkerNativeBackend
+from art.tinker_native.backend import _apply_kl_penalty
+from art.tinker_native.data import trajectory_groups_to_datums
DEFAULT_BASE_MODEL = "Qwen/Qwen3-30B-A3B-Instruct-2507"
@@ -37,6 +39,8 @@ async def simple_rollout(
max_tokens=10,
timeout=60,
temperature=1,
+ logprobs=True,
+ top_logprobs=0,
)
choice = chat_completion.choices[0]
content = (choice.message.content or "").lower()
@@ -115,6 +119,85 @@ async def make_group(prompt: str) -> art.TrajectoryGroup:
await backend.close()
+@pytest.mark.skipif(
+ "TINKER_API_KEY" not in os.environ,
+ reason="TINKER_API_KEY not set - skipping TinkerNativeBackend KL test",
+)
+async def test_tinker_native_backend_kl_identity_metric():
+ model_name = f"test-tinker-native-kl-{uuid.uuid4().hex[:8]}"
+ with tempfile.TemporaryDirectory() as tmpdir:
+ backend = TinkerNativeBackend(path=tmpdir)
+ model = art.TrainableModel(
+ name=model_name,
+ project="integration-tests",
+ base_model=get_base_model(),
+ )
+ try:
+ await model.register(backend)
+
+ openai_client = model.openai_client()
+ current_step = await model.get_step()
+ model_name_step = model.get_inference_name(step=current_step)
+ prompts = ["Say yes", "Say no", "Say maybe"]
+
+ async def make_group(prompt: str) -> art.TrajectoryGroup:
+ import asyncio
+
+ trajectories = await asyncio.gather(
+ *[
+ simple_rollout(openai_client, model_name_step, prompt)
+ for _ in range(2)
+ ]
+ )
+ return art.TrajectoryGroup(trajectories) # type: ignore[attr-defined]
+
+ train_groups = await art.gather_trajectory_groups( # type: ignore[attr-defined]
+ [make_group(prompt) for prompt in prompts]
+ )
+ ensure_reward_variance(train_groups)
+
+ state = backend._model_state[model.name]
+ datums = trajectory_groups_to_datums(
+ train_groups,
+ state.renderer,
+ state.tokenizer,
+ )
+ assert datums
+
+ reference_sampling_client = await backend._get_kl_reference_sampling_client(
+ state,
+ model.base_model,
+ current_step,
+ )
+ expected_kl = (
+ await _apply_kl_penalty(
+ trajectory_groups_to_datums(
+ train_groups,
+ state.renderer,
+ state.tokenizer,
+ ),
+ reference_sampling_client,
+ kl_penalty_coef=0.25,
+ )
+ )["loss/kl_policy_ref"]
+
+ result = await backend.train(
+ model,
+ train_groups,
+ learning_rate=1e-5,
+ kl_penalty_coef=0.25,
+ kl_penalty_reference_step=current_step,
+ )
+
+ assert result.metrics["loss/kl_policy_ref"] == pytest.approx(
+ expected_kl,
+ abs=0.05,
+ )
+ assert result.metrics["loss/kl_policy_ref"] == pytest.approx(0.0, abs=0.05)
+ finally:
+ await backend.close()
+
+
@pytest.mark.skipif(
"TINKER_API_KEY" not in os.environ,
reason="TINKER_API_KEY not set - skipping TinkerNativeBackend fork test",
diff --git a/tests/unit/test_pipeline_trainer_batching.py b/tests/unit/test_pipeline_trainer_batching.py
new file mode 100644
index 00000000..0ab412e8
--- /dev/null
+++ b/tests/unit/test_pipeline_trainer_batching.py
@@ -0,0 +1,67 @@
+import asyncio
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from art import TrainableModel, Trajectory, TrajectoryGroup
+from art.pipeline_trainer.trainer import PipelineTrainer
+
+
+def _make_group() -> TrajectoryGroup:
+ return TrajectoryGroup(
+ [
+ Trajectory(
+ reward=reward,
+ initial_policy_version=0,
+ messages_and_choices=[
+ {"role": "user", "content": f"prompt-{idx}"},
+ {"role": "assistant", "content": f"answer-{idx}"},
+ ],
+ )
+ for idx, reward in enumerate([0.0, 1.0])
+ ]
+ )
+
+
+@pytest.mark.asyncio
+async def test_collect_batch_respects_max_batch_size(tmp_path: Path) -> None:
+ model = TrainableModel(
+ name="pipeline-max-batch-size-test",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+ trainer = PipelineTrainer(
+ model=model,
+ backend=MagicMock(), # type: ignore[arg-type]
+ rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0),
+ scenarios=[],
+ config={},
+ num_rollout_workers=1,
+ min_batch_size=1,
+ max_batch_size=2,
+ max_steps=1,
+ eval_fn=None,
+ )
+ trainer._output_queue = asyncio.Queue()
+
+ first = _make_group()
+ second = _make_group()
+ third = _make_group()
+ await trainer._output_queue.put(first)
+ await trainer._output_queue.put(second)
+ await trainer._output_queue.put(third)
+ await trainer._output_queue.put(None)
+
+ batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0)
+
+ assert batch == [first, second]
+ assert discarded == 0
+ assert not saw_sentinel
+
+ batch, discarded, saw_sentinel = await trainer._collect_batch(current_step=0)
+
+ assert batch == [third]
+ assert discarded == 0
+ assert saw_sentinel
diff --git a/tests/unit/test_pipeline_trainer_local_backend.py b/tests/unit/test_pipeline_trainer_local_backend.py
new file mode 100644
index 00000000..d17af12d
--- /dev/null
+++ b/tests/unit/test_pipeline_trainer_local_backend.py
@@ -0,0 +1,398 @@
+import asyncio
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any, cast
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from art import TrainableModel, Trajectory, TrajectoryGroup
+from art.dev.model import InternalModelConfig
+from art.local import LocalBackend
+from art.pipeline_trainer.trainer import PipelineTrainer
+from art.utils.output_dirs import get_model_dir
+
+
+def _make_group(rewards: list[float]) -> TrajectoryGroup:
+ return TrajectoryGroup(
+ [
+ Trajectory(
+ reward=reward,
+ initial_policy_version=0,
+ messages_and_choices=[
+ {"role": "user", "content": f"prompt-{idx}"},
+ {"role": "assistant", "content": f"answer-{idx}"},
+ ],
+ )
+ for idx, reward in enumerate(rewards)
+ ]
+ )
+
+
+def _make_trainer(
+ *,
+ model: TrainableModel,
+ backend: object,
+ **kwargs: Any,
+) -> PipelineTrainer:
+ return PipelineTrainer(
+ model=model,
+ backend=backend, # type: ignore[arg-type]
+ rollout_fn=lambda *_args, **_kwargs: asyncio.sleep(0),
+ scenarios=[],
+ config={},
+ num_rollout_workers=1,
+ min_batch_size=1,
+ max_batch_size=1,
+ max_steps=1,
+ eval_fn=None,
+ **kwargs,
+ )
+
+
+@pytest.mark.asyncio
+async def test_pipeline_trainer_preserves_backend_train_kwargs(tmp_path: Path) -> None:
+ model = TrainableModel(
+ name="pipeline-default-backend-kwargs",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+ backend = MagicMock()
+ backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={}))
+ loss_fn_config = {"alpha": 0.1}
+ adam_params = object()
+
+ trainer = _make_trainer(
+ model=model,
+ backend=backend,
+ learning_rate=2e-5,
+ loss_fn="cispo",
+ loss_fn_config=loss_fn_config,
+ normalize_advantages=True,
+ adam_params=adam_params,
+ )
+ trainer._output_queue = asyncio.Queue()
+ await trainer._output_queue.put(_make_group([0.0, 1.0]))
+ await trainer._output_queue.put(None)
+
+ await trainer._training_stage()
+
+ assert backend.train.await_args.kwargs == {
+ "learning_rate": 2e-5,
+ "loss_fn": "cispo",
+ "loss_fn_config": loss_fn_config,
+ "normalize_advantages": True,
+ "save_checkpoint": False,
+ "adam_params": adam_params,
+ }
+
+
+@pytest.mark.asyncio
+async def test_pipeline_trainer_forwards_kl_kwargs_for_generic_backend(
+ tmp_path: Path,
+) -> None:
+ model = TrainableModel(
+ name="pipeline-generic-backend-kl-kwargs",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+ backend = MagicMock()
+ backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={}))
+
+ trainer = _make_trainer(
+ model=model,
+ backend=backend,
+ kl_penalty_coef=0.25,
+ )
+ trainer._output_queue = asyncio.Queue()
+ await trainer._output_queue.put(_make_group([0.0, 1.0]))
+ await trainer._output_queue.put(None)
+
+ await trainer._training_stage()
+
+ assert backend.train.await_args.kwargs == {
+ "learning_rate": 1e-5,
+ "loss_fn": "cispo",
+ "loss_fn_config": None,
+ "normalize_advantages": True,
+ "save_checkpoint": False,
+ "adam_params": None,
+ "kl_penalty_coef": 0.25,
+ "kl_penalty_reference_step": 0,
+ "kl_penalty_source": "sample",
+ }
+
+
+@pytest.mark.asyncio
+async def test_pipeline_trainer_kl_step_lag_floors_at_zero(
+ tmp_path: Path,
+) -> None:
+ """kl_penalty_step_lag floors at step 0 when lag > current_step."""
+ model = TrainableModel(
+ name="pipeline-kl-step-lag-floor",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+ backend = MagicMock()
+ backend.train = AsyncMock(return_value=SimpleNamespace(step=2, metrics={}))
+
+ trainer = _make_trainer(
+ model=model,
+ backend=backend,
+ kl_penalty_coef=0.25,
+ kl_penalty_step_lag=5,
+ )
+ trainer._output_queue = asyncio.Queue()
+ await trainer._output_queue.put(_make_group([0.0, 1.0]))
+ await trainer._output_queue.put(None)
+
+ # Simulate being at step 1
+ trainer.state.next_training_step = 1
+
+ await trainer._training_stage()
+
+ # At step 1, lag=5 → reference = max(0, 1-5) = 0
+ assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 0
+
+
+@pytest.mark.asyncio
+async def test_pipeline_trainer_kl_step_lag_computes_reference(
+ tmp_path: Path,
+) -> None:
+ """kl_penalty_step_lag computes reference from current step."""
+ model = TrainableModel(
+ name="pipeline-kl-step-lag",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+ backend = MagicMock()
+ backend.train = AsyncMock(return_value=SimpleNamespace(step=4, metrics={}))
+
+ trainer = _make_trainer(
+ model=model,
+ backend=backend,
+ kl_penalty_coef=0.25,
+ kl_penalty_step_lag=2,
+ )
+ trainer._output_queue = asyncio.Queue()
+ await trainer._output_queue.put(_make_group([0.0, 1.0]))
+ await trainer._output_queue.put(None)
+
+ # Simulate being at step 3
+ trainer.state.next_training_step = 3
+
+ await trainer._training_stage()
+
+ # At step 3, lag=2 → reference = max(0, 3-2) = 1
+ assert backend.train.await_args.kwargs["kl_penalty_reference_step"] == 1
+
+
+@pytest.mark.asyncio
+async def test_pipeline_trainer_uses_same_train_kwargs_for_local_backend(
+ tmp_path: Path,
+) -> None:
+ model = TrainableModel(
+ name="pipeline-local-backend-kwargs",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ _internal_config=InternalModelConfig(
+ trainer_gpu_ids=[0],
+ inference_gpu_ids=[1],
+ ),
+ )
+ backend = LocalBackend(path=str(tmp_path))
+ backend.train = AsyncMock(return_value=SimpleNamespace(step=1, metrics={})) # type: ignore[method-assign]
+
+ trainer = _make_trainer(
+ model=model,
+ backend=backend,
+ learning_rate=3e-5,
+ loss_fn="ppo",
+ )
+ trainer._output_queue = asyncio.Queue()
+ await trainer._output_queue.put(_make_group([0.0, 1.0]))
+ await trainer._output_queue.put(None)
+
+ await trainer._training_stage()
+
+ assert backend.train.await_args.kwargs == { # type: ignore[attr-defined]
+ "learning_rate": 3e-5,
+ "loss_fn": "ppo",
+ "loss_fn_config": None,
+ "normalize_advantages": True,
+ "save_checkpoint": False,
+ "adam_params": None,
+ }
+
+
+@pytest.mark.asyncio
+async def test_local_backend_train_translates_loss_fn(tmp_path: Path) -> None:
+ model = TrainableModel(
+ name="local-backend-train-translation",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+ backend = LocalBackend(path=str(tmp_path))
+ seen: dict[str, Any] = {}
+
+ async def fake_train_model(
+ _model: TrainableModel,
+ _groups: list[TrajectoryGroup],
+ config: Any,
+ dev_config: dict[str, Any],
+ verbose: bool = False,
+ ):
+ seen["config"] = config
+ seen["dev_config"] = dev_config
+ seen["verbose"] = verbose
+ yield {}
+
+ backend._train_model = fake_train_model # type: ignore[method-assign]
+ backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign]
+ with patch.object(model, "_get_wandb_run", return_value=None):
+ result = await backend.train(
+ model,
+ [_make_group([1.0])],
+ loss_fn="ppo",
+ save_checkpoint=False,
+ )
+
+ assert result.step == 1
+ assert seen["config"].learning_rate == 5e-6
+ assert seen["dev_config"]["ppo"] is True
+
+
+@pytest.mark.asyncio
+async def test_local_backend_train_passes_kl_penalty_source(tmp_path: Path) -> None:
+ model = TrainableModel(
+ name="local-backend-kl-source",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+ backend = LocalBackend(path=str(tmp_path))
+ seen: dict[str, Any] = {}
+
+ async def fake_train_model(
+ _model: TrainableModel,
+ _groups: list[TrajectoryGroup],
+ config: Any,
+ dev_config: dict[str, Any],
+ verbose: bool = False,
+ ):
+ seen["config"] = config
+ seen["dev_config"] = dev_config
+ seen["verbose"] = verbose
+ yield {}
+
+ backend._train_model = fake_train_model # type: ignore[method-assign]
+ backend._get_step = AsyncMock(return_value=1) # type: ignore[method-assign]
+ with patch.object(model, "_get_wandb_run", return_value=None):
+ result = await backend.train(
+ model,
+ [_make_group([1.0])],
+ kl_penalty_coef=0.25,
+ kl_penalty_source="sample",
+ save_checkpoint=False,
+ )
+
+ assert result.step == 1
+ assert seen["config"].kl_penalty_source == "sample"
+ assert seen["dev_config"]["kl_penalty_source"] == "sample"
+
+
+@pytest.mark.asyncio
+async def test_local_backend_async_context_manager_awaits_async_cleanup(
+ tmp_path: Path,
+) -> None:
+ backend = LocalBackend(path=str(tmp_path))
+ calls: list[str] = []
+
+ class FakeService:
+ async def aclose(self) -> None:
+ calls.append("aclose")
+
+ service = FakeService()
+ backend._services["test-service"] = cast(Any, service)
+
+ with patch("art.local.backend.close_proxy") as close_proxy:
+ async with backend:
+ pass
+
+ assert calls == ["aclose"]
+ close_proxy.assert_called_once_with(service)
+
+
+@pytest.mark.parametrize(
+ ("trainer_kwargs", "match"),
+ [
+ ({"loss_fn": "dro"}, "loss_fn='cispo' or loss_fn='ppo'"),
+ ({"loss_fn_config": {"clip": 0.2}}, "loss_fn_config=None"),
+ ({"normalize_advantages": False}, "normalize_advantages=True"),
+ ({"adam_params": object()}, "adam_params=None"),
+ ],
+)
+def test_pipeline_trainer_rejects_unsupported_local_backend_settings(
+ tmp_path: Path,
+ trainer_kwargs: dict[str, object],
+ match: str,
+) -> None:
+ model = TrainableModel(
+ name="pipeline-local-backend-invalid",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ _internal_config=InternalModelConfig(
+ trainer_gpu_ids=[0],
+ inference_gpu_ids=[1],
+ ),
+ )
+
+ with pytest.raises(ValueError, match=match):
+ _make_trainer(
+ model=model,
+ backend=LocalBackend(path=str(tmp_path)),
+ **trainer_kwargs,
+ )
+
+
+def test_pipeline_trainer_rejects_shared_local_backend(tmp_path: Path) -> None:
+ model = TrainableModel(
+ name="pipeline-local-backend-shared",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+
+ with pytest.raises(
+ ValueError, match="only supports LocalBackend in dedicated mode"
+ ):
+ _make_trainer(model=model, backend=LocalBackend(path=str(tmp_path)))
+
+
+def test_local_backend_inference_name_prefers_served_step_in_dedicated_mode(
+ tmp_path: Path,
+) -> None:
+ model = TrainableModel(
+ name="local-backend-served-step",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ _internal_config=InternalModelConfig(
+ trainer_gpu_ids=[0],
+ inference_gpu_ids=[1],
+ ),
+ )
+ backend = LocalBackend(path=str(tmp_path))
+ output_dir = Path(get_model_dir(model=model, art_path=str(tmp_path)))
+ (output_dir / "checkpoints" / "3").mkdir(parents=True)
+ backend._services[model.name] = cast(Any, SimpleNamespace(_latest_step=2))
+
+ assert backend._model_inference_name(model) == f"{model.name}@2"
+ assert backend._model_inference_name(model, step=3) == f"{model.name}@3"
diff --git a/tests/unit/test_tinker_native_kl.py b/tests/unit/test_tinker_native_kl.py
new file mode 100644
index 00000000..a2d16d01
--- /dev/null
+++ b/tests/unit/test_tinker_native_kl.py
@@ -0,0 +1,77 @@
+import pytest
+import tinker
+
+from art import TrainableModel
+from art.tinker_native.backend import TinkerNativeBackend, _apply_kl_penalty
+from art.tinker_native.data import build_datum
+
+
+class FakeSamplingClient(tinker.SamplingClient):
+ def __init__(self, responses: dict[tuple[int, ...], list[float | None]]) -> None:
+ self._responses = responses
+
+ async def compute_logprobs_async(
+ self, prompt: tinker.ModelInput
+ ) -> list[float | None]:
+ return self._responses[tuple(prompt.to_ints())]
+
+
+@pytest.mark.asyncio
+async def test_incorporate_kl_penalty_rewrites_advantages_in_place() -> None:
+ datum_a = build_datum(
+ prompt_tokens=[101, 102],
+ completion_tokens=[201, 202],
+ logprobs=[-0.4, -0.8],
+ advantage=1.0,
+ )
+ datum_b = build_datum(
+ prompt_tokens=[301, 302],
+ completion_tokens=[401],
+ logprobs=[-0.2],
+ advantage=2.0,
+ )
+ assert datum_a is not None
+ assert datum_b is not None
+
+ sampling_client = FakeSamplingClient(
+ {
+ (101, 102, 201, 202): [None, -9.0, -0.1, -0.5],
+ (301, 302, 401): [None, -7.0, -0.05],
+ }
+ )
+
+ metrics = await _apply_kl_penalty(
+ [datum_a, datum_b],
+ sampling_client,
+ kl_penalty_coef=2.0,
+ )
+
+ assert metrics == {"loss/kl_policy_ref": pytest.approx(-0.25)}
+ assert datum_a.loss_fn_inputs["advantages"].tolist() == pytest.approx(
+ [0.0, 1.1, 1.1]
+ )
+ assert datum_b.loss_fn_inputs["advantages"].tolist() == pytest.approx([0.0, 1.8])
+
+
+@pytest.mark.asyncio
+async def test_tinker_native_backend_rejects_current_learner_kl_source(
+ tmp_path,
+) -> None:
+ backend = TinkerNativeBackend(tinker_api_key="test-key", path=str(tmp_path))
+ model = TrainableModel(
+ name="tinker-native-kl-source",
+ project="pipeline-tests",
+ base_model="test-model",
+ base_path=str(tmp_path),
+ )
+
+ with pytest.raises(
+ AssertionError,
+ match="only supports kl_penalty_source='sample'",
+ ):
+ await backend.train(
+ model,
+ [],
+ kl_penalty_coef=0.25,
+ kl_penalty_source="current_learner", # ty:ignore[invalid-argument-type]
+ )